huggingface-transformerstransformer-modelmultimodal

How to extract image hidden states in LLaVa's transformers (Huggingface) implementation?


I am using the transformers library (Huggingface) to extract all hidden units of LLaVa 1.5. On the huggingface documentation, it shows that it is possible to extract image hidden states from the vision component.

Unfortunately, the outputs object has only these following keys available in the output dictionary: odict_keys(['sequences', 'attentions', 'hidden_states', 'past_key_values'])

How do I also extract the image_hidden_states from this LLaVa implementation alongwith the exisiting outputs?

I have implemented the follow code in the hopes to do so.

import torch
from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig, AutoProcessor, LlavaProcessor
from PIL import Image
import requests
from torchinfo import summary

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_id = 'llava-hf/llava-1.5-7b-hf'

# Initializing a CLIP-vision config
vision_config = CLIPVisionConfig(output_hidden_states=True, output_attentions=True, return_dict=True)

# Initializing a Llama config
text_config = LlamaConfig(output_hidden_states=True, output_attentions=True, return_dict=True)

# Initializing a Llava llava-1.5-7b style configuration
configuration = LlavaConfig(vision_config, text_config, output_hidden_states=True, output_attentions=True, return_dict=True)
cfg=LlavaConfig(vision_config, text_config, output_hidden_states=True, output_attentions=True, return_dict=True)

# Initializing a model from the llava-1.5-7b style configuration
model = LlavaForConditionalGeneration(configuration).from_pretrained(model_id, output_hidden_states=True, output_attentions=True, return_dict=True)

# Accessing the model configuration
configuration = model.config

model=model.to(device)
print(summary(model))

processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", output_hidden_states=True, output_attentions=True, return_dict=True)
prompt = "USER: <image>\nIs there sun in the image? ASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")
inputs=inputs.to(device)

with torch.no_grad():
    outputs = model.generate(**inputs, 
                             output_hidden_states=True, 
                             return_dict_in_generate=True, 
                             max_new_tokens=1, 
                             min_new_tokens=1,
                            return_dict=True)

print(outputs.keys())


Solution

  • Ok, I will try to answer my own question. The solution was quite not available directly with the transformers library. I do not know, why the functionality which is mentioned in their documentation doesn't work. However, I found a work-around by making use of the PyTorch pre-hooks and getting the values of the hidden-units.