pythonsurvival-analysislifelines

Plot Kaplan Meier curves for all "object" data types in data frame


I am attempting to plot survival curves for all "object" data type columns in my data frame whereby each unique value in each column is plotted in its own subplot

The resulting output should be "n" number of subplots whereby each subplot serves to plot the survival curves of each unique value of the data frame column

Some sample data:

import numpy as np
import random
import pandas as pd

duration = np.random.exponential(scale = 5, size = 100).round(1)
boolean = [random.randint(0, 1) for i in range(len(duration))]; boolean = [bool(x) for x in boolean]
group = np.random.choice(["A", "B", "C", "D"], size = len(duration))
house = np.random.choice(["Big", "Small"], p = [0.7, 0.3], size = len(duration))
provider = np.random.choice(["2Degrees", "Skinny", "Vodafone", "Spark"], p = [0.25, 0.25, 0.25, 0.25], size = len(duration))

df = pd.DataFrame(
    {"Duration":duration,
    "Boolean":boolean,
    "Group":group,
    "Gender":gender,
    "Provider":provider}
)

The code that I have so far is the following

from lifelines import KaplanMeierFitter
kmf = KaplanMeierFitter()
fig, axes = plt.subplots(nrows = 1, ncols = 3, figsize = (15, 5))
for column in df:
    if df[column].dtype == object:
        for value in df[column].unique():
            mask = df[column] == value
            kmf.fit(
                durations = df["Duration"][mask],
                event_observed = df["Boolean"][mask],
                label = value
            )
            for ax in axes.flatten():
                kmf.plot_survival_function(ax = ax, ci_show = False)
    

This is so close to getting what I want, except that it plots all survival curves for each unique value of every column on every subplot

The logic that I'm trying to use is similar to this line of code (this code works - if you use the same data made above):

fig, axes = plt.subplots(nrows = 1, ncols = 3, figsize = (15, 5))
cat_cols = []
for column in df:
    if df[column].dtype == object:
        cat_cols.append(column)
        for column, ax in zip(cat_cols, axes.flatten()):
            df[column].value_counts().plot(
                kind = "bar",
                ax = ax
            ).set_title("{} Counts".format(column))

Can someone point me in the right direction?


Solution

  • After some trial and error I managed to get what I want - this answer is intended for those who are working with survival data and who may want to get a "shotgun" approach to see which groups differ in their survival

    from lifelines import KaplanMeierFitter
    kmf = KaplanMeierFitter()
    
    fig, axes = plt.subplots(nrows = 1, ncols = 3, figsize = (15, 5))
    
    cat_cols =[]
    for column in df:
        if df[column].dtype == object:
            cat_cols.append(column)
            
    for column, ax in zip(cat_cols, axes.flatten()):
        for value in df[column].unique():
            mask = df[column] == value
            kmf.fit(
                durations = df["duration"][mask],
                event_observed = df["terminated"][mask],
                label = value
            )
            kmf.plot_survival_function(ax = ax, ci_show = False, legend = True)
        ax.set_title(column.capitalize())
        ax.set_xlabel(None)