Class based Celery task

In the previous post, I showed you how to implement basic Celery task that make use of @task decorator and some pattern on how to remove circular dependencies when calling the task from Flask view. Let's recall some part of the code.

def run_task_async():
    task = chain(long_run_task.s(2,2), long_map_task.s(4)).apply_async()

    return task

Here, I am chaining two celery tasks. Initially, this works in my case but I have issues like how to check for the task status. The variable task contains the taskid for the last job in the chain, which is long_map_task. To find status for all other tasks, I have to recursively query the parentid of the last task to get all the statuses in the chain.

This might work if you're running few simple tasks that chained together. In my case, my task complexity is quite high - build graph of companies, run matrix inversion, save the result into database, and generate some reports. I would like to have this inside its own module and have the module called from the view. So that we have a cleaner view. Let's look at how we can create this module by implementing class based Celery task.

Implementation

from time import sleep

import requests
from celery import Task

from .factory import celery_app

class CalculationWorker(Task):

    def __init__(self, *args, **kwargs):
        self.database = kwargs.get('database', None)
        self.host = kwargs.get('host', None)

    # Main entry
    def run(self, *args, **kwargs):
        self.long_run_task()
        self.long_map_task()

    # Wrap the celery app within the Flask context
    def bind(self, app):
        return super(self.__class__, self).bind(celery_app)


    def on_success(self, retval, task_id, *args, **kwargs):
        r = requests.get('{host}/api/jobs/{jobid}/callback'.format(host='http://127.0.0.1:3000', jobid=task_id))


    def long_run_task(self):
        print 'starting core job'
        self.update_state(state='PROGRESS', meta={'stage': 'calculating', 'percentage': 10})
        sleep(25)


    def long_map_task(self):
        print 'mapping'
        self.update_state(state='PROGRESS', meta={'stage': 'mapping', 'percentage': 50})
        sleep(25)

Our custom task class inherits celery.Task and we override the run method to call our custom codes that we would like to run. So, when you run like someTask.apply_async(), the run method here will be invoked. We override bind method so that we can wrap Flask context into our task. We then override on_success, so that maybe I want other services to be notified that this task has just finish running. You can check the documentation to see other available methods that you can override.

In the view, you can do like this to run the CalculationWorker task.

def run_job():
	worker = CalculationWorker(host='localhost', database='db')
	task = worker.apply_async()

	payload = dict(message='Job is running', jobid=task.id)

	return jsonify(status='success', data=payload), 200

This is cleaner since there is only one taskid per actual task. We no longer need to recursively query its parent to find all the related tasks.

You might ask how to make the status much more fine grained since now I have multiple jobs running in a single celery task. We can leverage the update_state method to achieve this. I can pass metadata (dictionary) to the meta keyword argument of the update_state. See above code to see how this is implemented.

To access this metadata from the view (maybe you exposed an endpoint to check for task status), you can do something like someTask.info.get('stage', None). One caveat though, the return type of someTask.info is dependent of the task status/state. If the someTask.state is SUCCESS, someTask.info will contains the result. If it is still running, then the return type is a dictionary containing the meta params. That's the reason why I use get method on someTask.info so that we don't throw exception but just fall back to whatever default value that works for my case. Have a look at the documentation for more details.

@api.route('/jobs/<jobid>', methods=['GET'])
def job_status(jobid):
    """
    Query Celery for task status based on its taskid/jobid

    Celery status can be one of:
    PENDING - Job not yet run or unknown status
    PROGRESS - Job is currently running
    SUCCESS - Job completed successfully
    FAILURE - Job failed
    REVOKED - Job get canceled
    """
    # Query database for job's information
    job = Job.query.get(jobid)

    if job:
        worker = factory_worker.manufacture(job.jobtype)
        task = worker.AsyncResult(job.id)

        # If task state is success, then the return value of task.info will be the return value
        # This will cause an error when trying to read the task metadata
        if task.state == 'PENDING':
            payload = dict(jobid=task.id, status=task.state)
            current_task_stage = None
        elif task.state == 'SUCCESS':
            payload = dict(jobid=task.id, status=task.state, percentage=100)
            current_task_stage = None
        else:
            current_task_stage = task.info.get('stage', None)
            current_percentage = task.info.get('percentage', 0)
            payload = dict(jobid=task.id, status=task.state, stage=current_task_stage, percentage=current_percentage)
        # Update db with latest task status
        job.status = task.status
        job.stage = current_task_stage
        db.session.commit()

        # Return value for http  response
        status = 'ok'
        status_code = 200
    else:
        status = 'error'
        status_code = 400
        payload = dict(message='No task with that jobid.')
    return jsonify(status=status, data=payload), status_code

Improvement

What I showed you here is the basic skeleton. On top of this, I also use factory pattern to create task worker. I have multiple task workers that should do different kind of tasks. It would just make the code much cleaner this way.