tensorflowfor-loopkerastensorflow-probability

Only Last TensorFlow Probability Layer Being Output (Multiple Times)


I'm building a TensorFlow Probability Bayesian network. In the example below, I've got a simple 2 distribution output, but both outputs are coming from the last distribution added to the network (ignoring any prior distributions added). Here's a concrete code example that shows what I'm talking about.

Import some packages, and some helper code

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
import tensorflow_probability as tfp
ZERO_BUFFER = 1e-5 

dist_lookup = {
    'normal': {
        'dist': tfp.distributions.Normal,
        'count': 2,
        'inputs': {
            'loc': False,
            'scale': True,
        }
    },
    'exponential': {
        'dist': tfp.distributions.Exponential,
        'count': 1,
        'inputs': {
            'rate': True,
        }
    }
}

Now let's create some fake data to use.

n = 100000
np.random.seed(123)
x1 = np.ones(shape=(n, 1))
x2 = 2 * np.ones(shape=(n, 1))
x3 = 3 * np.ones(shape=(n, 1))
X = pd.DataFrame(np.concatenate([x1, x2, x3], axis=1), columns=['x1', 'x2', 'x3']).astype(int)

Now let's build a toy model that demonstrates what I'm talking about. Notice that I'm trying to build the distribution layers with a for loop. If I manually build each distribution layer by typing it out, I don't get the weird behavior below. It only happens when I define it in a for loop, BUT I need to build a larger model with a dynamic number of distributions so I need to be able to build it using a loop of some kind.

def create_dist_lambda_kwargs(prior_input_count: int, input_dict: dict, t):
    kwargs = dict()
    for j, (param, use_softplus) in enumerate(input_dict.items()):
        x = prior_input_count + j
        if use_softplus:
            kwargs[param] = ZERO_BUFFER + tf.nn.softplus(t[..., prior_input_count + j])
        else:
            kwargs[param] = t[..., prior_input_count + j]
    return kwargs


input_layer = layers.Input(X.shape[1])
# distributions = ['exponential', 'normal']
distributions = ['normal', 'exponential']
dists = list()
reshapes = list()
total = 0
for i in range(len(distributions)):
    param_count = dist_lookup[distributions[i]]['count']
    dist_class = dist_lookup[distributions[i]]['dist']
    dists.append(
        tfp.layers.DistributionLambda(
            lambda t: dist_class(
                **create_dist_lambda_kwargs(
                    prior_input_count=total,
                    input_dict=dist_lookup[distributions[i]]['inputs'],
                    t=t,
                )
            )
        )(input_layer)
    )
    reshapes.append(layers.Reshape((1,))(dists[i])    )
    total += param_count
total = 0

output = layers.Concatenate()(reshapes)
model = Model(input_layer, output)
model.compile(loss='mse', optimizer='adam', metrics=['mae', 'mse'])

Oddly, if I remove the total = 0 line after the for loop above, the code above crashes. I'm assuming that's somehow related to the other issue below.

Now, if I make predictions with the input data (remember all of the rows of the input data are the same), then I should get a large sample from both of the output distributions that we can plot.

pred = model.predict(X)

fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_size_inches(10, 5)
for i, ax in enumerate((ax1, ax2)):
    ax.hist(pred[:, i], bins=50)
    ax.set_xlabel(f'Output{i + 1} Value')
    ax.set_title(f'Output{i + 1} Histogram')

If the 'exponential' value is the last in the distributions list, then the plot looks something like the image below; both outputs looks like an exponential distribution enter image description here

If the 'normal' value is the last in the distributions list, then the plot looks something like the image below; both outputs look like a normal distribution enter image description here

So, my question is WHY is the model build getting confused by the for loop and treating both outputs like the last distribution created in the for loop, and how can the code be fixed to make it work as expected?


Solution

  • I haven't read the whole question, but you're definitely hitting this when you create the lambda in the loop.