pythontensorflowkeras

A KerasTensor cannot be used as input to a TensorFlow function


I have been following a machine-learning book by Chollet and I keep getting this error in this block of code, specifically in the 3rd line. It seems I am passing a Keras tensor into a tf function but I don't know how to get around this.

import tensorflow as tf
inputs = keras.Input(shape=(None,), dtype="int64")
embedded = tf.one_hot(inputs, depth=max_tokens)
x = layers.Bidirectional(layers.LSTM(32))(embedded)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer="rmsprop",
              loss="binary_crossentropy",
              metrics=["accuracy"])
model.summary()

I tried to follow the solution in the error message which told me to create a new class but didn't know how to do it.


Solution

  • Wrap the TensorFlow function (tf.one_hot) in a layer. And, replace the call to the function with a call to the new layer.

    eg:

    class EmbeddedLayer(keras.Layer):
        def call(self, x):
            return tf.one_hot(x, depth=max_tokens)
    
    inputs = keras.Input(shape=(None,), dtype="int64")
    embedded = EmbeddedLayer()(inputs)
    x = layers.Bidirectional(layers.LSTM(32))(embedded)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer="rmsprop",
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )
    model.summary()