pythonmatplotlibimshowmultiple-axes

How do I take an image that is plotted as a result of a function call and plot to a grid of images?


I have a function that I use to output a photo that has had its pixels clustered using KMeans. I can input the k value as an argument, and it will fit the model and output the new image.

def cluster_image(k, img=img):
  img_flat = img.reshape(img.shape[0]*img.shape[1], 3)
  kmeans = KMeans(n_clusters = k, random_state = 42).fit(img_flat)
  new_img = img_flat.copy()
  
  for i in np.unique(kmeans.labels_):
    new_img[kmeans.labels_ == i, :] = kmeans.cluster_centers_[i]
  
  new_img = new_img.reshape(img.shape)

  return plt.imshow(new_img), plt.axis('off');

I want to write a loop to output the images for k=2 through k=10:

k_values = np.arange(2, 11)
for k in k_values:
  print('k = ' + str(k))
  cluster_image(k)
  show()

This returns a vertical line of images. How do I do something like this, but output each image to a 3x3 grid of images?


Solution

  • If you are allowed to modify the signature of cluster_image, I would do:

    def cluster_image(k, ax, img=img):
        img_flat = img.reshape(img.shape[0]*img.shape[1], 3)
        kmeans = KMeans(n_clusters = k, random_state = 42).fit(img_flat)
        new_img = img_flat.copy()
    
        for i in np.unique(kmeans.labels_):
            new_img[kmeans.labels_ == i, :] = kmeans.cluster_centers_[i]
    
        new_img = new_img.reshape(img.shape)
        ax.imshow(new_img)
        ax.axis('off')
    
    fig, axs = plt.subplots(3, 3)
    axs = axs.flatten()
    k_values = np.arange(2, 11)
    for i, k in enumerate(k_values):
        print('k = ' + str(k))
        cluster_image(k, axs[i], img=img)