plt.figure(figure=(6,8))
sns.barplot(data=csat_korean,x='grade',y='percentage').set_title('Korean');
plt.show()
sns.barplot(data=csat_math,x='grade',y='percentage').set_title('Math');
plt.show()
sns.barplot(data=csat_english,x='grade',y='percentage').set_title('English');
plt.show()
Hello, the above code is me trying to plot with sns(everything is imported). However, the plots I get are not constant in size. This is the result I get. How would I fix this code so that the plots are same in size, and the last graph is not annoyingly smaller than the other plots?
Greatly appreciate it!
With seaborn's catplot()
, you can generate a grid of bar plots, starting from a combined dataframe. The size of the subplots is set via the height=
and aspect=
parameters (width = height * aspect
).
By default, the x and y axis are shared between the subplots, so they look very similar.
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
# generate some dummy test data
courses = ['Math', 'English', 'Korean']
csat_data = pd.DataFrame({'course': np.repeat(courses, 9),
'grade': np.tile(np.arange(1, 10), 3),
'count': np.random.randint(10, 1000, size=27)})
csat_data['percentage'] = csat_data.groupby('course')['count'].transform(lambda x: (x / x.sum()) * 100)
g = sns.catplot(csat_data, kind='bar', x='grade', y='percentage', row='course',
height=8, aspect=6/8)
plt.show()
PS: For your original dataframes, you can combine them like:
csat_math['course'] = 'Math'
csat_english['course'] = 'English'
csat_korean['course'] = 'Korean'
csat_data = pd.concat([csat_math, csat_english, csat_korean])
You can also let Seaborn calculate the histograms directly from the original data.
courses = ['Math', 'English', 'Korean']
csat_data = pd.DataFrame({'course': np.repeat(courses, 50),
'grade':
np.clip(np.random.normal(loc=6, scale=1.8, size=150).round().astype(int), 1, 9)})
g = sns.displot(csat_data, kind='hist', stat='percent',
x='grade', discrete=True, row='course',
height=8, aspect=6/8)
g.set(xticks=range(1, 10))
plt.show(block=True)