pythontensorflowkeraslstmlstm-stateful

Is there a way to pass along temporal weights to a loss function?


Background

Currently, I'm using an LSTM to perform a regression. I'm using small batch sizes with a reasonably large amount of timesteps (but much, much fewer than the number of timesteps I have).

I'm attempting to transition to larger batches with fewer timesteps, but with stateful enabled, to allow a larger amount of generated training data to be used.

However, I am currently using a regularization based off of sqrt(timestep), (this is ablation tested and helps with convergence speed, it works because of the statistical nature of the problem, expected error decreases by a factor of sqrt(timestep)). This is performed by using tf.range to generate a list of the proper size within the loss function. This approach will not be correct when stateful is enabled, because it will be counting the wrong number of timesteps (number of timesteps in this batch, rather than seen so far overall).

Question

Is there a way to pass an offset or list of ints or floats to the loss function? Preferably without modifying the model, but I recognize that a hack of this nature might be required.

Code

Simplified model:

def create_model():    
    inputs = Input(shape=(None,input_nodes))
    next_input = inputs
    for i in range(dense_layers):
        dense = TimeDistributed(Dense(units=dense_nodes,
                activation='relu',
                kernel_regularizer=l2(regularization_weight),
                activity_regularizer=l2(regularization_weight)))\
            (next_input)
        next_input = TimeDistributed(Dropout(dropout_dense))(dense)

    for i in range(lstm_layers):
        prev_input = next_input
        next_input = LSTM(units=lstm_nodes,
                dropout=dropout_lstm,
                recurrent_dropout=dropout_lstm,
                kernel_regularizer=l2(regularization_weight),
                recurrent_regularizer=l2(regularization_weight),
                activity_regularizer=l2(regularization_weight),
                stateful=True,
                return_sequences=True)\
            (prev_input)
        next_input = add([prev_input, next_input])

    outputs = TimeDistributed(Dense(output_nodes,
            kernel_regularizer=l2(regularization_weight),
            activity_regularizer=l2(regularization_weight)))\
        (next_input)

    model = Model(inputs=inputs, outputs=outputs)

Loss function

def loss_function(y_true, y_pred):
    length = K.shape(y_pred)[1]

    seq = K.ones(shape=(length,))
    if use_sqrt_loss_scaling:
        seq = tf.range(1, length+1, dtype='int32')
        seq = K.sqrt(tf.cast(seq, tf.float32))

    seq = K.reshape(seq, (-1, 1))

    if separate_theta_phi:
        angle_loss = phi_loss_weight * phi_metric(y_true, y_pred, angle_loss_fun)
        angle_loss += theta_loss_weight * theta_metric(y_true, y_pred, angle_loss_fun)
    else:
        angle_loss = angle_loss_weight * total_angle_metric(y_true, y_pred, angle_loss_fun)

    norm_loss = norm_loss_weight * norm_loss_fun(y_true, y_pred)
    energy_loss = energy_loss_weight * energy_metric(y_true, y_pred)
    stability_loss = stability_loss_weight * stab_loss_fun(y_true, y_pred)
    act_loss = act_loss_weight * act_loss_fun(y_true, y_pred)

    return K.sum(K.dot(0
        + angle_loss
        + norm_loss
        + energy_loss
        + stability_loss
        + act_loss
        , seq))

(The functions that calculate the pieces of the loss function shouldn't be super duper relevant. Simply, they're also loss functions.)


Solution

  • For that purpose, you can use sample_weight argument of fit method and pass sample_weight_mode='temporal' to compile method so that you can assign a weight to each timestep of each sample in the batch:

    model.compile(..., sample_weight_mode='temporal')
    model.fit(..., sample_weight=sample_weight)
    

    The sample_weight should be an array of shape (num_samples, num_timesteps).

    Note that if you are using an input data generator or an instance of Sequence, instead you need to pass the sample weights as the third element of generated tuple/list in the generator or Sequence instance.