pythontensorflowtf.kerascalibrationtensorflow-probability

Keras TensorFlow Probability model not learning distribution spread


I built and trained a Keras Tensorflow Probability model. It's basically a fully connected Neural Network model with a DistributionLambda on the output layer. Last Layer code example here:

tfp.layers.DistributionLambda(
            lambda t: tfd.Independent(tfd.Normal(loc=t[..., :n], scale=1e-5 + tf.nn.softplus(c + t[..., n:])),
                                      reinterpreted_batch_ndims=1))

During training I'm using Mean Squared Error as my loss function. The training seems to progress well and is numerically stable.

After training, I start by removing the last layer of the model and then make forward-pass predictions with my test set data. This basically gives me the "learned" expected loc and scale for the distribution the model learned for each data point in the test set. However, because of the softplus correction in the DistributionLambda I also have to apply that same correction to the chopped model's prediction for scale.

I'm trying to verify that the model learned the appropriate distributions contingent on the input values. So, with these predictions for the loc (mean) and scale (standard deviation) I can create calibration plots to see how well the model learned the latent distributions. The calibration plot for the mean looks great. I'm also creating a calibration plot for the scale/stdev parameter with code like this:

def create_stdev_calibration_plot(df: pd.DataFrame,
                              y_true: str = 'y_true',
                              y_pred_mean: str = 'y_pred_mean',
                              y_pred_std: str = 'y_pred_std',
                              title: Optional[str] = None,
                              save_path: Optional[str] = None):

    # Compute the residuals
    df['residual'] = df[y_true] - df[y_pred_mean]

    # Bin data based on predicted standard deviation
    bins = np.linspace(df[y_pred_std].min(), df[y_pred_std].max(), 10)
    df['bin'] = np.digitize(df[y_pred_std], bins)

    # For each bin, compute mean predicted std and actual std of residuals
    df['y_pred_variance'] = df[y_pred_std] ** 2
    bin_means_variance = df.groupby('bin')['y_pred_variance'].mean()

    # Convert back to standard deviation
    bin_means = np.sqrt(bin_means_variance)
    bin_residual_stds = df.groupby('bin')['residual'].std()

    # Create the calibration plot
    plt.figure(figsize=(8, 8))
    plt.plot(bin_means, bin_residual_stds, 'o-')

    xrange = plt.xlim()
    yrange = plt.ylim()
    max_val = max(xrange[1], yrange[1])
    min_val = min(xrange[0], yrange[0])
    plt.axline((min_val, min_val), (max_val, max_val), linestyle='--', color='k', linewidth=2)

    plt.xlabel('Mean Predicted Standard Deviation')
    plt.ylabel('Actual Standard Deviation of Residuals')
    plt.title('Spread Calibration Plot')
    plt.grid(True)
    plt.show()

I generated some synthetic data to prove that this standard deviation calibration plot works as expected like this:

# Number of samples
n_samples = 1000

# Input feature
x = np.random.uniform(-10, 10, size=n_samples)

# True mean and standard deviation as functions of the input feature
true_mean = 2 * x + 3
true_std = 0.5 * np.abs(x) + 1

# Generate synthetic data
y_true = np.random.normal(loc=true_mean, scale=true_std)

# Simulate model predictions (with some error)
y_pred_mean = true_mean + np.random.normal(loc=0, scale=1, size=n_samples)
y_pred_std = true_std + np.random.normal(loc=0, scale=0.5, size=n_samples)

# Ensure standard deviations are positive
y_pred_std = np.abs(y_pred_std)

df = pd.DataFrame({
    'y_true': y_true,
    'y_pred_mean': y_pred_mean,
    'y_pred_std': y_pred_std
})

create_stdev_calibration_plot(df)

Here's what the calibration looks like with the synthetic data: enter image description here

When I run the same function on the output data from my model the plot looks like this: enter image description here

Based on the calibration plot. It looks like the model is NOT learning the spread, but just learning the mean and keeping the spread tight to minimize the loss. What changes can I make to my training to incentivize the model to accurately learn the spread?

Update:

One thought I had was to create a custom loss function that is based off of the average expected calibration error from both the mean and spread calibrations. However, the inputs for loss functions are the y_true tensor and the y_pred tensor from the model. The y_pred would just be samplings from the current learned distribution(s) and I wouldn't be able to know the distribution parameters (loc and scale); that makes the spread calibration impossible. Also expected calibration error is not differentiable due to the binning required, so that makes learning with back propagation impossible as well.

Update 2:

I'm currently looking into changing the loss function to be the negative log likelihood (NLL). I'll have the "learned" distribution parameters so I can just calculate the loss based on the NLL for each data point against the "learned" distributions. I'm not confident this will work though because the NLL for only 1 data point (1 per row and distribution combination) might just do the same thing as MSE, since the NLL is maximized when a single data point equals the distribution mean.


Solution

  • Your biggest issue is the loss function you're using. Minimizing MSE will only focus on minimizing the loss between the mean predicted values and training data. If you switch over to using negative log likelihood, as suggested in your Update 2, you're minimizing a loss function that goes lower the better your learned probability distribution fits the training data. See this tutorial, and notice how they're using negative log likelihood for their loss function.