Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Celery: Rate limit on tasks with the same parameters

I am looking for a way to restrict when a function is called, but only when the input parameters are different, that is:

@app.task(rate_limit="60/s")
def api_call(user):
   do_the_api_call()

for i in range(0,100):
  api_call("antoine")
  api_call("oscar")

So I would like api_call("antoine") to be called 60 times per second and api_call("oscar") 60 times per second as well.

Any help on how can I do that?

--EDIT 27/04/2015 I have tried calling a subtask with rate_limit within a task, but it does not work either: The rate_limit is always applied for all the instantiated subtasks or tasks (which is logical).

@app.task(rate_limit="60/s")
def sub_api_call(user):
   do_the_api_call()

@app.task
def api_call(user):
  sub_api_call(user)

for i in range(0,100):
  api_call("antoine")
  api_call("oscar")

Best!

like image 478
Antoine Brunel Avatar asked Apr 24 '15 17:04

Antoine Brunel


1 Answers

Update

See comments section for a link to a much better approach that incorporates most of what's here, but fixes a ping-pong problem that the version here has. The version here retries tasks naively. That is, it just tries them again later, with some jitter. If you have 1,000 tasks that are all queued, this creates chaos as they all vie for the next available spot. They all just ping-pong into and out of the task worker, getting tried hundreds of times before finally getting an opportunity to be run.

Instead of doing that naive approach, the next thing I tried was an exponential backoff in which each time a task is throttled, it backs off for a little longer than the time before. This concept can work, but it requires that you store the number of retries each task has had, which is annoying and has to be centralized, and it's not optimal either, because you can have long delays of no activity as you wait for a scheduled task to run. (Imagine a task that is throttled for the 50th time and has to wait an hour, while a throttle timer expires a few seconds after it is rescheduled so much later. In that case, the worker would be idle for an hour while it waited for that task to be run.)

The better way to attempt this, instead of a naive retry or an exponential backoff, is with a scheduler. The updated version linked in the comments section maintains a basic scheduler that knows when to retry the task. It keeps track of the order that tasks are throttled, and knows when the next window for a task to run will occur. So, imagine a throttle of 1 task minute, with the following timeline:

00:00:00 - Task 1 is attempted and begins running
00:00:01 - Task 2 is attempted. Oh no! It gets throttled. The current
           throttle expires at 00:01:00, so it is rescheduled then.
00:00:02 - Task 3 is attempted. Oh no! It gets throttled. The current
           throttle expires at 00:01:00, but something is already  
           scheduled then, so it is rescheduled for 00:02:00.
00:01:00 - Task 2 attempts to run again. All clear! It runs.
00:02:00 - Task 3 attempts to run again. All clear! It runs.

In other words, depending on the length of the backlog, it will reschedule the task after the current throttle expires and all other rescheduled, throttled tasks have had their opportunity to run. (This took weeks to figure out.)


Original Answer

I spent some time on this today and came up with a nice solution. All the other solutions to this have one of the following problems:

  • They require tasks to have infinite retries thereby making celery's retry mechanism useless.
  • They don't throttle based on parameters
  • It fails with multiple workers or queues
  • They're clunky, etc.

Basically, you wrap your task like this:

@app.task(bind=True, max_retries=10)
@throttle_task("2/s", key="domain", jitter=(2, 15))
def scrape_domain(self, domain):
    do_stuff()

And the result is that you'll throttle the task to 2 runs per second per domain parameter, with a random retry jitter between 2-15s. The key parameter is optional, but corresponds to a parameter in your task. If the key parameter is not given, it'll just throttle the task to the rate given. If it is provided, then the throttle will apply to the (task, key) dyad.

The other way to look at this is without the decorator. This gives a little more flexibility, but counts on you to do the retrying yourself. Instead of the above, you could do:

@app.task(bind=True, max_retries=10)
def scrape_domain(self, domain):
    proceed = is_rate_okay(self, "2/s", key=domain)
    if proceed:
        do_stuff()
    else:
        self.request.retries = task.request.retries - 1  # Don't count this as against max_retries.
        return task.retry(countdown=random.uniform(2, 15))

I think that's identical to the first example. A bit longer, and more branchy, but shows how it works a bit more clearly. I'm hoping to always use the decorator, myself.

This all works by keeping a tally in redis. The implementation is very simple. You create a key in redis for the task (and the key parameter, if given), and you expire the redis key according to the schedule provided. If the user sets a rate of 10/m, you make a redis key for 60s, and you increment it every time a task with the correct name is attempted. If your incrementor gets too high, retry the task. Otherwise, run it.

def parse_rate(rate: str) -> Tuple[int, int]:
    """

    Given the request rate string, return a two tuple of:
    <allowed number of requests>, <period of time in seconds>

    (Stolen from Django Rest Framework.)
    """
    num, period = rate.split("/")
    num_requests = int(num)
    if len(period) > 1:
        # It takes the form of a 5d, or 10s, or whatever
        duration_multiplier = int(period[0:-1])
        duration_unit = period[-1]
    else:
        duration_multiplier = 1
        duration_unit = period[-1]
    duration_base = {"s": 1, "m": 60, "h": 3600, "d": 86400}[duration_unit]
    duration = duration_base * duration_multiplier
    return num_requests, duration


def throttle_task(
    rate: str,
    jitter: Tuple[float, float] = (1, 10),
    key: Any = None,
) -> Callable:
    """A decorator for throttling tasks to a given rate.

    :param rate: The maximum rate that you want your task to run. Takes the
    form of '1/m', or '10/2h' or similar.
    :param jitter: A tuple of the range of backoff times you want for throttled
    tasks. If the task is throttled, it will wait a random amount of time
    between these values before being tried again.
    :param key: An argument name whose value should be used as part of the
    throttle key in redis. This allows you to create per-argument throttles by
    simply passing the name of the argument you wish to key on.
    :return: The decorated function
    """

    def decorator_func(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            # Inspect the decorated function's parameters to get the task
            # itself and the value of the parameter referenced by key.
            sig = inspect.signature(func)
            bound_args = sig.bind(*args, **kwargs)
            task = bound_args.arguments["self"]
            key_value = None
            if key:
                try:
                    key_value = bound_args.arguments[key]
                except KeyError:
                    raise KeyError(
                        f"Unknown parameter '{key}' in throttle_task "
                        f"decorator of function {task.name}. "
                        f"`key` parameter must match a parameter "
                        f"name from function signature: '{sig}'"
                    )
            proceed = is_rate_okay(task, rate, key=key_value)
            if not proceed:
                logger.info(
                    "Throttling task %s (%s) via decorator.",
                    task.name,
                    task.request.id,
                )
                # Decrement the number of times the task has retried. If you
                # fail to do this, it gets auto-incremented, and you'll expend
                # retries during the backoff.
                task.request.retries = task.request.retries - 1
                return task.retry(countdown=random.uniform(*jitter))
            else:
                # All set. Run the task.
                return func(*args, **kwargs)

        return wrapper

    return decorator_func


def is_rate_okay(task: Task, rate: str = "1/s", key=None) -> bool:
    """Keep a global throttle for tasks

    Can be used via the `throttle_task` decorator above.

    This implements the timestamp-based algorithm detailed here:

        https://www.figma.com/blog/an-alternative-approach-to-rate-limiting/

    Basically, you keep track of the number of requests and use the key
    expiration as a reset of the counter.

    So you have a rate of 5/m, and your first task comes in. You create a key:

        celery_throttle:task_name = 1
        celery_throttle:task_name.expires = 60

    Another task comes in a few seconds later:

        celery_throttle:task_name = 2
        Do not update the ttl, it now has 58s remaining

    And so forth, until:

        celery_throttle:task_name = 6
        (10s remaining)

    We're over the threshold. Re-queue the task for later. 10s later:

        Key expires b/c no more ttl.

    Another task comes in:

        celery_throttle:task_name = 1
        celery_throttle:task_name.expires = 60

    And so forth.

    :param task: The task that is being checked
    :param rate: How many times the task can be run during the time period.
    Something like, 1/s, 2/h or similar.
    :param key: If given, add this to the key placed in Redis for the item.
    Typically, this will correspond to the value of an argument passed to the
    throttled task.
    :return: Whether the task should be throttled or not.
    """
    key = f"celery_throttle:{task.name}{':' + str(key) if key else ''}"

    r = make_redis_interface("CACHE")

    num_tasks, duration = parse_rate(rate)

    # Check the count in redis
    count = r.get(key)
    if count is None:
        # No key. Set the value to 1 and set the ttl of the key.
        r.set(key, 1)
        r.expire(key, duration)
        return True
    else:
        # Key found. Check it.
        if int(count) <= num_tasks:
            # We're OK, run it.
            r.incr(key, 1)
            return True
        else:
            return False
like image 117
mlissner Avatar answered Sep 28 '22 10:09

mlissner