pythonmatplotlibseabornfacet-gridcatplot

How to customize the text ticklabels for each subplot of a seaborn catplot


Let us consider the following example (from Seaborn documentation):

titanic = sns.load_dataset("titanic")

fg = sns.catplot(x="age", y="embark_town",
                hue="sex", row="class",
                data=titanic[titanic.embark_town.notnull()],
                orient="h", height=2, aspect=3, palette="Set3",
                kind="violin", dodge=True, cut=0, bw=.2)

Output:

enter image description here

I want to change the tick labels on the y axis, for example by prepending a number in parenthesis: (1) Southampton, (2) Cherbourg, (3) Queenstown. I have seen this answer, and I have tried to use a FuncFormatter, but I obtain a strange result. Here is my code:

titanic = sns.load_dataset("titanic")

fg = sns.catplot(x="age", y="embark_town",
                hue="sex", row="class",
                data=titanic[titanic.embark_town.notnull()],
                orient="h", height=2, aspect=3, palette="Set3",
                kind="violin", dodge=True, cut=0, bw=.2)

from matplotlib.ticker import FuncFormatter
for ax in fg.axes.flat:
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'({1 + pos}) {x}'))

And here is the output:

enter image description here

It looks like x is the same as pos in the lambda. I was expecting x to be the value of the tick label (i.e. Southampton, Cherbourg, Queenstown). What am I doing wrong?


Software versions:

matplotlib                         3.4.3
seaborn                            0.11.2

Solution

  • import seaborn as sns
    
    titanic = sns.load_dataset("titanic")
    
    fg = sns.catplot(x="age", y="embark_town",
                    hue="sex", row="class",
                    data=titanic[titanic.embark_town.notnull()],
                    orient="h", height=2, aspect=3, palette="Set3",
                    kind="violin", dodge=True, cut=0, bw=.2)
    
    for ax in fg.axes.flat:  # iterate through each subplot
        labels = ax.get_yticklabels()  # get the position and text for each subplot
        for label in labels:
            _, y = label.get_position()  # extract the y tick position
            txt = label.get_text()  # extract the text
            txt = f'({y + 1}) {txt}'  # update the text string
            label.set_text(txt)  # set the text
        ax.set_yticklabels(labels)  # update the yticklabels
    

    enter image description here