machine-learningkerasneural-networksymmetry

Enforcing symmetry in Keras


I am looking for a way to create a neural-network model in Keras for a function, which is symmetric with respect to interchange of its inputs. For simplicity lets assume, that the function of interest depends on two variables x,y and returns a scalar f=f(x,y). Furthermore, we know that f(x,y)=f(y,x) holds for any x,y. What would be the method of choice in order to ensure, that this symmetry is exactly reproduced by my Keras neural-network model?

Clearly, I could train the model with symmetrical data but what I am looking for is a way to "hardcode" this symmetry into the model.

I know, this question seems to be really basic. Sorry, if there is an obvious answer to this question which I have overlooked and thank you in advance for your help!


Solution

  • From your question, it seems that what you are looking for is a convenient way to have a layer or a set of layers, with shared weights, applied to the inputs in both forward and reverse order.

    i.e. similar to how a convolution identifies a pattern through a set of time steps but considering the input buffer to be circular.

    A convenient way to achieve this would be to put your special 'convolution' layer inside a reusable aux model and then max pool the results. Something like the following:

    from tensorflow import keras
    from tensorflow.keras.layers import *
    from tensorflow.keras.models import Model
    from tensorflow.keras import backend as K
    
    def make_inner_model():
      inp = Input(shape=(2,))
      h1 = Dense(8, activation='relu')(inp)
      out = Dense(1)(h1)
      model = Model(inp, out)
      return model
    
    def make_model(inner_model):
      inp = Input(shape=(2,))
      rev = Lambda(lambda x: K.concatenate([x[:, 1:], x[:, 0:1]], axis=1))(inp)
      r1 = inner_model(inp)
      r2 = inner_model(rev)
      out = Maximum()([r1, r2])
      model = Model(inp, out)
      model.compile('adam', 'mse')
      return model
    
    inner = make_inner_model()
    model = make_model(inner)
    model.summary()