pythonpandasnumpysurvival-analysisscikit-survival

How to get the probability density function from CoxPHSurvivalAnalysis in scikit-survival?


I am using sksurv.linear_model.CoxPHSurvivalAnalysis to fit a cox ph regression and I would like to recover the density function f(t). The sksurv class has methods to predict the survival function and cumulative distribution function S(t) = 1-F(t) and the cumulative hazard function $H(t)$ but it doesn't seem to produce the density function.

My use case has no censoring, so ere is an example:

import pandas as pd
import numpy as np
from sksurv.linear_model import CoxPHSurvivalAnalysis

data = np.random.randint(5,30,size=10)
X_train = pd.DataFrame(data, columns=['covariate'])

y_train = np.array(np.random.randint(0,100,size=10)/100,dtype=[('status',bool),('target',float)])

estimator = CoxPHSurvivalAnalysis()
estimator.fit(X_train,y_train)

X_test = pd.DataFrame({'covariate':[12,2]})
chf = estimator.predict_cumulative_hazard_function(X_test)
cdf = estimator.predict_survival_function(X_test)

fig, ax = plt.subplots(1,2)
for fn_h, fn_c in zip(chf, cdf):
    ax[0].step(fn_h.x,fn_h(fn_h.x),where='post')
    ax[1].step(fn_c.x,fn_c(fn_c.x),where='post')

ax[0].set_title('Cumulative Hazard Functions')
ax[1].set_title('Survival Functions')
plt.show()


enter image description here How can I also access and plot the density function?


Solution

  • The probability density function (PDF) can be obtained from the cumulative distribution function (CDF) as :

    f(t) = dF(t)/dt
    
    

    Now, in Survival Analysis (SA) the PDF (f(t)) can be expressed in terms of Survival Function S(t) and the hazard function h(t) which is given by:

    f(t) = h(t) x S(t)
    
    

    where S(t) = 1 - F(t) and h(t) = -dS(t)/dt x S(t) = dH(t)/dt

    So, the PDF f(t) can be expressed as : f(t) = dH(t)/dt x S(t)

    Now, to compute the hazard function f(t) we need derivative of Cumulative Hazard Function (CHF) H(t). Since the CHF are all discrete data points, we need InterpolatedUnivariateSpline from the scipy library to differentiate it. It creates a smooth spline interpolation of the CHF, which can then be differentiated to obtain h(t). Here's a slight modification of the code that was pasted:

    # Import the necessary libraries
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from sksurv.linear_model import CoxPHSurvivalAnalysis
    from scipy.interpolate import InterpolatedUnivariateSpline
    
    # Define a function to compute the probability density function (pdf) 
    # from the cumulative hazard function (chf) and survival function (sf).
    def compute_pdf_from_chf_and_sf(chf, sf):
        # The hazard function is the derivative of the cumulative hazard function.
        # We use InterpolatedUnivariateSpline for spline interpolation to create a smooth 
        # function approximation of the CHF. This provides us with a smooth curve that 
        # passes through each data point, allowing us to differentiate the function and obtain 
        # the hazard function.
        chf_spline = InterpolatedUnivariateSpline(chf.x, chf(chf.x))
        hazard_function = chf_spline.derivative()(chf.x)
        
        # The pdf can be computed using the formula: pdf(t) = hazard(t) * survival(t)
        pdf = hazard_function * sf(chf.x)
        return chf.x, pdf
    
    # Generate random data for demonstration purposes
    # Here, we create a random dataset with one covariate and survival times.
    
    np.random.seed(42)  # Setting a fixed seed.
    data = np.random.randint(5, 30, size=10)
    X_train = pd.DataFrame(data, columns=['covariate'])
    y_train = np.array(np.random.randint(0, 100, size=10)/100, dtype=[('status', bool), ('target', float)])
    
    # Initialize and fit the Cox Proportional Hazards model
    estimator = CoxPHSurvivalAnalysis()
    estimator.fit(X_train, y_train)
    
    # Predict for new data points
    X_test = pd.DataFrame({'covariate': [12, 2]})
    cumulative_hazard_functions = estimator.predict_cumulative_hazard_function(X_test)
    survival_functions = estimator.predict_survival_function(X_test)
    
    # Plot the Cumulative Hazard, Survival, and PDF side by side in a single row
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    for chf, sf in zip(cumulative_hazard_functions, survival_functions):
        # Compute the pdf using our defined function
        times, pdf_values = compute_pdf_from_chf_and_sf(chf, sf)
        
        # Plotting the cumulative hazard function
        axes[0].step(chf.x, chf(chf.x), where='post')
        
        # Plotting the survival function
        axes[1].step(sf.x, sf(sf.x), where='post')
        
        # Plotting the probability density function
        axes[2].step(times, pdf_values, where='post')
    
    # Setting titles for each subplot
    axes[0].set_title('Cumulative Hazard Functions')
    axes[1].set_title('Survival Functions')
    axes[2].set_title('Probability Density Functions')
    
    # Display the plots
    plt.tight_layout()
    plt.show()
    
    
    

    which results in

    PDF

    References : Machine Learning for Survival Analysis: A Survey