In the previous post, I showed you how to implement basic Celery task that make use of @task
decorator and some pattern on how to remove circular dependencies when calling the task from Flask view. Let's recall some part of the code.
def run_task_async():
task = chain(long_run_task.s(2,2), long_map_task.s(4)).apply_async()
return task
Here, I am chaining two celery tasks. Initially, this works in my case but I have issues like how to check for the task status. The variable task
contains the taskid
for the last job in the chain, which is long_map_task
. To find status for all other tasks, I have to recursively query the parentid
of the last task to get all the statuses in the chain.
This might work if you're running few simple tasks that chained together. In my case, my task complexity is quite high - build graph of companies, run matrix inversion, save the result into database, and generate some reports. I would like to have this inside its own module and have the module called from the view. So that we have a cleaner view. Let's look at how we can create this module by implementing class based Celery task.
Implementation
from time import sleep
import requests
from celery import Task
from .factory import celery_app
class CalculationWorker(Task):
def __init__(self, *args, **kwargs):
self.database = kwargs.get('database', None)
self.host = kwargs.get('host', None)
# Main entry
def run(self, *args, **kwargs):
self.long_run_task()
self.long_map_task()
# Wrap the celery app within the Flask context
def bind(self, app):
return super(self.__class__, self).bind(celery_app)
def on_success(self, retval, task_id, *args, **kwargs):
r = requests.get('{host}/api/jobs/{jobid}/callback'.format(host='http://127.0.0.1:3000', jobid=task_id))
def long_run_task(self):
print 'starting core job'
self.update_state(state='PROGRESS', meta={'stage': 'calculating', 'percentage': 10})
sleep(25)
def long_map_task(self):
print 'mapping'
self.update_state(state='PROGRESS', meta={'stage': 'mapping', 'percentage': 50})
sleep(25)
Our custom task class inherits celery.Task
and we override the run
method to call our custom codes that we would like to run. So, when you run like someTask.apply_async()
, the run
method here will be invoked. We override bind
method so that we can wrap Flask context into our task. We then override on_success
, so that maybe I want other services to be notified that this task has just finish running. You can check the documentation to see other available methods that you can override.
In the view, you can do like this to run the CalculationWorker
task.
def run_job():
worker = CalculationWorker(host='localhost', database='db')
task = worker.apply_async()
payload = dict(message='Job is running', jobid=task.id)
return jsonify(status='success', data=payload), 200
This is cleaner since there is only one taskid
per actual task. We no longer need to recursively query its parent to find all the related tasks.
You might ask how to make the status much more fine grained since now I have multiple jobs running in a single celery task. We can leverage the update_state
method to achieve this. I can pass metadata (dictionary) to the meta
keyword argument of the update_state
. See above code to see how this is implemented.
To access this metadata from the view (maybe you exposed an endpoint to check for task status), you can do something like someTask.info.get('stage', None)
. One caveat though, the return type of someTask.info
is dependent of the task status/state. If the someTask.state
is SUCCESS
, someTask.info
will contains the result. If it is still running, then the return type is a dictionary containing the meta params. That's the reason why I use get
method on someTask.info
so that we don't throw exception but just fall back to whatever default value that works for my case. Have a look at the documentation for more details.
@api.route('/jobs/<jobid>', methods=['GET'])
def job_status(jobid):
"""
Query Celery for task status based on its taskid/jobid
Celery status can be one of:
PENDING - Job not yet run or unknown status
PROGRESS - Job is currently running
SUCCESS - Job completed successfully
FAILURE - Job failed
REVOKED - Job get canceled
"""
# Query database for job's information
job = Job.query.get(jobid)
if job:
worker = factory_worker.manufacture(job.jobtype)
task = worker.AsyncResult(job.id)
# If task state is success, then the return value of task.info will be the return value
# This will cause an error when trying to read the task metadata
if task.state == 'PENDING':
payload = dict(jobid=task.id, status=task.state)
current_task_stage = None
elif task.state == 'SUCCESS':
payload = dict(jobid=task.id, status=task.state, percentage=100)
current_task_stage = None
else:
current_task_stage = task.info.get('stage', None)
current_percentage = task.info.get('percentage', 0)
payload = dict(jobid=task.id, status=task.state, stage=current_task_stage, percentage=current_percentage)
# Update db with latest task status
job.status = task.status
job.stage = current_task_stage
db.session.commit()
# Return value for http response
status = 'ok'
status_code = 200
else:
status = 'error'
status_code = 400
payload = dict(message='No task with that jobid.')
return jsonify(status=status, data=payload), status_code
Improvement
What I showed you here is the basic skeleton. On top of this, I also use factory pattern to create task worker. I have multiple task workers that should do different kind of tasks. It would just make the code much cleaner this way.