I am trying to create a custom objective function in Keras (tensorflow backend) with an additional parameter whose value would depend on the batch being trained.
Eg:
def myLoss(self, stateValues):
def sparse_loss(y_true, y_pred):
foo = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
return tf.reduce_mean(foo * stateValues)
return sparse_loss
self.model.compile(loss=self.myLoss(stateValue = self.stateValue),
optimizer=Adam(lr=self.alpha))
My train function is as follows
for batch in batches:
self.stateValue = computeStateValueVectorForCurrentBatch(batch)
model.fit(xVals, yVals, batch_size=<num>)
However, the stateValue in the loss function is not being updated. It is just using the value stateValue has at model.compile step.
I guess this could be solved by using a placeHolder for stateValue but I am unable to figure out how to do it. Can someone please help?
Your loss function is not getting updated because keras doesn't compile the model after each batch and therefore is not using the updated loss function.
You can define a custom callback which would update the value of loss after each batch. Something like this:
from keras.callbacks import Callback
class UpdateLoss(Callback):
def on_batch_end(self, batch, logs={}):
# I am not sure what is the type of the argument you are passing for computing stateValue ??
stateValue = computeStateValueVectorForCurrentBatch(batch)
self.model.loss = myLoss(stateValue)