I'm using TextIteratorStreamer
to generate text as stream
and I use Thread
to run model.generate
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
I want to introduce a cancel_event = asyncio.Event()
and check if cancel_event.is_set()
in the streamer loop to stop model.generate
consuming GPU resources, How can I stop model.generate
? Do I need to kill the thread
? how?
You can create a class that should handle the cancelling without the need to kill the thread. Something like this might work:
import asyncio
from threading import Thread
from transformers import StoppingCriteria
class StopCriteria(StoppingCriteria):
def __init__(self, event):
self.event = event
def __call__(self, *args, **kwargs):
return self.event.is_set()
cancel_event = asyncio.Event()
generation_kwargs['stopping_criteria'] = [StopCriteria(cancel_event)]
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# your existing code
cancel_event.set()
If you want the thread to finish before cancelling it you could use:
thread.join()