pythondjangoceleryamazon-ecs

Celery Task ECS Termination Issue - Need Help Updating Decorator for Handling ProtectionEnabled State Changes


Explanation:

I have a Django application where I am running multiple Celery tasks on AWS Elastic Container Service (ECS), using SQS as the broker. I am encountering an issue where the Celery tasks are being started in an existing ECS task once the previous one is completed. The issue arises because my decorator changes the status of ProtectionEnabled from true to false, and after a couple of seconds, the ECS task is terminated. The newly started task then fails to work. Below is the command I am running to start celery task.

celery -A myapp_settings.celery worker --concurrency=1 l info -Q sqs-celery

I am using alerts on CloudWatch to check messages in broker and terminate those ECS tasks that are completed. The problem is that celery is starting task in existing ECS task once the previous one was completed. It would not be a problem but my decorator changes the status of ProtectionEnabled from true to false and after 20 seconds ECS task is terminated and newly started task is not working anymore.

Question:

I am considering updating my decorator to change back the ProtectionEnabled value from false to true if a new Celery task starts, but I am unsure how to implement this.

Code:

container_decorator.py

class ContainerAgent:
    class Error(Exception):
        pass

    class RequestError(Error, IOError):
        pass

    def __init__(
        self,
        ecs_agent_uri: str,
        timeout: int = 10,
        session: requests.Session = None,
        logger: logging.Logger = None,
    ) -> None:
        self._ecs_agent_uri = ecs_agent_uri
        self._timeout = timeout

        self._session = session or requests.Session()
        self._logger = logger or logging.getLogger(self.__class__.__name__)

    def _request(self, *, path: str, data: Optional[dict] = None) -> dict:
        url = f"{self._ecs_agent_uri}{path}"
        self._logger.info(f"Performing request... {url=} {data=}")

        try:
            response = self._session.put(url=url, json=data, timeout=self._timeout)
            self._logger.info(f"Got response. {response.status_code=} {response.content=}")

            response.raise_for_status()
            return response.json()
        except requests.RequestException as e:
            response_body = e.response.text if e.response is not None else None
            self._logger.warning(f"Request error! {url=} {data=} {e=} {response_body=}")

            raise self.RequestError(str(e)) from e

    def toggle_scale_in_protection(self, *, enable: bool = True, expire_in_minutes: int = 2880):
        response = self._request(
            path="/task-protection/v1/state",
            data={"ProtectionEnabled": enable, "ExpiresInMinutes": expire_in_minutes},
        )

        try:
            return response["protection"]["ProtectionEnabled"]
        except KeyError as e:
            raise self.Error(f"Task scale-in protection endpoint error: {response=}") from e


def enable_scale_in_protection(*, logger: logging.Logger = None):
    def decorator(f):
        if not (ecs_agent_uri := os.getenv("ECS_AGENT_URI")):
            (logger or logging).warning(f"Scale-in protection not enabled. {ecs_agent_uri=}")
            return f

        client = ContainerAgent(ecs_agent_uri=ecs_agent_uri, logger=logger)

        @wraps(f)
        def wrapper(*args, **kwargs):
            try:
                client.toggle_scale_in_protection(enable=True)
            except client.Error as e:
                (logger or logging).warning(f"Scale-in protection not enabled. {e}")
                protection_set = False
            else:
                protection_set = True

            try:
                return f(*args, **kwargs)
            finally:
                if protection_set:
                    client.toggle_scale_in_protection(enable=False)

        return wrapper
    return decorator

celery_tasks.py

@shared_task(name="add_spider_schedule", base=AbortableTask)
@enable_scale_in_protection(logger=get_task_logger(__name__))
def add_spider_schedule(user_id, spider_id):
    settings_module = os.environ.get('DJANGO_SETTINGS_MODULE')
    if settings_module == 'myapp_settings.settings.production':
        return add_spider_schedule_production(user_id, spider_id, add_spider_schedule)
    else:
        return print('Unknown settings module')

def add_spider_schedule_production(user_id, spider_id, task_object):
    """
    Adds the schedule for the specified spider.
    :param spider_id: The ID of the spider to schedule.
    :return: A string representation of the spider and task IDs.
    """
    # below is the logging setup to include all 'prints' (both in the below & in the spider script) in the logger
    logger = get_task_logger(task_object.request.id)
    old_outs = sys.stdout, sys.stderr
    rlevel = add_spider_schedule.app.conf.worker_redirect_stdouts_level
    add_spider_schedule.app.log.redirect_stdouts_to_logger(logger, rlevel)

    # Get the Spider model instance
    spider = Spider.objects.get(id=spider_id)

    # Get the current user
    user = User.objects.get(id=user_id)

    # Get the names of the relevant files from the model instance
    spider_config_file = spider.spider_config_file.file
    yaml_config_file = spider.yaml_config_file.file
    template_file = spider.template_file.file
    mongodb_database_name = spider.mongodb_collection.database_name
    mongodb_collection_name = spider.mongodb_collection.collection_name

    # Read the contents of the files from the S3 bucket
    spider_config_file_contents = load_content_from_s3(AWS_STORAGE_BUCKET_NAME, rf"{PUBLIC_MEDIA_LOCATION}/{spider_config_file}")
    yaml_config_path = load_content_from_s3(AWS_STORAGE_BUCKET_NAME, rf"{PUBLIC_MEDIA_LOCATION}/{yaml_config_file}")
    input_file_path = load_content_from_s3(AWS_STORAGE_BUCKET_NAME, rf"{PUBLIC_MEDIA_LOCATION}/{template_file}")

    # Convert the JSON-encoded keyword arguments to a dictionary
    kwargs = json.loads(spider.kwargs) if spider.kwargs else {}

    # Create a module from the contents of the spider_config_file
    spider_module = import_module(spider_config_file_contents, "spider_config")
 
    is_scraping_finished = False

    async def run_spider():
        try:
            await spider_module.run(
                yaml_config_path=yaml_config_path,
                # page_type = page_type,
                # fields_to_scrape = fields_to_scrape,
                input_file_path=input_file_path,
                mongodb_name=mongodb_database_name,
                mongodb_collection_name=mongodb_collection_name,
                task_object=task_object,
                mode="sf-lab",
                **kwargs
            )
            nonlocal is_scraping_finished
            # print(f"CELERY TASK OBJECT DETAILS: {task_object.request}")
            is_scraping_finished = True
        except Exception as e:
            raise Exception(f"An error occurred while running the spider: {e}")

    async def check_if_aborted():
        while True:
            if task_object.is_aborted():
                print("Parralel function detected that task was cancelled.")
                raise Exception("task was cancelled")
            elif is_scraping_finished:
                # print("Scraping finished - breaking check-if-aborted loop")
                break
            await asyncio.sleep(0.1)

    loop = asyncio.get_event_loop()
    loop.run_until_complete(asyncio.gather(run_spider(), check_if_aborted()))

    sys.stdout, sys.stderr = old_outs  # needed for logging part

    return f"[spider: {spider_id}, task_id: {task_object.request.id}]"


Solution

  • You're on the right track. Rather than always disabling scale-in protection when a protected task is finished, you need to check whether or not the current worker has any remaining tasks that need to be protected, and only disable scale-in protection if not (Assuming a single worker per ECS task).

    How I handled this is setting up a dedicated queue for celery tasks that should be protected from scale-in events:

    from celery import Celery
    from kombu import Exchange, Queue
    
    # Init App
    app = Celery(...)
    
    # Set up task Queues
    app.conf.update(
      task_default_queue="default",
      task_default_exchange_type="direct",
      task_default_routing_key="default",
    )
    app.conf.task_queues = (
      Queue("default", Exchange("default"), routing_key="default"),
      Queue(
        "scale-in-protection",
        Exchange("scale-in-protection"),
        routing_key="scale-in-protection",
      ),
    )
    

    Next, I used the celeryd_after_setup signal to store the celery worker ID in an environmental variable so that it can be accessed from within the task.

    @celeryd_after_setup.connect
    def store_celery_worker_id(sender, instance, **kwargs):
        """Store the current worker ID as an environment variable upon startup."""
        os.environ["CELERY_CURRENT_WORKER_ID"] = sender
    

    Lastly, I modified the decorator you wrote as follows:

    def task_with_scale_in_protection(*, logger: Optional[logging.Logger] = None):
    
        def decorator(f):
            client = None
            ecs_agent_uri = os.getenv("ECS_AGENT_URI")
            if not (ecs_agent_uri):
                (logger or logging).warning(
                    f"Scale-in protection not enabled. {ecs_agent_uri}"
                )
            else:
                client = ContainerAgent(ecs_agent_uri=ecs_agent_uri, logger=logger)
    
    
            @app.task(bind=True, queue="scale-in-protection")
            @wraps(f)
            def wrapper(self, *args, **kwargs):
                protection_set = False
                if client:
                    try:
                        client.toggle_scale_in_protection(enable=True)
                        protection_set = True
                    except client.Error as e:
                        (logger or logging).warning(f"Scale-in protection not enabled. {e}")
    
                try:
                    # Run the celery task
                    return f(*args, **kwargs)
                finally:
                    if client and protection_set:
                        current_worker_id = os.getenv("CELERY_CURRENT_WORKER_ID")
                        if not current_worker_id:
                            (logger or logging).warning(
                                "Current Worker ID not set. Leaving scale-in protection enabled."
                            )
                        else:
                            # Determine if the worker has any remaining protected tasks
                            # If not, disable scale-in protection.
                            # NOTE: This is currently vulnerable to a race condition.
                            # If two or more protected tasks finish and run this code at the same time,
                            # they may see each other as active skip disabling scale-in protection. 
                            worker_has_remaining_protected_tasks = False
                            i = app.control.inspect()
                            for task_list in [i.active(), i.reserved(), i.scheduled()]:
                                for task in task_list.get(current_worker_id):
                                    if task["id"] != self.request.id and task["delivery_info"]["routing_key"] == "scale-in-protection":
                                        worker_has_remaining_protected_tasks = True
                            if worker_has_remaining_protected_tasks:
                                (logger or logging).info(
                                    "Worker has remaining protected tasks. Leaving scale-in protection enabled."
                                )
                            else:
                                (logger or logging).info(
                                    "Worker has no remaining protected tasks. Disabling scale-in protection."
                            )
                                client.toggle_scale_in_protection(enable=False)
    
            return wrapper
    
        return decorator
    

    Usage as follows:

    @task_with_scale_in_protection
    def protected_task():
        ...