pythonmatplotlib

Attention weights on top of image


h = 16
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

for i, q_id in enumerate(sorted_indices[0]):
    logit = itm_logit[:, q_id, :]
    prob = torch.nn.functional.softmax(logit, dim=1)
    name = f'{prob[0, 1]:.3f}_query_id_{q_id}'
    
    # Attention map
    attention_map = avg_cross_att[0, q_id, :-1].view(h, h).detach().cpu().numpy()
    
    # Image
    raw_image_resized = raw_image.resize((596, 596))
    
    ax[0].set_title(name)
    ax[0].imshow(attention_map, cmap='viridis')
    ax[0].axis('off')
    
    ax[1].set_title(caption)
    ax[1].imshow(raw_image_resized)
    ax[1].axis('off')
    

    ax[2].set_title(f'Overlay: {name}')
    ax[2].imshow(raw_image_resized)
    ax[2].imshow(attention_map, cmap='viridis', alpha=0.6)  
    ax[2].axis('off')
    

    ax[0].set_aspect('equal')
    ax[1].set_aspect('equal')
    ax[2].set_aspect('equal')
    
    plt.tight_layout()
    plt.savefig(f"./att_maps/{name}.jpg")
    plt.show()
    break

enter image description here

What I am trying to do is overlay the attention weights on top of the image (on thrid axes), so I can see which part of the image attention weight is more focused on.

However, the code that I put only overlap the attention weight on top of the image.

What might be the problem in this case?


Solution

  • The root cause of this is the different resolution of the image and the attention map. This way, the second imshow call reduced the displayed area to a tiny corner of the original image, with an overlay of the 16x16 attention map.

    To fix this, the attention map needs to be upscaled (e.g. via np.repeat) to the image resolution. Here's an example:

    enter image description here

    import numpy as np
    from matplotlib import pyplot as plt
    from matplotlib import image
    
    attention_map = np.random.rand(16, 16)
    img = image.imread("merlion.jpg")
    
    plt.figure("uneven shapes")
    plt.imshow(img)
    plt.imshow(attention_map, cmap='viridis', alpha=0.3)
    
    # naive upscaling via np.repeat in both dimensions
    attention_map_upscale = np.repeat(np.repeat(attention_map, img.shape[0] // attention_map.shape[0], axis=0),
                                      img.shape[1] // attention_map.shape[1], axis=1)
    
    plt.figure("even shapes")
    plt.imshow(img)
    plt.imshow(attention_map_upscale, cmap='viridis', alpha=0.3)
    
    plt.show()