AWS Step Functions - Distributed Map Example#

Keywords: AWS Step Function, StepFunction, SFN, State Machine, StateMachine, Lambda

早期 AWS Step Function 的 Map 并行处理只有 inline 模式. 这种并发模式指的是这个并发是在一个同一个 scheduler 的子进程中调度的. 类似于你一个机器上 fork 了一个 thread. 所以这种模式的并发不可能做的很高, 所以它有 40 个 concurrency 的限制.

而在 2022-12-01, AWS 发布了 Distributed Map 功能, 它允许同时并发 10,000 个子任务. 这种模式的底层是用的 sub step function workflow execution 的模式来实现的, 所以作为一个分布式系统, 它的并发是可以做的很高的.

Item Sources#

给每个子任务指定参数的方式有很多种:

  1. State Input: 上一步返回的结果就是一个 List of Dictionary. 里面每个 Dictionary 就是一个子任务的参数, List 中元素的个数就是并发数.

  2. Amazon S3 - S3 Object List: 把某个 S3 prefix 下的所有 object 作为输入, 每个 object 都有是一个 JSON, 里面是子任务的参数.

  3. Amazon S3 - Json File in S3: 把 List of Dictionary 用 JSON 序列化保存在 S3 object 中. 然后 Step Function 到这个 S3 object 中去读数据.

  4. Amazon S3 - CSV File in S3: 跟 Json File in S3 类似, 唯一的区别是文件格式是 CSV.

  5. Amazon S3 - S3 Inventory: 跟 S3 Object List 类似, 不过是在一个指定的 Manifest File (一个 S3 Object) 中去读所有文件的列表, 这个 Manifest File 是一个 Json Line 文件, 每一行是一条数据. 注意, 每一行只是包含了每个子任务的参数所在的 S3 object 的 uri, 并不包含参数数据本身.

我个人最喜欢 #3, 因为它简单, 直观, 灵活 (适用于各种复杂的参数列表). 只有在数据量巨大的时候才会用到 #5.

Amazon S3 Json File in S3 Example#

这里给出了一个具体的例子.

这里的关键是 sfn_def 中的这个部分, 指定了 Map 要从哪里读参数:

"Parameters": {
  "Bucket.$": "$.map_input_s3_bucket",
  "Key.$": "$.map_input_s3_key"
}

以及 Lambda Function 的 input 是直接从 state input 过来的.

SFN 的定义

sfn_def.json
 1{
 2  "Comment": "A description of my state machine",
 3  "StartAt": "Map",
 4  "States": {
 5    "Map": {
 6      "Type": "Map",
 7      "ItemProcessor": {
 8        "ProcessorConfig": {
 9          "Mode": "DISTRIBUTED",
10          "ExecutionType": "STANDARD"
11        },
12        "StartAt": "Lambda Invoke",
13        "States": {
14          "Lambda Invoke": {
15            "Type": "Task",
16            "Resource": "arn:aws:states:::lambda:invoke",
17            "OutputPath": "$.Payload",
18            "Parameters": {
19              "FunctionName": "arn:aws:lambda:us-east-1:878625312159:function:sfn-poc-lbd-1:$LATEST",
20              "Payload.$": "$"
21            },
22            "Retry": [
23              {
24                "ErrorEquals": [
25                  "Lambda.ServiceException",
26                  "Lambda.AWSLambdaException",
27                  "Lambda.SdkClientException",
28                  "Lambda.TooManyRequestsException"
29                ],
30                "IntervalSeconds": 1,
31                "MaxAttempts": 3,
32                "BackoffRate": 2
33              }
34            ],
35            "End": true
36          }
37        }
38      },
39      "End": true,
40      "Label": "Map",
41      "MaxConcurrency": 1000,
42      "ItemReader": {
43        "Resource": "arn:aws:states:::s3:getObject",
44        "ReaderConfig": {
45          "InputType": "JSON"
46        },
47        "Parameters": {
48          "Bucket.$": "$.map_input_s3_bucket",
49          "Key.$": "$.map_input_s3_key"
50        }
51      }
52    }
53  }
54}

Lambda 的源码

lbd.py
 1# -*- coding: utf-8 -*-
 2
 3from pprint import pprint
 4import boto3
 5
 6s3_client = boto3.client("s3")
 7
 8
 9def main(bucket: str, key: str) -> int:
10    return s3_client.get_object(Bucket=bucket, Key=key)["ContentLength"]
11
12
13def lambda_handler(event: dict, context):
14    print("----- event -----")
15    pprint(event)
16    return main(event["bucket"], event["key"])

测试代码

sfn_test.py
 1# -*- coding: utf-8 -*-
 2
 3import json
 4from s3pathlib import S3Path, context
 5from boto_session_manager import BotoSesManager
 6
 7bsm = BotoSesManager(profile_name="bmt_app_dev_us_east_1")
 8context.attach_boto_session(boto_ses=bsm.boto_ses)
 9sfn_arn = f"arn:aws:states:{bsm.aws_region}:{bsm.aws_account_id}:stateMachine:sfn-poc"
10lbd_arn = f"arn:aws:lambda:{bsm.aws_region}:{bsm.aws_account_id}:function:sfn-poc-lbd-1"
11bucket = f"{bsm.aws_account_alias}-{bsm.aws_region}-data"
12key = "projects/tmp/"
13
14n_file = 3
15s3dir_tmp = S3Path(bucket, key)
16
17# prepare map payload
18map_payload = []
19for ith in range(1, 1 + n_file):
20    s3path = s3dir_tmp.joinpath(f"data/{ith}.json")
21    s3path.write_text("hello world")
22    map_payload.append(
23        {
24            "bucket": s3path.bucket,
25            "key": s3path.key,
26        }
27    )
28s3path = s3dir_tmp.joinpath("map_payload.json")
29s3path.write_text(json.dumps(map_payload))
30
31input_data = {
32    "map_input_s3_bucket": s3path.bucket,
33    "map_input_s3_key": s3path.key,
34}
35bsm.sfn_client.start_execution(
36    stateMachineArn=sfn_arn,
37    input=json.dumps(input_data),
38)