pythonseabornseaborn-objects

How to order facets when using the Seaborn objects interface


I am trying to order facets in a plot produced by the seaborn objects interface.

import numpy as np
import seaborn as sns
import seaborn.objects as so
import matplotlib.pyplot as plt

df = sns.load_dataset("iris")
df["species"] = df["species"].astype("category")
df["species"] = df["species"].cat.codes
rng = np.random.default_rng(seed=0)
df["subset"] = rng.choice(['A','B','C'], len(df), replace=True) 

fig = plt.figure(figsize=(6.4 * 2.0, 4.8 * 2.0))

_ = (
    so.Plot(df, x="sepal_length", y="sepal_width")
    .facet(row="species", col="subset")
    .add(so.Dot())
    .on(fig)
    .plot()
)

A plot with facets row-wise and column-wise. The columns are, left to right--C, B, A--and the rows are, top to bottom--0, 1, 2.

However, if col_order or row_order are passed as parameters to the .facet() line an "unexpected keyword argument" TypeError is raised.

_ = (
    so.Plot(df, x="sepal_length", y="sepal_width")
    .facet(
        row="species",
        col="subset",
        row_order=['A','C','B'],
        col_order=[0,2,1]
    )
    .add(so.Dot())
    .on(fig)
    .plot()
)
TypeError: Plot.facet() got an unexpected keyword argument 'row_order'

How should facets be ordered when using the seaborn.objects interface?

Note that this question is very similar to "Seaborn ordering of facets" which is the same question when the plot is generated using seaborn but not the seaborn.objects module.

Ideally, an answer should also work when using the wrap parameter of facet() in the seaborn.objects interface.


Solution

  • Plot.facet has a single order parameter. When only col or row are used a single list can be passed and it will be used by the appropriate variable. When both col and row are used, order can be a dictionary with col/row keys:

    (
        so.Plot(df, x="sepal_length", y="sepal_width")
        .facet(row="species", col="subset", order={"row": [2, 1, 0], "col": ["A", "B", "C"]})
        .add(so.Dot())
        .layout(size=(6.4 * 2, 4.8 * 2))
    )