Athena in Python#

Keywords: AWS, Amazon, Athena, Python, pyathena

本文主要介绍如何在 Python 中使用 Athena.

在 Python 的世界里有一个标准叫做 DB API 2.0 的标准. 无论底层是什么 SQL 数据库 (前提要是 SQL 数据库), 只要你的库遵循这个标准, 那么就可以用 connect.execute(sql_statement) 这样的语法返回一个 iterable 的 cursor 对象, 返回的每条记录是以 namedtuple 的形式存在的. 而 Python 中的生产级 SQL 库 sqlalchemy 也能对遵循 DB API 2.0 的库有着良好的支持.

Python 社区对 Athena 的 DB API 2.0 实现的库是 https://pypi.org/project/pyathena/. 本质上 Athena 是将数据存在 S3 bucket 中, 而 pyathena 是通过实现一个 wrapper, 以实现 DB API 2.0 标准. 如果想要用 python 操作 Athena, 建议参照 pyathena 的文档, 配合 sqlalchemypandas 一起使用, 体验最好.

requirements.txt Dependencies.

# test dependencies for pyathena
awswrangler>=3.0.0,<4.0.0
jsonpath_ng>=1.6.0,<2.0.0
pyathena>=3.0.0,<4.0.0
PyAthena[Pandas]
PyAthena[SQLAlchemy]
pandas>=2.0.0,<3.0.0
sqlalchemy>=2.0.0,<3.0.0
s3pathlib>=2.1.2,<3.0.0
boto_session_manager>=1.5.4,<2.0.0
smart_open>=6.0.0,<7.0.0

prepare_data.py Prepare dataset for test.

# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
import awswrangler as wr
from boto_session_manager import BotoSesManager
from s3pathlib import S3Path, context

# define aws credential and s3 location
class Config:
    aws_profile = "bmt_app_dev_us_east_1"
    bucket = "bmt-app-dev-us-east-1-data"
    prefix = "poc/2023-12-01-athena-in-python"
    glue_database = "athena_in_python"
    glue_table = "events"


bsm = BotoSesManager(profile_name=Config.aws_profile)
context.attach_boto_session(bsm.boto_ses)


if __name__ == "__main__":
    s3dir = S3Path(
        Config.bucket,
        Config.prefix,
        "events",
    ).to_dir()
    print(f"preview data at {s3dir.console_url}")

    databases = wr.catalog.databases(boto3_session=bsm.boto_ses)
    if Config.glue_database not in databases.values:
        wr.catalog.create_database(Config.glue_database, boto3_session=bsm.boto_ses)

    # generate dummy data
    n_rows = 1000
    df = pd.DataFrame()
    df["id"] = range(1, n_rows + 1)
    df["time"] = pd.date_range(start="2000-01-01", end="2000-03-31", periods=n_rows)
    df["category"] = np.random.randint(1, 1 + 3, size=n_rows)
    df["value"] = np.random.randint(1, 1 + 100, size=n_rows)

    # write csv to s3
    wr.s3.to_csv(
        df=df,
        path=s3dir.uri,
        dataset=True,
        database=Config.glue_database,
        table=Config.glue_table,
        mode="overwrite",
        boto3_session=bsm.boto_ses,
    )

query_data.py pyathena usage example.

# -*- coding: utf-8 -*-

import pandas as pd
from s3pathlib import S3Path
from pyathena import connect
from prepare_data import Config, bsm

# define aws credential and s3 location
s3path_athena_result = S3Path(Config.bucket, "athena/results/")

# define connection, use AWS CLI named profile for authentication
conn = connect(
    s3_staging_dir=s3path_athena_result.uri,
    profile_name=Config.aws_profile,
    region_name=bsm.aws_region,
)

# define the SQL statement, use ${database}.${table} as t to specify the table
sql = f"""
SELECT 
    t.category,
    AVG(t.value) as average_value  
FROM {Config.glue_database}.{Config.glue_table} t
GROUP BY t.category
ORDER BY t.category
"""

# execute the SQL query, load result to pandas DataFrame
df = pd.read_sql_query(sql, conn)
print(df)