pythonpandasseabornjointplot

How to make a jointplot in Seaborn with multiple groups or categories?


I am trying to make a jointplot in Seaborn. The goal is to have a scatter plot of all [x,z] values and to have these color-coded by [cat], and to have the distributions for these two categories. Then I also want a scatter and distribution plot of [x,alt_Z], ignoring the alt_Z values that are NaN.

Using Python 3.7

Here is a stand-alone dataset and my goal (made in Excel, so the distributions are not shown).

import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import seaborn as sns

col1 = [1,1.5,3.1,3.4,2,-1]
col2 = [1,-3,2,8,2.5,-1.3]
col3 = [4,3,4,0.5,1,0.3]
col4 = [10,12,10,'NaN',13,'NaN']
col5 = ['A','A','A','B','A','B']
df = pd.DataFrame(list(zip(col1, col2, col3, col4, col5)), 
                  columns =['x', 'y', 'z', 'alt_Z', 'cat'])
display(df)

enter image description here
enter image description here

The code below doesn't finish the plot and returns TypeError: The y variable is categorical, but one of ['numeric', 'datetime'] is required. I also don't how, in the code below, to group by [cat] A & B, so it is shown as red and only the A category is plotting.

df2 = df[['x', 'y', 'z', 'alt_Z', 'cat']]\
    .melt(id_vars=['x', 'y'], value_vars=['z', 'alt_Z'])
    
g = sns.jointplot(data=df2, x='x', y='value', hue='variable', 
                  palette={'z': 'black', 'alt_Z': 'red'})

enter image description here


Solution

  • One problem with the dataframe, is that col4 contains integers and 'NaN'. As there don't exist NaN values for integers, pandas makes it a column of objects. Converting it to floats will create a proper float column with NaN as numbers.

    To create the scatter plot, two calls to sns.scatter() will do:

    import matplotlib.pyplot as plt
    import pandas as pd
    import seaborn as sns
    
    col1 = [1, 1.5, 3.1, 3.4, 2, -1]
    col2 = [1, -3, 2, 8, 2.5, -1.3]
    col3 = [4, 3, 4, 0.5, 1, 0.3]
    col4 = [10, 12, 10, 'NaN', 13, 'NaN']
    col5 = ['A', 'A', 'A', 'B', 'A', 'B']
    df = pd.DataFrame(list(zip(col1, col2, col3, col4, col5)),
                      columns=['x', 'y', 'z', 'alt_Z', 'cat'])
    df['alt_Z'] = df['alt_Z'].astype(float)
    
    ax = sns.scatterplot(data=df, x='x', y='alt_Z', color='black', label='alt_Z')
    sns.scatterplot(data=df, x='x', y='z', hue='cat', ax=ax)
    
    plt.show()
    

    scatterplot

    From here, we can create 2 dataframes: df1 containing x, z and cat. And df2 containing x and alt_Z. Renaming alt_Z to z and filling in a cat column containing the string alt_Z will make it similar to df1.

    The jointplot() can then operate on the concatenation of both datafames:

    df1 = df[['x', 'z', 'cat']]
    df2 = df[['x', 'alt_Z']].rename(columns={'alt_Z': 'z'}).dropna()
    df2['cat'] = 'alt_Z'
    
    g = sns.jointplot(data=df1.append(df2), x='x', y='z', hue='cat', palette={'alt_Z': 'black', 'A': 'orange', 'B': 'green'})
    g.ax_joint.set_xlim(-3, 6) # the default limits are too wide for these reduced test data
    plt.show()
    

    jointplot