h = 16
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))
for i, q_id in enumerate(sorted_indices[0]):
logit = itm_logit[:, q_id, :]
prob = torch.nn.functional.softmax(logit, dim=1)
name = f'{prob[0, 1]:.3f}_query_id_{q_id}'
# Attention map
attention_map = avg_cross_att[0, q_id, :-1].view(h, h).detach().cpu().numpy()
# Image
raw_image_resized = raw_image.resize((596, 596))
ax[0].set_title(name)
ax[0].imshow(attention_map, cmap='viridis')
ax[0].axis('off')
ax[1].set_title(caption)
ax[1].imshow(raw_image_resized)
ax[1].axis('off')
ax[2].set_title(f'Overlay: {name}')
ax[2].imshow(raw_image_resized)
ax[2].imshow(attention_map, cmap='viridis', alpha=0.6)
ax[2].axis('off')
ax[0].set_aspect('equal')
ax[1].set_aspect('equal')
ax[2].set_aspect('equal')
plt.tight_layout()
plt.savefig(f"./att_maps/{name}.jpg")
plt.show()
break
What I am trying to do is overlay the attention weights on top of the image (on thrid axes), so I can see which part of the image attention weight is more focused on.
However, the code that I put only overlap the attention weight on top of the image.
What might be the problem in this case?
The root cause of this is the different resolution of the image and the attention map. This way, the second imshow
call reduced the displayed area to a tiny corner of the original image, with an overlay of the 16x16 attention map.
To fix this, the attention map needs to be upscaled (e.g. via np.repeat
) to the image resolution. Here's an example:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image
attention_map = np.random.rand(16, 16)
img = image.imread("merlion.jpg")
plt.figure("uneven shapes")
plt.imshow(img)
plt.imshow(attention_map, cmap='viridis', alpha=0.3)
# naive upscaling via np.repeat in both dimensions
attention_map_upscale = np.repeat(np.repeat(attention_map, img.shape[0] // attention_map.shape[0], axis=0),
img.shape[1] // attention_map.shape[1], axis=1)
plt.figure("even shapes")
plt.imshow(img)
plt.imshow(attention_map_upscale, cmap='viridis', alpha=0.3)
plt.show()