tensorflowtf.kerastensorflow-estimator

Two sets of shared embeddings from one tensorflow feature?


How can I create two sets of shared embeddings from the same tensorflow feature columns?

This small example

import tensorflow as tf

data = {"A": [0, 1, 2, 3], "B": [2, 1, 0, 3]}


def add_label(example):
    return example, 1


def input_fn():
    dset = tf.data.Dataset.from_tensor_slices(data).map(add_label).batch(2)
    return dset


def model_fn(features, labels, mode, params):
    colA = tf.feature_column.categorical_column_with_vocabulary_list("A", [0, 1, 2, 3])
    colB = tf.feature_column.categorical_column_with_vocabulary_list("B", [0, 1, 2, 3])

    model1_embedddings = tf.feature_column.shared_embeddings(categorical_columns=[colA, colB], dimension=2)
    X1 = tf.keras.layers.DenseFeatures(model1_embedddings)(features)
    output1_output = tf.reduce_sum(X1, axis=1)

    with tf.compat.v1.variable_scope("other", reuse=False):
        model2_embedddings = tf.feature_column.shared_embeddings(categorical_columns=[colA, colB], dimension=2)
    X2 = tf.keras.layers.DenseFeatures(model2_embedddings)(features)
    output2_output = tf.reduce_sum(X2, axis=1)

    loss = tf.losses.mean_squared_error(labels, output1_output + output2_output)

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.01)
    train_op = optimizer.minimize(loss=loss)
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)


estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn=input_fn)

crashes with

ValueError: Variable A_B_shared_embedding already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?

It seems one should be able to use variable_scope or name_scope to get it to work, but so far no luck.


Solution

  • There is an option to shared_embeddings to set a new embeddings collection name

    model2_embedddings = tf.feature_column.shared_embeddings(
     categorical_columns=[colA, colB], dimension=2,
     shared_embedding_collection_name="other")