Apache Airflow is a popular open-source workflow management platform. Typically tasks run remotely by Celery workers for scalability. In AWS, however, scalability can also be achieved using serverless computing services in a simpler way. For example, the ECS Operator allows to run dockerized tasks and, with the Fargate launch type, they can run in a serverless environment.

The ECS Operator alone is not sufficient because it can take up to several minutes to pull a Docker image and to set up network interface (for the case of Fargate launch type). Due to its latency, it is not suitable for frequently-running tasks. On the other hand, the latency of a Lambda function is negligible so that it’s more suitable for managing such tasks.

In this post, it is demonstrated how AWS Lambda can be integrated with Apache Airflow using a custom operator inspired by the ECS Operator.

How it works

The following shows steps when an Airflow task is executed by the ECS Operator.

  • Running the associating ECS task
  • Waiting for the task ended
  • Checking the task status

The status of a task is checked by searching a stopped reason and raises AirflowException if the reason is considered to be failure. While checking the status, the associating CloudWatch log events are pulled and printed so that the ECS task’s container logs can be found in Airflow web server.

The key difference between ECS and Lambda is that the former sends log events to a dedicated CloudWatch Log Stream while the latter may reuse an existing Log Stream due to container reuse. Therefore it is not straightforward to pull execution logs for a specific Lambda invocation. It can be handled by creating a custom CloudWatch Log Group and sending log events to a CloudWatch Log Stream within the custom Log Group. For example, let say there is a Lambda function named as airflow-test. In order to pull log events for a specific Lambda invocation, a custom Log Group (eg /airflow/lambda/airflow-test) can be created and, inside the Lambda function, log events can be sent to a Log Stream within the custom Log Group. Note that the CloudWatch Log Stream name can be determined by the operator and sent to the Lambda function in the Lambda payload. In this way, the Lambda function can send log events to a Log Stream that Airflow knows. Then the steps of a custom Lambda Operator can be as following.

  • Invoking the Lambda function
  • Wating for function ended
  • Checking the invocation status

Lambda Operator

Below shows a simplified version of the custom Lambda Operator - the full version can be found here. Note that the associating CloudWatch Log Group name is a required argument (awslogs_group) while the Log Stream name is determined by a combination of execution date, qualifier and UUID. These are sent to the Lambda function in the payload. Note also that, in _check_success_invocation(), whether a function invocation is failed or succeeded is identified by searching ERROR within message of log events. I find this gives a more stable outcome than Lambda invocation response.

 1import re, time, json, math, uuid
 2from datetime import datetime
 3from botocore import exceptions
 4from airflow.exceptions import AirflowException
 5from airflow.models import BaseOperator
 6from airflow.utils import apply_defaults
 7from airflow.contrib.hooks.aws_hook import AwsHook
 8from airflow.contrib.hooks.aws_logs_hook import AwsLogsHook
 9
10class LambdaOperator(BaseOperator):
11    @apply_defaults
12    def __init__(
13        self, function_name, awslogs_group, qualifier="$LATEST", 
14        payload={}, aws_conn_id=None, region_name=None, *args, **kwargs
15    ):
16        super(LambdaOperator, self).__init__(**kwargs)
17        self.function_name = function_name
18        self.qualifier = qualifier
19        self.payload = payload
20        # log stream is created and added to payload
21        self.awslogs_group = awslogs_group
22        self.awslogs_stream = "{0}/[{1}]{2}".format(
23            datetime.utcnow().strftime("%Y/%m/%d"),
24            self.qualifier,
25            re.sub("-", "", str(uuid.uuid4())),
26        )
27        # lambda client and cloudwatch logs hook
28        self.client = AwsHook(aws_conn_id=aws_conn_id).get_client_type("lambda")
29        self.awslogs_hook = AwsLogsHook(aws_conn_id=aws_conn_id, region_name=region_name)
30
31    def execute(self, context):        
32        # invoke - wait - check
33        payload = json.dumps(
34            {
35                **{"group_name": self.awslogs_group, "stream_name": self.awslogs_stream},
36                **self.payload,
37            }
38        )
39        invoke_opts = {
40            "FunctionName": self.function_name,
41            "Qualifier": self.qualifier,
42            "InvocationType": "RequestResponse",
43            "Payload": bytes(payload, encoding="utf8"),
44        }
45        try:
46            resp = self.client.invoke(**invoke_opts)
47            self.log.info("Lambda function invoked - StatusCode {0}".format(resp["StatusCode"]))
48        except exceptions.ClientError as e:
49            raise AirflowException(e.response["Error"])
50
51        self._wait_for_function_ended()
52
53        self._check_success_invocation()
54        self.log.info("Lambda Function has been successfully invoked")
55
56    def _wait_for_function_ended(self):
57        waiter = self.client.get_waiter("function_active")
58        waiter.config.max_attempts = math.ceil(
59            self._get_function_timeout() / 5
60        )  # poll interval - 5 seconds
61        waiter.wait(FunctionName=self.function_name, Qualifier=self.qualifier)
62
63    def _check_success_invocation(self):
64        self.log.info("Lambda Function logs output")
65        has_message, invocation_failed = False, False
66        messages, max_trial, current_trial = [], 5, 0
67        # sometimes events are not retrieved, run for 5 times if so
68        while True:
69            current_trial += 1
70            for event in self.awslogs_hook.get_log_events(self.awslogs_group, self.awslogs_stream):
71                dt = datetime.fromtimestamp(event["timestamp"] / 1000.0)
72                self.log.info("[{}] {}".format(dt.isoformat(), event["message"]))
73                messages.append(event["message"])
74            if len(messages) > 0 or current_trial > max_trial:
75                break
76            time.sleep(2)
77        if len(messages) == 0:
78            raise AirflowException("Fails to get log events")
79        for m in reversed(messages):
80            if re.search("ERROR", m) != None:
81                raise AirflowException("Lambda Function invocation is not successful")
82
83    def _get_function_timeout(self):
84        resp = self.client.get_function(FunctionName=self.function_name, Qualifier=self.qualifier)
85        return resp["Configuration"]["Timeout"]

Lambda Function

Below shows a simplified version of the Lambda function - the full version can be found here. CustomLogManager includes methods to create CloudWatch Log Stream and to put log events. LambdaDecorator manages actions before/after the function invocation as well as when an exception occurs - it’s used as a decorator and modified from the lambda_decorators package. Before an invocation, it initializes a custom Log Stream. Log events are put to the Log Stream after an invocation or there is an exception. Note that traceback is also sent to the Log Stream when there’s an exception. The Lambda function simply exits after a loop or raises an exception at random.

  1import time, re, random, logging, traceback, boto3
  2from datetime import datetime
  3from botocore import exceptions
  4from io import StringIO
  5from functools import update_wrapper
  6
  7# save logs to stream
  8stream = StringIO()
  9logger = logging.getLogger()
 10log_handler = logging.StreamHandler(stream)
 11formatter = logging.Formatter("%(levelname)-8s %(asctime)-s %(name)-12s %(message)s")
 12log_handler.setFormatter(formatter)
 13logger.addHandler(log_handler)
 14logger.setLevel(logging.INFO)
 15
 16cwlogs = boto3.client("logs")
 17
 18class CustomLogManager(object):
 19    # create log stream and send logs to it
 20    def __init__(self, event):
 21        self.group_name = event["group_name"]
 22        self.stream_name = event["stream_name"]
 23
 24    def has_log_group(self):
 25        group_exists = True
 26        try:
 27            resp = cwlogs.describe_log_groups(logGroupNamePrefix=self.group_name)
 28            group_exists = len(resp["logGroups"]) > 0
 29        except exceptions.ClientError as e:
 30            logger.error(e.response["Error"]["Code"])
 31            group_exists = False
 32        return group_exists
 33
 34    def create_log_stream(self):
 35        is_created = True
 36        try:
 37            cwlogs.create_log_stream(logGroupName=self.group_name, logStreamName=self.stream_name)
 38        except exceptions.ClientError as e:
 39            logger.error(e.response["Error"]["Code"])
 40            is_created = False
 41        return is_created
 42
 43    def delete_log_stream(self):
 44        is_deleted = True
 45        try:
 46            cwlogs.delete_log_stream(logGroupName=self.group_name, logStreamName=self.stream_name)
 47        except exceptions.ClientError as e:
 48            # ResourceNotFoundException is ok
 49            codes = [
 50                "InvalidParameterException",
 51                "OperationAbortedException",
 52                "ServiceUnavailableException",
 53            ]
 54            if e.response["Error"]["Code"] in codes:
 55                logger.error(e.response["Error"]["Code"])
 56                is_deleted = False
 57        return is_deleted
 58
 59    def init_log_stream(self):
 60        if not all([self.has_log_group(), self.delete_log_stream(), self.create_log_stream()]):
 61            raise Exception("fails to create log stream")
 62        logger.info("log stream created")
 63
 64    def create_log_events(self, stream):
 65        fmt = "%Y-%m-%d %H:%M:%S,%f"
 66        log_events = []
 67        for m in [s for s in stream.getvalue().split("\n") if s]:
 68            match = re.search(r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}", m)
 69            dt_str = match.group() if match else datetime.utcnow().strftime(fmt)
 70            log_events.append(
 71                {"timestamp": int(datetime.strptime(dt_str, fmt).timestamp()) * 1000, "message": m}
 72            )
 73        return log_events
 74
 75    def put_log_events(self, stream):
 76        try:
 77            resp = cwlogs.put_log_events(
 78                logGroupName=self.group_name,
 79                logStreamName=self.stream_name,
 80                logEvents=self.create_log_events(stream),
 81            )
 82            logger.info(resp)
 83        except exceptions.ClientError as e:
 84            logger.error(e)
 85            raise Exception("fails to put log events")
 86
 87class LambdaDecorator(object):
 88    # keep functions to run before, after and on exception
 89    # modified from lambda_decorators (https://lambda-decorators.readthedocs.io/en/latest/)
 90    def __init__(self, handler):
 91        update_wrapper(self, handler)
 92        self.handler = handler
 93
 94    def __call__(self, event, context):
 95        try:
 96            self.event = event
 97            self.log_manager = CustomLogManager(event)
 98            return self.after(self.handler(*self.before(event, context)))
 99        except Exception as exception:
100            return self.on_exception(exception)
101
102    def before(self, event, context):
103        # remove existing logs
104        stream.seek(0)
105        stream.truncate(0)
106        # create log stream
107        self.log_manager.init_log_stream()
108        logger.info("Start Request")
109        return event, context
110
111    def after(self, retval):
112        logger.info("End Request")
113        # send logs to stream
114        self.log_manager.put_log_events(stream)
115        return retval
116
117    def on_exception(self, exception):
118        logger.error(str(exception))
119        # log traceback
120        logger.error(traceback.format_exc())
121        # send logs to stream
122        self.log_manager.put_log_events(stream)
123        return str(exception)
124
125@LambdaDecorator
126def lambda_handler(event, context):
127    max_len = event.get("max_len", 6)
128    fails_at = random.randint(0, max_len * 2)
129    for i in range(max_len):
130        if i != fails_at:
131            logger.info("current run {0}".format(i))
132        else:
133            raise Exception("fails at {0}".format(i))
134        time.sleep(1)

Run Lambda Task

A simple demo task is created as following. It just runs the Lambda function every 30 seconds.

 1import airflow
 2from airflow import DAG
 3from airflow.utils.dates import days_ago
 4from datetime import timedelta
 5from dags.operators import LambdaOperator
 6
 7function_name = "airflow-test"
 8
 9demo_dag = DAG(
10    dag_id="demo-dag",
11    start_date=days_ago(1),
12    catchup=False,
13    max_active_runs=1,
14    concurrency=1,
15    schedule_interval=timedelta(seconds=30),
16)
17
18demo_task = LambdaOperator(
19    task_id="demo-task",
20    function_name=function_name,
21    awslogs_group="/airflow/lambda/{0}".format(function_name),
22    payload={"max_len": 6},
23    dag=demo_dag,
24)

The task can be tested by the following docker compose services. Note that the web server and scheduler are split into separate services although it doesn’t seem to be recommended for Local Executor - I had an issue to launch Airflow when those are combined in ECS.

 1version: "3.7"
 2services:
 3  postgres:
 4    image: postgres:11
 5    container_name: airflow-postgres
 6    networks:
 7      - airflow-net
 8    ports:
 9      - 5432:5432
10    environment:
11      - POSTGRES_USER=airflow
12      - POSTGRES_PASSWORD=airflow
13      - POSTGRES_DB=airflow
14  webserver:
15    image: puckel/docker-airflow:1.10.6
16    container_name: webserver
17    command: webserver
18    networks:
19      - airflow-net
20    user: root # for DockerOperator
21    volumes:
22      - ${HOME}/.aws:/root/.aws # run as root user
23      - ./requirements.txt:/requirements.txt
24      - ./dags:/usr/local/airflow/dags
25      - ./config/airflow.cfg:/usr/local/airflow/config/airflow.cfg
26      - ./entrypoint.sh:/entrypoint.sh # override entrypoint
27      - /var/run/docker.sock:/var/run/docker.sock # for DockerOperator
28      - ./custom:/usr/local/airflow/custom
29    ports:
30      - 8080:8080
31    environment:
32      - AIRFLOW__CORE__EXECUTOR=LocalExecutor
33      - AIRFLOW__CORE__LOAD_EXAMPLES=False
34      - AIRFLOW__CORE__LOGGING_LEVEL=INFO
35      - AIRFLOW__CORE__FERNET_KEY=Gg3ELN1gITETZAbBQpLDBI1y2P0d7gHLe_7FwcDjmKc=
36      - AIRFLOW__CORE__REMOTE_LOGGING=True
37      - AIRFLOW__CORE__REMOTE_BASE_LOG_FOLDER=s3://airflow-lambda-logs
38      - AIRFLOW__CORE__ENCRYPT_S3_LOGS=True
39      - POSTGRES_HOST=postgres
40      - POSTGRES_USER=airflow
41      - POSTGRES_PASSWORD=airflow
42      - POSTGRES_DB=airflow
43      - AWS_DEFAULT_REGION=ap-southeast-2
44    restart: always
45    healthcheck:
46      test: ["CMD-SHELL", "[ -f /usr/local/airflow/config/airflow-webserver.pid ]"]
47      interval: 30s
48      timeout: 30s
49      retries: 3
50  scheduler:
51    image: puckel/docker-airflow:1.10.6
52    container_name: scheduler
53    command: scheduler
54    networks:
55      - airflow-net
56    user: root # for DockerOperator
57    volumes:
58      - ${HOME}/.aws:/root/.aws # run as root user
59      - ./requirements.txt:/requirements.txt
60      - ./logs:/usr/local/airflow/logs
61      - ./dags:/usr/local/airflow/dags
62      - ./config/airflow.cfg:/usr/local/airflow/config/airflow.cfg
63      - ./entrypoint.sh:/entrypoint.sh # override entrypoint
64      - /var/run/docker.sock:/var/run/docker.sock # for DockerOperator
65      - ./custom:/usr/local/airflow/custom
66    environment:
67      - AIRFLOW__CORE__EXECUTOR=LocalExecutor
68      - AIRFLOW__CORE__LOAD_EXAMPLES=False
69      - AIRFLOW__CORE__LOGGING_LEVEL=INFO
70      - AIRFLOW__CORE__FERNET_KEY=Gg3ELN1gITETZAbBQpLDBI1y2P0d7gHLe_7FwcDjmKc=
71      - AIRFLOW__CORE__REMOTE_LOGGING=True
72      - AIRFLOW__CORE__REMOTE_BASE_LOG_FOLDER=s3://airflow-lambda-logs
73      - AIRFLOW__CORE__ENCRYPT_S3_LOGS=True
74      - POSTGRES_HOST=postgres
75      - POSTGRES_USER=airflow
76      - POSTGRES_PASSWORD=airflow
77      - POSTGRES_DB=airflow
78      - AWS_DEFAULT_REGION=ap-southeast-2
79    restart: always
80
81networks:
82  airflow-net:
83    name: airflow-network

Below shows the demo DAG after running for a while.

Lambda logs (and traceback) are found for both succeeded and failed tasks.