tensorflowtensorflow-hubelmo

FailedPreconditionError: Error while reading resource variable module/bilm/CNN_proj/W_proj from Container: localhost


I am trying to use pre-trained elmo embeddings in jupyter notebook with python 3.7. Tensorflow version - 1.14.0

This is my code

def ElmoEmbeddingLayer(x):
    print(x.shape)
    module = hub.Module("https://tfhub.dev/google/elmo/3", trainable=False)
    embeddings = module(tf.squeeze(tf.cast(x, tf.string)), signature="default", as_dict=True)["elmo"]
    return embeddings
elmo_dim=1024
elmo_input = Input(shape=(None,), dtype=tf.string)
elmo_embedding = Lambda(ElmoEmbeddingLayer, output_shape=(None,elmo_dim))(elmo_input)
x = Dense(1)(elmo_embedding)
x = Activation('relu')(x)
model = Model(inputs=[elmo_input], outputs=x)
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1,validation_data=(x_test, y_test))

However I'm getting a runtime error that is

FailedPreconditionError: Error while reading resource variable module/bilm/CNN_proj/W_proj from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/module/bilm/CNN_proj/W_proj/N10tensorflow3VarE does not exist. [[{{node lambda/module_apply_default/bilm/MatMul_9/ReadVariableOp}}]]


Solution

  • To use model pieces from TF Hub in building a Keras model, use the hub.KerasLayer class. It implements Keras's way of collecting variables for initialization. With tensorflow_hub 0.7.0 (and preferably tensorflow 1.15), you can also use it for older TF Hub modules (like the https://tfhub.dev/google/elmo/3 in your example), subject to some caveats, see tensorflow.org/hub/migration_tf2

    For context: The older hub.Module class is for building models in the classic TF1 way (like tf.layers). It implements the old-style way of collecting variables for initialization via the GLOBAL_VARIABLES collection of the tf.Graph. Those are missed in your case. (You could try to initialize them manually in the Session returned bytf.compat.v1.keras.backend.get_session(), but that's getting weird.)