pythonmatplotlibscatter-plotgraph-coloring

Python scatterplot: how to use a colormap that has the same colors as the colorcycle


I am trying to color clusters in a scatter plot and I managed with two different methods.

In the first I plot iteratively each cluster, in the second I plot all the data at once and colour the clusters according to their labels [0, 1, 2, 3 ,4].

I am happy with the result I get in example1 and example3 but I don't understand why the coloring changes so dramatically when coloring the clusters according to the labels instead of iteratively plotting each cluster.

Additionally, why the second cluster (despite having always label "1") has a different color in example1 and example3?

import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight') #irrelevant here, but coherent with the examples=)

fig, ax = plt.subplots(figsize=(6,4))
for clust in range(kmeans.n_clusters):
    ax.scatter(X[kmeans.labels_==clust],Y[kmeans.labels_==clust])
    ax.set_title("example1")`

example1

and

plt.figure(figsize = (6, 4))
plt.scatter(X,Y,c=kmeans.labels_.astype(float))
plt.title("example2")

example2

(I know I can explicitly define a colormap for the second method but I couldn't find any that reproduces the results in example 1)

Here is a minimal working example

import matplotlib.pyplot as plt
import pandas as pd
plt.style.use('fivethirtyeight') #irrelevant here, but coherent with the examples=)
X=pd.Series([1, 2, 3, 4, 5, 11, 12, 13, 14, 15])
Y=pd.Series([1,1,1,1,1,2,2,2,2,2])
clusters=pd.Series([0,0,0,0,0,1,1,1,1,1])


fig, ax = plt.subplots(figsize=(6,4))
for clust in range(2):
ax.scatter(X[clusters==clust],Y[clusters==clust])
ax.set_title("example3")

example3

plt.figure(figsize = (6, 4))
plt.scatter(X,Y, c=clusters)
plt.title("example4")

example4


Solution

  • When you loop over the clusters and plot a scatter without specifying any color, the default colors of the active property cycler (color cycle) will be used. The active property cycler is defined in the rcParams. It is set via the style in use; in your case, using 'fivethirtyeight'

    print(plt.rcParams["axes.prop_cycle"])
    > cycler('color', ['#008fd5', '#fc4f30', '#e5ae38', '#6d904f', '#8b8b8b', '#810f7c'])
    

    The first two colors of this ('#008fd5', '#fc4f30') are the one you see in the plot.

    When you use a scatter with the clusters as color argument, those values will be mapped to a color via a colormap. If no colormap is specified it will take the default colormap defined in the rcParam.

    print(plt.rcParams["image.cmap"])
    > "viridis"
    

    The 'fivethirtyeight' style does not define any special colormap, so the default would be unchanged. (The fact that you observe a different colormap than viridis in your picture is due to the fact that there was some other code still active which is not shown in the question.)

    At this point I need to start interpreting; I would think that your question really is how to get the single scatter use a colormap that has the same colors as the colorcycle in it. None of the predefined colormaps has the fivethirtyeight cycler colors in it. Hence you would define that colormap manually, by taking the colors from the cycle,

    import matplotlib.colors as mcolors
    cmap = mcolors.ListedColormap(plt.rcParams['axes.prop_cycle'].by_key()['color'])
    

    Now you need a way to index the colormap, because you have discrete clusters.

    n = len(clusters.unique())
    norm = mcolors.BoundaryNorm(np.arange(n+1)-0.5, n)
    

    Of course this requires that the number of colors in the colormap is greater or equal the number of classes - which is the case here.

    Putting it all together, (I added another category, to make it more illustrative)

    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    import matplotlib.colors as mcolors
    
    plt.style.use('fivethirtyeight') #relevant here!!
    
    X=pd.Series([1, 2, 3, 4, 5, 11, 12, 13, 14, 15])
    Y=pd.Series([1,1,1,1,1,2,2,2,2,2])
    clusters=pd.Series([0,0,0,0,0,1,1,1,1,2])
    
    cmap = mcolors.ListedColormap(plt.rcParams['axes.prop_cycle'].by_key()['color'])
    n = len(clusters.unique())
    norm = mcolors.BoundaryNorm(np.arange(n+1)-0.5, n)
    
    plt.figure(figsize = (6, 4))
    sc = plt.scatter(X,Y, c=clusters, cmap=cmap, norm=norm)
    plt.colorbar(sc, ticks=clusters.unique())
    plt.title("example4")
    
    plt.show()
    

    enter image description here