huggingface-transformersdiffusers

How to stop hugging face pipeline operation


I need to stop hugging face pipeline operation. I tried to achieve this using a method from the following question, but it didn't work. I set the breakpoint on the line return flag and expected debugger to stop on it and I change the value.

How to implement `stopping_criteria` parameter in transformers library?

My code:

import torch
from diffusers import DiffusionPipeline
from transformers import StoppingCriteriaList

pipe = DiffusionPipeline.from_pretrained(
    "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
pipe.enable_sequential_cpu_offload()
prompt = "Cemetery of abandoned vehicles, many different rusty cars"

flag = False


def custom_stopping_criteria(
    input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs
) -> bool:
    return flag


stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])


image = pipe(prompt, stopping_criteria=stopping_criteria).images[0]
image.show()

Solution

  • The pipeline object you initialize is from the diffusers library and they have an interuption implementation that differs from the transformers pipeline class implementation.

    The diffusers library allows you to specify a method to interrupt the diffusion process via the callback_on_step_end parameter. The method is called at the end of each step (code reference) and receives (not sure what the t parameter represents):

    To interrupt the pipeline you simply set the pipeline._interrupt property to True and it stop the generation. The following example corresponds to the implementation in the original question:

    import torch
    from diffusers import DiffusionPipeline
    
    # Using a smaller model for demonstration
    model_id = "OFA-Sys/small-stable-diffusion-v0"
    
    
    pipe = DiffusionPipeline.from_pretrained(
        model_id, torch_dtype=torch.float16
    )
    pipe.enable_sequential_cpu_offload()
    
    prompt = "Cemetery of abandoned vehicles, many different rusty cars"
    
    
    def interrupt_callback(pipeline, i, t, callback_kwargs):
        pipeline._interrupt = True
    
        return callback_kwargs
    
    image = pipe(
        prompt,
        callback_on_step_end=interrupt_callback,
    )
    

    Output:

    1/50 [00:01<01:21,  1.67s/it]
    Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.