python-3.xseabornscatter3d

Use color map from seaborn sns.pairplot tab10_r on mpl_toolkits 3D scatter plot


I am trying to get the first two colors found in the sns.pairplot shown here using

colors = dict(zip(['2', '1', '0'], sns.color_palette('tab10_r', 3)))

enter image description here

on the 3D scatter plot using mpl_toolkits. When I put the tab10_r into the cmap, I get this output:

enter image description here

Here is my code:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d


fig = plt.figure()
plt.figure(figsize=(15, 12))
ax = plt.axes(projection="3d")

x= x_pca[:,0]
y= x_pca[:,1]
z = x_pca[:,2]


y_train_new = y_train_orig.argmax(1)

ax.scatter3D(x, y, z, c = y_train_new, marker = 'o', alpha=0.6, s=55, cmap='tab10_r', edgecolor='k')
plt.title("3D Scatterplot: 95% of the variability captured", pad = 15)
ax.set_xlabel('First principal component')
ax.set_ylabel('Second principal component')
ax.set_zlabel('Third principal component')

Here is some data:

x y z not used Label
-3.8481877 -0.47685334 0.63422906 1.0396314 1
-2.320888 0.65347993 1.1519914 0.12997247 1
1.5827686 1.4119303 -1.7410104 -4.6962333 1
-0.1337152 0.13315737 -1.6648949 -1.4205348 1
-0.4028037 1.332986 1.3618442 0.3292255 1
-0.015517877 1.346349 1.4083523 0.87017965 1
-0.2669228 0.5478992 -0.06730786 -1.5959451 1
-0.03318152 0.3263167 -2.116833 -5.4616213 1
0.4588691 0.6723614 -1.617398 -4.3511734 1
0.5899199 0.66525555 -1.694493 -3.9452586 1
1.610061 2.4186094 1.8807093 1.3764497 0
1.7985699 2.4387648 1.6306056 1.1184534 0
-9.222036 -9.9776 -9.832 -9.909746 0
0.21364458 -1.0171559 -4.9093766 -6.2154694 0
-0.019955145 -1.1677283 -4.6549516 -5.9503417 0
0.44730473 -0.77167743 -4.7527356 -5.971007 0
-0.16508447 -0.005777468 -1.5020386 -4.49326 0
-0.8654994 -0.54387957 -1.300646 -4.621529 0
-1.7471086 -2.0005553 -1.7533782 -2.6065414 0
-1.5313624 -1.6995796 -1.4394685 -2.600004 0

How to go about getting the the colors from the seaborn plots onto the 3D scatter plot?


Solution

  • Thanks for the clarification. From your comments, I believe you want to use the green and purple colors available in tab10_r, which would automatically get picked up if you had 3 categories.

    I think you have already know most of the answer. From the code around pairplot creation, you already know how to get the dictionary with the rgb values. Use it to get the two rgb colors you need....

    colors = dict(zip(['2', '1', '0'], sns.color_palette('tab10_r', 3))) #Your code
    new_colors = list(colors.values()) #Get the rgb values for 3 colors into a list
    del new_colors[0] #Remove first color - Grey
    

    Once you have the list, you can use ListedColormap() to create your own cmap and use it.... (updated code - replace your ax.scatter() with this)

    from matplotlib.colors import ListedColormap
    mycmap = ListedColormap(new_colors)
    
    ax.scatter3D(x, y, z, c = y_train_new, marker = 'o', alpha=1, s=85, edgecolor='k', cmap=mycmap)#'tab10_r', )
    

    Output

    enter image description here