pythonmultithreadingpython-3.xpython-asynciothrottling

Throttling Async Functions in Python Asyncio


I have a list of awaitables that I want to pass to the asyncio.AbstractEventLoop but I need to throttle the requests to a third party API.

I would like to avoid something that waits to pass the future to the loop because in the meantime I block my loop waiting. What options do I have? Semaphores and ThreadPools will limit how many are running concurrently, but that's not my problem. I need to throttle my requests to 100/sec, but it doesn't matter how long it takes to complete the request.

This is a very concise (non)working example using the standard library, that demonstrates the problem. This is supposed to throttle at 100/sec but throttles at 116.651/sec. What's the best way to throttle the scheduling of an asynchronous request in asyncio?

Working code:

import asyncio
from threading import Lock

class PTBNL:

    def __init__(self):
        self._req_id_seq = 0
        self._futures = {}
        self._results = {}
        self.token_bucket = TokenBucket()
        self.token_bucket.set_rate(100)

    def run(self, *awaitables):

        loop = asyncio.get_event_loop()

        if not awaitables:
            loop.run_forever()
        elif len(awaitables) == 1:
            return loop.run_until_complete(*awaitables)
        else:
            future = asyncio.gather(*awaitables)
            return loop.run_until_complete(future)

    def sleep(self, secs) -> True:

        self.run(asyncio.sleep(secs))
        return True

    def get_req_id(self) -> int:

        new_id = self._req_id_seq
        self._req_id_seq += 1
        return new_id

    def start_req(self, key):

        loop = asyncio.get_event_loop()
        future = loop.create_future()
        self._futures[key] = future
        return future

    def end_req(self, key, result=None):

        future = self._futures.pop(key, None)
        if future:
            if result is None:
                result = self._results.pop(key, [])
            if not future.done():
                future.set_result(result)

    def req_data(self, req_id, obj):
        # Do Some Work Here
        self.req_data_end(req_id)
        pass

    def req_data_end(self, req_id):
        print(req_id, " has ended")
        self.end_req(req_id)

    async def req_data_async(self, obj):

        req_id = self.get_req_id()
        future = self.start_req(req_id)

        self.req_data(req_id, obj)

        await future
        return future.result()

    async def req_data_batch_async(self, contracts):

        futures = []
        FLAG = False

        for contract in contracts:
            req_id = self.get_req_id()
            future = self.start_req(req_id)
            futures.append(future)

            nap = self.token_bucket.consume(1)

            if FLAG is False:
                FLAG = True
                start = asyncio.get_event_loop().time()

            asyncio.get_event_loop().call_later(nap, self.req_data, req_id, contract)

        await asyncio.gather(*futures)
        elapsed = asyncio.get_event_loop().time() - start

        return futures, len(contracts)/elapsed

class TokenBucket:

    def __init__(self):
        self.tokens = 0
        self.rate = 0
        self.last = asyncio.get_event_loop().time()
        self.lock = Lock()

    def set_rate(self, rate):
        with self.lock:
            self.rate = rate
            self.tokens = self.rate

    def consume(self, tokens):
        with self.lock:
            if not self.rate:
                return 0

            now = asyncio.get_event_loop().time()
            lapse = now - self.last
            self.last = now
            self.tokens += lapse * self.rate

            if self.tokens > self.rate:
                self.tokens = self.rate

            self.tokens -= tokens

            if self.tokens >= 0:
                return 0
            else:
                return -self.tokens / self.rate


if __name__ == '__main__':

    asyncio.get_event_loop().set_debug(True)
    app = PTBNL()

    objs = [obj for obj in range(500)]

    l,t = app.run(app.req_data_batch_async(objs))

    print(l)
    print(t)

Edit: I've added a simple example of TrottleTestApp here using semaphores, but still can't throttle the execution:

import asyncio
import time


class ThrottleTestApp:

    def __init__(self):
        self._req_id_seq = 0
        self._futures = {}
        self._results = {}
        self.sem = asyncio.Semaphore()

    async def allow_requests(self, sem):
        """Permit 100 requests per second; call 
           loop.create_task(allow_requests())
        at the beginning of the program to start this routine.  That call returns
        a task handle that can be canceled to end this routine.

        asyncio.Semaphore doesn't give us a great way to get at the value other
        than accessing sem._value.  We do that here, but creating a wrapper that
        adds a current_value method would make this cleaner"""

        while True:
            while sem._value < 100: sem.release()
            await asyncio.sleep(1)  # Or spread more evenly 
                                    # with a shorter sleep and 
                                    # increasing the value less

    async def do_request(self, req_id, obj):
        await self.sem.acquire()

        # this is the work for the request
        self.req_data(req_id, obj)

    def run(self, *awaitables):

        loop = asyncio.get_event_loop()

        if not awaitables:
            loop.run_forever()
        elif len(awaitables) == 1:
            return loop.run_until_complete(*awaitables)
        else:
            future = asyncio.gather(*awaitables)
            return loop.run_until_complete(future)

    def sleep(self, secs: [float]=0.02) -> True:

        self.run(asyncio.sleep(secs))
        return True

    def get_req_id(self) -> int:

        new_id = self._req_id_seq
        self._req_id_seq += 1
        return new_id

    def start_req(self, key):

        loop = asyncio.get_event_loop()
        future = loop.create_future()
        self._futures[key] = future
        return future

    def end_req(self, key, result=None):

        future = self._futures.pop(key, None)
        if future:
            if result is None:
                result = self._results.pop(key, [])
            if not future.done():
                future.set_result(result)

    def req_data(self, req_id, obj):
        # This is the method that "does" something
        self.req_data_end(req_id)
        pass

    def req_data_end(self, req_id):

        print(req_id, " has ended")
        self.end_req(req_id)

    async def req_data_batch_async(self, objs):

        futures = []
        FLAG = False

        for obj in objs:
            req_id = self.get_req_id()
            future = self.start_req(req_id)
            futures.append(future)

            if FLAG is False:
                FLAG = True
                start = time.time()

            self.do_request(req_id, obj)

        await asyncio.gather(*futures)
        elapsed = time.time() - start
        print("Roughly %s per second" % (len(objs)/elapsed))

        return futures


if __name__ == '__main__':

    asyncio.get_event_loop().set_debug(True)
    app = ThrottleTestApp()

    objs = [obj for obj in range(10000)]

    app.run(app.req_data_batch_async(objs))

Solution

  • You can do this by implementing the leaky bucket algorithm:

    import asyncio
    from contextlib import AbstractAsyncContextManager
    from functools import partial
    from heapq import heappop, heappush
    from itertools import count
    from types import TracebackType
    from typing import List, Optional, Tuple, Type
    
    
    class AsyncLimiter(AbstractAsyncContextManager):
        """A leaky bucket rate limiter.
    
        This is an :ref:`asynchronous context manager <async-context-managers>`;
        when used with :keyword:`async with`, entering the context acquires
        capacity::
    
            limiter = AsyncLimiter(10)
            for foo in bar:
                async with limiter:
                    # process foo elements at 10 items per minute
    
        :param max_rate: Allow up to `max_rate` / `time_period` acquisitions before
           blocking.
        :param time_period: duration, in seconds, of the time period in which to
           limit the rate. Note that up to `max_rate` acquisitions are allowed
           within this time period in a burst.
    
        """
    
        __slots__ = (
            "max_rate",
            "time_period",
            "_rate_per_sec",
            "_level",
            "_last_check",
            "_event_loop",
            "_waiters",
            "_next_count",
            "_waker_handle",
        )
    
        max_rate: float  #: The configured `max_rate` value for this limiter.
        time_period: float  #: The configured `time_period` value for this limiter.
    
        def __init__(self, max_rate: float, time_period: float = 60) -> None:
            self.max_rate = max_rate
            self.time_period = time_period
            self._rate_per_sec = max_rate / time_period
            self._level = 0.0
            self._last_check = 0.0
    
            # timer until next waiter can resume
            self._waker_handle: asyncio.TimerHandle | None = None
            # min-heap with (amount requested, order, future) for waiting tasks
            self._waiters: List[Tuple[float, int, "asyncio.Future[None]"]] = []
            # counter used to order waiting tasks
            self._next_count = partial(next, count())
    
        @property
        def _loop(self) -> asyncio.AbstractEventLoop:
            self._event_loop: asyncio.AbstractEventLoop
            try:
                loop = self._event_loop
            except AttributeError:
                loop = self._event_loop = asyncio.get_running_loop()
            return loop
    
        def _leak(self) -> None:
            """Drip out capacity from the bucket."""
            now = self._loop.time()
            if self._level:
                # drip out enough level for the elapsed time since
                # we last checked
                elapsed = now - self._last_check
                decrement = elapsed * self._rate_per_sec
                self._level = max(self._level - decrement, 0)
            self._last_check = now
    
        def has_capacity(self, amount: float = 1) -> bool:
            """Check if there is enough capacity remaining in the limiter
    
            :param amount: How much capacity you need to be available.
    
            """
            self._leak()
            return self._level + amount <= self.max_rate
    
        async def acquire(self, amount: float = 1) -> None:
            """Acquire capacity in the limiter.
    
            If the limit has been reached, blocks until enough capacity has been
            freed before returning.
    
            :param amount: How much capacity you need to be available.
            :exception: Raises :exc:`ValueError` if `amount` is greater than
               :attr:`max_rate`.
    
            """
            if amount > self.max_rate:
                raise ValueError("Can't acquire more than the maximum capacity")
    
            loop = self._loop
            while not self.has_capacity(amount):
                # Add a future to the _waiters heapq to be notified when capacity
                # has come up. The future callback uses call_soon so other tasks
                # are checked *after* completing capacity acquisition in this task.
                fut = loop.create_future()
                fut.add_done_callback(partial(loop.call_soon, self._wake_next))
                heappush(self._waiters, (amount, self._next_count(), fut))
                self._wake_next()
                await fut
    
            self._level += amount
            # reset the waker to account for the new, lower level.
            self._wake_next()
    
            return None
    
        def _wake_next(self, *_args: object) -> None:
            """Wake the next waiting future or set a timer"""
            # clear timer and any cancelled futures at the top of the heap
            heap, handle, self._waker_handle = self._waiters, self._waker_handle, None
            if handle is not None:
                handle.cancel()
            while heap and heap[0][-1].done():
                heappop(heap)
    
            if not heap:
                # nothing left waiting
                return
    
            amount, _, fut = heap[0]
            self._leak()
            needed = amount - self.max_rate + self._level
            if needed <= 0:
                heappop(heap)
                fut.set_result(None)
                # fut.set_result triggers another _wake_next call
                return
    
            wake_next_at = self._last_check + (1 / self._rate_per_sec * needed)
            self._waker_handle = self._loop.call_at(wake_next_at, self._wake_next)
    
        def __repr__(self) -> str:  # pragma: no cover
            args = f"max_rate={self.max_rate!r}, time_period={self.time_period!r}"
            state = f"level: {self._level:f}, waiters: {len(self._waiters)}"
            if (handle := self._waker_handle) and not handle.cancelled():
                microseconds = int((handle.when() - self._loop.time()) * 10**6)
                if microseconds > 0:
                    state += f", waking in {microseconds} \N{MICRO SIGN}s"
            return f"<AsyncLimiter({args}) at {id(self):#x} [{state}]>"
    
        async def __aenter__(self) -> None:
            await self.acquire()
            return None
    
        async def __aexit__(
            self,
            exc_type: Optional[Type[BaseException]],
            exc: Optional[BaseException],
            tb: Optional[TracebackType],
        ) -> None:
            return None
    

    Note that we leak capacity from the bucket opportunistically, there is no need to run a separate async task just to lower the level; instead, capacity are leaked out when testing for sufficient remaining capacity.

    Note that tasks that wait for capacity are kept in a min-heap, and when there might be capacity to spare again, the first still-waiting task is woken up early.

    You can use this as a context manager; trying to acquire the bucket when it is full blocks until enough capacity has been freed again:

    bucket = AsyncLeakyBucket(100)
    
    # ...
    
    async with bucket:
        # only reached once the bucket is no longer full
    

    or you can call acquire() directly:

    await bucket.acquire()  # blocks until there is space in the bucket
    

    or you can simply test if there is space first:

    if bucket.has_capacity():
        # reject a request due to rate limiting
    

    Note that you can count some requests as 'heavier' or 'lighter' by increasing or decreasing the amount you 'drip' into the bucket:

    await bucket.acquire(10)
    if bucket.has_capacity(0.5):
    

    Do be careful with this though; when mixing large and small drips, small drips tend to get run before large drips when at or close to the maximum rate, because there is a greater likelyhood that there is enough free capacity for a smaller drip before there is space for a larger one.

    Demo:

    >>> import asyncio, time
    >>> bucket = AsyncLeakyBucket(5, 10)
    >>> async def task(id):
    ...     await asyncio.sleep(id * 0.01)
    ...     async with bucket:
    ...         print(f'{id:>2d}: Drip! {time.time() - ref:>5.2f}')
    ...
    >>> ref = time.time()
    >>> tasks = [task(i) for i in range(15)]
    >>> result = asyncio.run(asyncio.wait(tasks))
     0: Drip!  0.00
     1: Drip!  0.02
     2: Drip!  0.02
     3: Drip!  0.03
     4: Drip!  0.04
     5: Drip!  2.05
     6: Drip!  4.06
     7: Drip!  6.06
     8: Drip!  8.06
     9: Drip! 10.07
    10: Drip! 12.07
    11: Drip! 14.08
    12: Drip! 16.08
    13: Drip! 18.08
    14: Drip! 20.09
    

    The bucket is filled up quickly at the start in a burst, causing the rest of the tasks to be spread out more evenly; every 2 seconds enough capacity is freed for another task to be handled.

    The maximum burst size is equal to the maximum rate value, in the above demo that was set to 5. If you do not want to permit bursts, set the maximum rate to 1, and the time period to the minimum time between drips:

    >>> bucket = AsyncLeakyBucket(1, 1.5)  # no bursts, drip every 1.5 seconds
    >>> async def task():
    ...     async with bucket:
    ...         print(f'Drip! {time.time() - ref:>5.2f}')
    ...
    >>> ref = time.time()
    >>> tasks = [task() for _ in range(5)]
    >>> result = asyncio.run(asyncio.wait(tasks))
    Drip!  0.00
    Drip!  1.50
    Drip!  3.01
    Drip!  4.51
    Drip!  6.02
    

    I've gotten round to packaging this up as a Python project: https://github.com/mjpieters/aiolimiter, I've kept the implementation in this answer up-to-date with improvements made in that project.