I'm trying to make this swarmplot
with seaborn
My problem is that the swarms are too wide. I want to be able to break them up into rows of maximum 3 dots per row
# Import modules
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
###
# Import and clead dataset
url = "https://raw.githubusercontent.com/amirnakar/scratchboard/master/Goodreads/goodreads_library_export.csv"
Books = pd.read_csv(url)
Books = Books[Books['Date Read'].notna()] # Remove NA
Books['Year'] = pd.to_datetime( # Convert to dates
Books['Date Read'],
format='%YYYY%mm%dd',
errors='coerce')
Books['Year'] = pd.DatetimeIndex(Books['Date Read']).year # Take only years
Books[['Year', 'Date Read']] # merge the years in
###
# Calculate mean rate by year
RateMeans = (Books["My Rating"].groupby(Books["Year"]).mean())
Years = list(RateMeans.index.values)
Rates = list(RateMeans)
RateMeans = pd.DataFrame(
{'Years': Years,
'Rates': Rates
})
###
# Plot
fig,ax = plt.subplots(figsize=(20,10))
## Violin Plot:
plot = sns.violinplot(
data=Books,
x = "Year",
y = 'My Rating',
ax=ax,
color = "white",
inner=None,
#palette=colors_from_values(ArrayRates[:,1], "Blues")
)
## Swarm Plot
plot = sns.swarmplot(
data=Books,
x = "Year",
y = 'My Rating',
ax=ax,
hue = "My Rating",
size = 10
)
## Style
### Title
ax.text(x=0.5, y=1.1, s='Book Ratings: Distribution per Year', fontsize=32, weight='bold', ha='center', va='bottom', transform=ax.transAxes)
ax.text(x=0.5, y=1.05, s='Source: Goodreads.com (amirnakar)', fontsize=24, alpha=0.75, ha='center', va='bottom', transform=ax.transAxes)
### Axis
ax.set(xlim=(4.5, None), ylim=(0,6))
#ax.set_title('Book Ratings: Distribution per Year \n', fontsize = 32)
ax.set_ylabel('Rating (out of 5 stars)', fontsize = 24)
ax.set_xlabel('Year', fontsize = 24)
ax.set_yticklabels(ax.get_yticks().astype(int), size=20)
ax.set_xticklabels(ax.get_xticks(), size=20)
### Legend
plot.legend(loc="lower center", ncol = 5 )
### Colour pallete
colorset = ["#FAFF04", "#FFD500", "#9BFF00", "#0099FF", "#000BFF"]
colorset.reverse()
sns.set_palette(sns.color_palette(colorset))
# Save the plot
#plt.show(plot)
plt.savefig("Rate-Python.svg", format="svg")
I want to be able to define that each row of dots should have a maximum of 3, if it's more, than break it into a new row. I demonstrate it here (done manually in PowerPoint) on two groups, but I want it for the entire plot
Here is an attempt to relocate the dots a bit upward/downward. The value for delta
comes from experimenting.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
# Import and clea dataset
url = "https://raw.githubusercontent.com/amirnakar/scratchboard/master/Goodreads/goodreads_library_export.csv"
Books = pd.read_csv(url)
Books = Books[Books['Date Read'].notna()] # Remove NA
Books['Year'] = pd.DatetimeIndex(Books['Date Read']).year # Take only years
# Calculate mean rate by year
RatePerYear = Books[["My Rating", "Year"]].groupby("Year")["My Rating"].value_counts()
modified_ratings = []
delta = 0.2 # distance to move overlapping ratings
for (year, rating), count in RatePerYear.iteritems():
higher = max(0, ((count - 3) + 1) // 2)
lower = max(0, (count - 3) // 2)
modified_ratings.append([year, rating, count - higher - lower])
for k in range((higher + 2) // 3):
modified_ratings.append([year, rating + (k + 1) * delta, 3 if (k + 1) * 3 <= higher else higher % 3])
for k in range((lower + 2) // 3):
modified_ratings.append([year, rating - (k + 1) * delta, 3 if (k + 1) * 3 <= lower else lower % 3])
modified_ratings = np.array(modified_ratings)
modified_ratings_df = pd.DataFrame(
{'Year': np.repeat(modified_ratings[:, 0].astype(int), modified_ratings[:, 2].astype(int)),
'My Rating': np.repeat(modified_ratings[:, 1], modified_ratings[:, 2].astype(int))})
modified_ratings_df['Rating'] = modified_ratings_df['My Rating'].round().astype(int)
fig, ax = plt.subplots(figsize=(20, 10))
sns.violinplot(data=Books, x="Year", y='My Rating', ax=ax, color="white", inner=None)
palette = ["#FAFF04", "#FFD500", "#9BFF00", "#0099FF", "#000BFF"].reverse()
sns.swarmplot(data=modified_ratings_df, x="Year", y='My Rating', ax=ax, hue="Rating", size=10, palette=palette)
ax.set(xlim=(4.5, None), ylim=(0, 6))
ax.legend(loc="lower center", ncol=5)
plt.tight_layout()
plt.show()