pythonmatplotlibseabornswarmplot

Seaborn swarmplot break into lines


I'm trying to make this swarmplot with seaborn

My problem is that the swarms are too wide. I want to be able to break them up into rows of maximum 3 dots per row

This is my code:

# Import modules
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
###

# Import and clead dataset
url = "https://raw.githubusercontent.com/amirnakar/scratchboard/master/Goodreads/goodreads_library_export.csv"
Books = pd.read_csv(url)
Books = Books[Books['Date Read'].notna()]   # Remove NA

Books['Year'] = pd.to_datetime(             # Convert to dates
    Books['Date Read'],
    format='%YYYY%mm%dd', 
    errors='coerce')

Books['Year'] = pd.DatetimeIndex(Books['Date Read']).year # Take only years
Books[['Year', 'Date Read']]                 # merge the years in
###

# Calculate mean rate by year
RateMeans = (Books["My Rating"].groupby(Books["Year"]).mean())
Years = list(RateMeans.index.values)
Rates = list(RateMeans)
RateMeans = pd.DataFrame(
    {'Years': Years,
     'Rates': Rates
    })
###

# Plot
fig,ax = plt.subplots(figsize=(20,10))

## Violin Plot:
plot = sns.violinplot(
    data=Books, 
    x = "Year", 
    y = 'My Rating', 
    ax=ax,
    color = "white", 
    inner=None,
    #palette=colors_from_values(ArrayRates[:,1], "Blues")
    )

## Swarm Plot
plot = sns.swarmplot(
    data=Books, 
    x = "Year", 
    y = 'My Rating', 
    ax=ax,
    hue = "My Rating",
    size = 10
    )
    
## Style
    
### Title
ax.text(x=0.5, y=1.1, s='Book Ratings: Distribution per Year', fontsize=32, weight='bold', ha='center', va='bottom', transform=ax.transAxes)
ax.text(x=0.5, y=1.05, s='Source: Goodreads.com (amirnakar)', fontsize=24, alpha=0.75, ha='center', va='bottom', transform=ax.transAxes)



### Axis
ax.set(xlim=(4.5, None), ylim=(0,6))
#ax.set_title('Book Ratings: Distribution per Year \n', fontsize = 32)
ax.set_ylabel('Rating (out of 5 stars)', fontsize = 24)
ax.set_xlabel('Year', fontsize = 24)
ax.set_yticklabels(ax.get_yticks().astype(int), size=20)
ax.set_xticklabels(ax.get_xticks(), size=20)

### Legend
plot.legend(loc="lower center", ncol = 5 )

### Colour pallete
colorset = ["#FAFF04", "#FFD500", "#9BFF00", "#0099FF", "#000BFF"]
colorset.reverse()
sns.set_palette(sns.color_palette(colorset))


# Save the plot
#plt.show(plot)
plt.savefig("Rate-Python.svg", format="svg")

This is the output:

enter image description here

What I want to have happen:

I want to be able to define that each row of dots should have a maximum of 3, if it's more, than break it into a new row. I demonstrate it here (done manually in PowerPoint) on two groups, but I want it for the entire plot

BEFORE:

BEFORE

AFTER:

enter image description here


Solution

  • Here is an attempt to relocate the dots a bit upward/downward. The value for delta comes from experimenting.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    # Import and clea dataset
    url = "https://raw.githubusercontent.com/amirnakar/scratchboard/master/Goodreads/goodreads_library_export.csv"
    Books = pd.read_csv(url)
    Books = Books[Books['Date Read'].notna()]  # Remove NA
    Books['Year'] = pd.DatetimeIndex(Books['Date Read']).year  # Take only years
    # Calculate mean rate by year
    RatePerYear = Books[["My Rating", "Year"]].groupby("Year")["My Rating"].value_counts()
    
    modified_ratings = []
    delta = 0.2  # distance to move overlapping ratings
    for (year, rating), count in RatePerYear.iteritems():
        higher = max(0, ((count - 3) + 1) // 2)
        lower = max(0, (count - 3) // 2)
        modified_ratings.append([year, rating, count - higher - lower])
        for k in range((higher + 2) // 3):
            modified_ratings.append([year, rating + (k + 1) * delta, 3 if (k + 1) * 3 <= higher else higher % 3])
        for k in range((lower + 2) // 3):
            modified_ratings.append([year, rating - (k + 1) * delta, 3 if (k + 1) * 3 <= lower else lower % 3])
    modified_ratings = np.array(modified_ratings)
    modified_ratings_df = pd.DataFrame(
        {'Year': np.repeat(modified_ratings[:, 0].astype(int), modified_ratings[:, 2].astype(int)),
         'My Rating': np.repeat(modified_ratings[:, 1], modified_ratings[:, 2].astype(int))})
    modified_ratings_df['Rating'] = modified_ratings_df['My Rating'].round().astype(int)
    
    fig, ax = plt.subplots(figsize=(20, 10))
    sns.violinplot(data=Books, x="Year", y='My Rating', ax=ax, color="white", inner=None)
    palette = ["#FAFF04", "#FFD500", "#9BFF00", "#0099FF", "#000BFF"].reverse()
    sns.swarmplot(data=modified_ratings_df, x="Year", y='My Rating', ax=ax, hue="Rating", size=10, palette=palette)
    
    ax.set(xlim=(4.5, None), ylim=(0, 6))
    ax.legend(loc="lower center", ncol=5)
    plt.tight_layout()
    plt.show()
    

    sns.swarmplot limiting row widths to 3