I am using Flask with Celery and I am trying to lock a specific task so that it can only be run one at a time. In the celery docs it gives a example of doing this Celery docs, Ensuring a task is only executed one at a time. This example that was given was for Django however I am using flask I have done my best to convert this to work with Flask however I still see myTask1 which has the lock can be run multiple times.
One thing that is not clear to me is if I am using the cache correctly, I have never used it before so all of it is new to me. One thing from the doc's that is mentioned but not explained is this
In order for this to work correctly you need to be using a cache backend where the .add operation is atomic. memcached is known to work well for this purpose.
Im not truly sure what that means, should i be using the cache in conjunction with a database and if so how would I do that? I am using mongodb. In my code I just have this setup for the cache cache = Cache(app, config={'CACHE_TYPE': 'simple'})
as that is what was mentioned in the Flask-Cache doc's Flask-Cache Docs
Another thing that is not clear to me is if there is anything different I need to do as I am calling my myTask1
from within my Flask route task1
Here is an example of my code that I am using.
from flask import (Flask, render_template, flash, redirect,
url_for, session, logging, request, g, render_template_string, jsonify)
from flask_caching import Cache
from contextlib import contextmanager
from celery import Celery
from Flask_celery import make_celery
from celery.result import AsyncResult
from celery.utils.log import get_task_logger
from celery.five import monotonic
from flask_pymongo import PyMongo
from hashlib import md5
import pymongo
import time
app = Flask(__name__)
cache = Cache(app, config={'CACHE_TYPE': 'simple'})
app.config['SECRET_KEY']= 'super secret key for me123456789987654321'
######################
# MONGODB SETUP
#####################
app.config['MONGO_HOST'] = 'localhost'
app.config['MONGO_DBNAME'] = 'celery-test-db'
app.config["MONGO_URI"] = 'mongodb://localhost:27017/celery-test-db'
mongo = PyMongo(app)
##############################
# CELERY ARGUMENTS
##############################
app.config['CELERY_BROKER_URL'] = 'amqp://localhost//'
app.config['CELERY_RESULT_BACKEND'] = 'mongodb://localhost:27017/celery-test-db'
app.config['CELERY_RESULT_BACKEND'] = 'mongodb'
app.config['CELERY_MONGODB_BACKEND_SETTINGS'] = {
"host": "localhost",
"port": 27017,
"database": "celery-test-db",
"taskmeta_collection": "celery_jobs",
}
app.config['CELERY_TASK_SERIALIZER'] = 'json'
celery = Celery('task',broker='mongodb://localhost:27017/jobs')
celery = make_celery(app)
LOCK_EXPIRE = 60 * 2 # Lock expires in 2 minutes
@contextmanager
def memcache_lock(lock_id, oid):
timeout_at = monotonic() + LOCK_EXPIRE - 3
# cache.add fails if the key already exists
status = cache.add(lock_id, oid, LOCK_EXPIRE)
try:
yield status
finally:
# memcache delete is very slow, but we have to use it to take
# advantage of using add() for atomic locking
if monotonic() < timeout_at and status:
# don't release the lock if we exceeded the timeout
# to lessen the chance of releasing an expired lock
# owned by someone else
# also don't release the lock if we didn't acquire it
cache.delete(lock_id)
@celery.task(bind=True, name='app.myTask1')
def myTask1(self):
self.update_state(state='IN TASK')
lock_id = self.name
with memcache_lock(lock_id, self.app.oid) as acquired:
if acquired:
# do work if we got the lock
print('acquired is {}'.format(acquired))
self.update_state(state='DOING WORK')
time.sleep(90)
return 'result'
# otherwise, the lock was already in use
raise self.retry(countdown=60) # redeliver message to the queue, so the work can be done later
@celery.task(bind=True, name='app.myTask2')
def myTask2(self):
print('you are in task2')
self.update_state(state='STARTING')
time.sleep(120)
print('task2 done')
@app.route('/', methods=['GET', 'POST'])
def index():
return render_template('index.html')
@app.route('/task1', methods=['GET', 'POST'])
def task1():
print('running task1')
result = myTask1.delay()
# get async task id
taskResult = AsyncResult(result.task_id)
# push async taskid into db collection job_task_id
mongo.db.job_task_id.insert({'taskid': str(taskResult), 'TaskName': 'task1'})
return render_template('task1.html')
@app.route('/task2', methods=['GET', 'POST'])
def task2():
print('running task2')
result = myTask2.delay()
# get async task id
taskResult = AsyncResult(result.task_id)
# push async taskid into db collection job_task_id
mongo.db.job_task_id.insert({'taskid': str(taskResult), 'TaskName': 'task2'})
return render_template('task2.html')
@app.route('/status', methods=['GET', 'POST'])
def status():
taskid_list = []
task_state_list = []
TaskName_list = []
allAsyncData = mongo.db.job_task_id.find()
for doc in allAsyncData:
try:
taskid_list.append(doc['taskid'])
except:
print('error with db conneciton in asyncJobStatus')
TaskName_list.append(doc['TaskName'])
# PASS TASK ID TO ASYNC RESULT TO GET TASK RESULT FOR THAT SPECIFIC TASK
for item in taskid_list:
try:
task_state_list.append(myTask1.AsyncResult(item).state)
except:
task_state_list.append('UNKNOWN')
return render_template('status.html', data_list=zip(task_state_list, TaskName_list))
from flask import (Flask, render_template, flash, redirect,
url_for, session, logging, request, g, render_template_string, jsonify)
from flask_caching import Cache
from contextlib import contextmanager
from celery import Celery
from Flask_celery import make_celery
from celery.result import AsyncResult
from celery.utils.log import get_task_logger
from celery.five import monotonic
from flask_pymongo import PyMongo
from hashlib import md5
import pymongo
import time
import redis
from flask_redis import FlaskRedis
app = Flask(__name__)
# ADDING REDIS
redis_store = FlaskRedis(app)
# POINTING CACHE_TYPE TO REDIS
cache = Cache(app, config={'CACHE_TYPE': 'redis'})
app.config['SECRET_KEY']= 'super secret key for me123456789987654321'
######################
# MONGODB SETUP
#####################
app.config['MONGO_HOST'] = 'localhost'
app.config['MONGO_DBNAME'] = 'celery-test-db'
app.config["MONGO_URI"] = 'mongodb://localhost:27017/celery-test-db'
mongo = PyMongo(app)
##############################
# CELERY ARGUMENTS
##############################
# CELERY USING REDIS
app.config['CELERY_BROKER_URL'] = 'redis://localhost:6379/0'
app.config['CELERY_RESULT_BACKEND'] = 'mongodb://localhost:27017/celery-test-db'
app.config['CELERY_RESULT_BACKEND'] = 'mongodb'
app.config['CELERY_MONGODB_BACKEND_SETTINGS'] = {
"host": "localhost",
"port": 27017,
"database": "celery-test-db",
"taskmeta_collection": "celery_jobs",
}
app.config['CELERY_TASK_SERIALIZER'] = 'json'
celery = Celery('task',broker='mongodb://localhost:27017/jobs')
celery = make_celery(app)
LOCK_EXPIRE = 60 * 2 # Lock expires in 2 minutes
@contextmanager
def memcache_lock(lock_id, oid):
timeout_at = monotonic() + LOCK_EXPIRE - 3
print('in memcache_lock and timeout_at is {}'.format(timeout_at))
# cache.add fails if the key already exists
status = cache.add(lock_id, oid, LOCK_EXPIRE)
try:
yield status
print('memcache_lock and status is {}'.format(status))
finally:
# memcache delete is very slow, but we have to use it to take
# advantage of using add() for atomic locking
if monotonic() < timeout_at and status:
# don't release the lock if we exceeded the timeout
# to lessen the chance of releasing an expired lock
# owned by someone else
# also don't release the lock if we didn't acquire it
cache.delete(lock_id)
@celery.task(bind=True, name='app.myTask1')
def myTask1(self):
self.update_state(state='IN TASK')
print('dir is {} '.format(dir(self)))
lock_id = self.name
print('lock_id is {}'.format(lock_id))
with memcache_lock(lock_id, self.app.oid) as acquired:
print('in memcache_lock and lock_id is {} self.app.oid is {} and acquired is {}'.format(lock_id, self.app.oid, acquired))
if acquired:
# do work if we got the lock
print('acquired is {}'.format(acquired))
self.update_state(state='DOING WORK')
time.sleep(90)
return 'result'
# otherwise, the lock was already in use
raise self.retry(countdown=60) # redeliver message to the queue, so the work can be done later
@celery.task(bind=True, name='app.myTask2')
def myTask2(self):
print('you are in task2')
self.update_state(state='STARTING')
time.sleep(120)
print('task2 done')
@app.route('/', methods=['GET', 'POST'])
def index():
return render_template('index.html')
@app.route('/task1', methods=['GET', 'POST'])
def task1():
print('running task1')
result = myTask1.delay()
# get async task id
taskResult = AsyncResult(result.task_id)
# push async taskid into db collection job_task_id
mongo.db.job_task_id.insert({'taskid': str(taskResult), 'TaskName': 'myTask1'})
return render_template('task1.html')
@app.route('/task2', methods=['GET', 'POST'])
def task2():
print('running task2')
result = myTask2.delay()
# get async task id
taskResult = AsyncResult(result.task_id)
# push async taskid into db collection job_task_id
mongo.db.job_task_id.insert({'taskid': str(taskResult), 'TaskName': 'task2'})
return render_template('task2.html')
@app.route('/status', methods=['GET', 'POST'])
def status():
taskid_list = []
task_state_list = []
TaskName_list = []
allAsyncData = mongo.db.job_task_id.find()
for doc in allAsyncData:
try:
taskid_list.append(doc['taskid'])
except:
print('error with db conneciton in asyncJobStatus')
TaskName_list.append(doc['TaskName'])
# PASS TASK ID TO ASYNC RESULT TO GET TASK RESULT FOR THAT SPECIFIC TASK
for item in taskid_list:
try:
task_state_list.append(myTask1.AsyncResult(item).state)
except:
task_state_list.append('UNKNOWN')
return render_template('status.html', data_list=zip(task_state_list, TaskName_list))
if __name__ == '__main__':
app.secret_key = 'super secret key for me123456789987654321'
app.run(port=1234, host='localhost')
Here is also a screen shot you can see that I ran myTask1
two times and myTask2 a single time. Now I have the expected behavior for myTask1. Now myTask1
will be run by a single worker if another worker attempt to pick it up it will just keep retrying based on whatever i define.
I also found this to be a surprisingly hard problem. Inspired mainly by Sebastian's work on implementing a distributed locking algorithm in redis I wrote up a decorator function.
A key point to bear in mind about this approach is that we lock tasks at the level of the task's argument space, e.g. we allow multiple game update/process order tasks to run concurrently, but only one per game. That's what argument_signature
achieves in the code below. You can see documentation on how we use this in our stack at this gist:
import base64
from contextlib import contextmanager
import json
import pickle as pkl
import uuid
from backend.config import Config
from redis import StrictRedis
from redis_cache import RedisCache
from redlock import Redlock
rds = StrictRedis(Config.REDIS_HOST, decode_responses=True, charset="utf-8")
rds_cache = StrictRedis(Config.REDIS_HOST, decode_responses=False, charset="utf-8")
redis_cache = RedisCache(redis_client=rds_cache, prefix="rc", serializer=pkl.dumps, deserializer=pkl.loads)
dlm = Redlock([{"host": Config.REDIS_HOST}])
TASK_LOCK_MSG = "Task execution skipped -- another task already has the lock"
DEFAULT_ASSET_EXPIRATION = 8 * 24 * 60 * 60 # by default keep cached values around for 8 days
DEFAULT_CACHE_EXPIRATION = 1 * 24 * 60 * 60 # we can keep cached values around for a shorter period of time
REMOVE_ONLY_IF_OWNER_SCRIPT = """
if redis.call("get",KEYS[1]) == ARGV[1] then
return redis.call("del",KEYS[1])
else
return 0
end
"""
@contextmanager
def redis_lock(lock_name, expires=60):
# https://breadcrumbscollector.tech/what-is-celery-beat-and-how-to-use-it-part-2-patterns-and-caveats/
random_value = str(uuid.uuid4())
lock_acquired = bool(
rds.set(lock_name, random_value, ex=expires, nx=True)
)
yield lock_acquired
if lock_acquired:
rds.eval(REMOVE_ONLY_IF_OWNER_SCRIPT, 1, lock_name, random_value)
def argument_signature(*args, **kwargs):
arg_list = [str(x) for x in args]
kwarg_list = [f"{str(k)}:{str(v)}" for k, v in kwargs.items()]
return base64.b64encode(f"{'_'.join(arg_list)}-{'_'.join(kwarg_list)}".encode()).decode()
def task_lock(func=None, main_key="", timeout=None):
def _dec(run_func):
def _caller(*args, **kwargs):
with redis_lock(f"{main_key}_{argument_signature(*args, **kwargs)}", timeout) as acquired:
if not acquired:
return TASK_LOCK_MSG
return run_func(*args, **kwargs)
return _caller
return _dec(func) if func is not None else _dec
Implementation in our task definitions file:
@celery.task(name="async_test_task_lock")
@task_lock(main_key="async_test_task_lock", timeout=UPDATE_GAME_DATA_TIMEOUT)
def async_test_task_lock(game_id):
print(f"processing game_id {game_id}")
time.sleep(TASK_LOCK_TEST_SLEEP)
How we test against a local celery cluster:
from backend.tasks.definitions import async_test_task_lock, TASK_LOCK_TEST_SLEEP
from backend.tasks.redis_handlers import rds, TASK_LOCK_MSG
class TestTaskLocking(TestCase):
def test_task_locking(self):
rds.flushall()
res1 = async_test_task_lock.delay(3)
res2 = async_test_task_lock.delay(5)
self.assertFalse(res1.ready())
self.assertFalse(res2.ready())
res3 = async_test_task_lock.delay(5)
res4 = async_test_task_lock.delay(5)
self.assertEqual(res3.get(), TASK_LOCK_MSG)
self.assertEqual(res4.get(), TASK_LOCK_MSG)
time.sleep(TASK_LOCK_TEST_SLEEP)
res5 = async_test_task_lock.delay(3)
self.assertFalse(res5.ready())
(as a goodie there's also a quick example of how to setup a redis_cache
)