I am currently using the diffusers StableDiffusionPipeline
(from hugging face) to generate AI images with a discord bot which I use with my friends. I was wondering if it was possible to get a preview of the image being generated before it is finished?
For example, if an image takes 20 seconds to generate, since it is using diffusion it starts off blury and gradually gets better and better. What I want is to save the image on each iteration (or every few seconds) and see how it progresses. How would I be able to do this?
class ImageGenerator:
def __init__(self, socket_listener, pretty_logger, prisma):
self.model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=os.environ.get("HF_AUTH_TOKEN"))
self.model = self.model.to("cuda")
async def generate_image(self, data):
start_time = time.time()
with autocast("cuda"):
image = self.model(data.description, height=self.default_height, width=self.default_width,
num_inference_steps=self.default_inference_steps, guidance_scale=self.default_guidance_scale)
image.save(...)
The code I have currently is this, however it only returns the image when it is completely done. I have tried to look into how the image is generated inside of the StableDiffusionPipeline but I cannot find anywhere where the image is generated. If anybody could provide any pointers/tips on where I can begin that would be very helpful.
You can use the callback argument of the stable diffusion pipeline to get the latent space representation of the image: link to documentation
The implementation shows how the latents are converted back to an image. We just have to copy that code and decode the latents.
Here is a small example that saves the generated image every 5 steps:
from diffusers import StableDiffusionPipeline
import torch
#load model
model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token="YOUR TOKEN HERE")
model = model.to("cuda")
def callback(iter, t, latents):
# convert latents to image
with torch.no_grad():
latents = 1 / 0.18215 * latents
image = model.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# convert to PIL Images
image = model.numpy_to_pil(image)
# do something with the Images
for i, img in enumerate(image):
img.save(f"iter_{iter}_img{i}.png")
# generate image (note the `callback` and `callback_steps` argument)
image = model("tree", callback=callback, callback_steps=5)
To understand the stable diffusion model I highly recommend this blog post.