pytorchstable-diffusiondiffusersgenerative-programming

stabilityai/stable-cascade takes 7+ hours to generate an image


I am using this model: https://huggingface.co/stabilityai/stable-cascade

from diffusers import StableCascadeCombinedPipeline

print("LOADING MODEL")
pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16)
print("MODEL LOADED")

prompt = "a lawyer"
pipe(
    prompt=prompt,
    negative_prompt="",
    num_inference_steps=10,
    prior_num_inference_steps=20,
    prior_guidance_scale=3.0,
    width=1024,
    height=1024,
).images[0].save("cascade-combined2.png")

The model is loaded almost instantly but the next part took 7 plus hours.

Loading pipeline components...: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00,  9.73it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████| 6/6 [00:00<00:00, 11.08it/s]
MODEL LOADED
  0%|                                                                                | 0/20 [00:00<?, ?it/s]  0%|                                                                                | 0/20 [04:10<?, ?it/s]

I am using

Apple M2 Pro (32 GB)

Python 3.10.2

Is there anything I can do to speed this up? Becuase I would like to generate maybe 50 images and that doesn't seem possible at the current speed.

Edit:

pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float32)

device = torch.device('mps')
pipe.to(device)

prompt = "a football"
pipe(
    prompt=prompt,
    negative_prompt="",
    num_inference_steps=10,
    prior_num_inference_steps=20,
    prior_guidance_scale=3.0,
    width=1024,
    height=1024,
).images[0].save("cascade-combined2.png")

Solution

  • Assuming you now have access to your GPU, this can be checked with:

    >>> torch.backends.mps.is_available()
    >>> torch.backends.mps.is_built()
    

    You can perform inference on that device instead of the CPU:

    device = torch.device('mps')
    pipe = pipe.to(device)