pythontensorflowmachine-learningkerasdeep-learning

How to concatenate a tensor to a keras layer along batch (without specifying batch size)?


I want to concatenate the output from an embedding layer with a custom tensor (myarr / myconst). I can specify everything with a fixed batch size like follows:

import numpy as np
import tensorflow as tf

BATCH_SIZE = 100
myarr = np.ones((10, 5))
myconst = tf.constant(np.tile(myarr, (BATCH_SIZE, 1, 1)))

# Model definition
inputs = tf.keras.layers.Input((10,), batch_size=BATCH_SIZE)
x = tf.keras.layers.Embedding(10, 5)(inputs)
x = tf.keras.layers.Concatenate(axis=1)([x, myconst])
model = tf.keras.models.Model(inputs=inputs, outputs=x)

However, if I don't specify batch size and tile my array, i.e. just the following...

myarr = np.ones((10, 5))
myconst = tf.constant(myarr)

# Model definition
inputs = tf.keras.layers.Input((10,))
x = tf.keras.layers.Embedding(10, 5)(inputs)
x = tf.keras.layers.Concatenate(axis=1)([x, myconst])
model = tf.keras.models.Model(inputs=inputs, outputs=x)

... I get an error specifying that shapes [(None, 10, 5), (10, 5)] can't be concatenated. Is there a way to add this None / batch_size axis to avoid tiling?

Thanks in advance


Solution

  • You want to concatenate to a 3D tensor of shape (batch, 10, 5) a constant of shape (10, 5) along the batch dimensionality. To do this your constant must be 3D. So you have to reshape it in (1, 10, 5) and repeat it along the axis=0 in order to match the shape (batch, 10, 5) and operate a concatenation.

    We do this inside a Lambda layer:

    X = np.random.randint(0,10, (100,10))
    Y = np.random.uniform(0,1, (100,20,5))
    
    myarr = np.ones((1, 10, 5)).astype('float32')
    myconst = tf.convert_to_tensor(myarr)
    
    def repeat_const(tensor, myconst):
        shapes = tf.shape(tensor)
        return tf.repeat(myconst, shapes[0], axis=0)
    
    inputs = tf.keras.layers.Input((10,))
    x = tf.keras.layers.Embedding(10, 5)(inputs)
    xx = tf.keras.layers.Lambda(lambda x: repeat_const(x, myconst))(x)
    x = tf.keras.layers.Concatenate(axis=1)([x, xx])
    model = tf.keras.models.Model(inputs=inputs, outputs=x)
    model.compile('adam', 'mse')
    
    model.fit(X, Y, epochs=3)