keraskeras-3

Keras 3 equivalent of getting inbound layers of a given layer


I would like to get the inputs of a given layer linked to a previous layer.

My code with Keras 2.x was the following

def get_inputs(layer):
    """Get the inputs of a layer."""
    inputs = []
    for node in layer.inbound_nodes:
        for inbound_layer, _, tensor_index, _ in node.iterate_inbound():
            try:
                activation = inbound_layer.activation.__name__
            except AttributeError:
                activation = None
            inputs.append((inbound_layer.name, activation, tensor_index))

    return inputs

For the following model, I had this result:

from tensorflow import keras

input1 = keras.Input(shape=(5,2,), name='input1')
x1 = keras.layers.LSTM(3, name='LSTM1')(input1)
x2 = keras.layers.LSTM(3, name='LSTM2', go_backwards=True)(input1)
input2 = keras.Input(shape=(10,7,), name='input2')
x3 = keras.layers.LSTM(3, name='LSTM3', return_state=True, return_sequences=True)(input2, initial_state=[x1, x2])
x4 = keras.layers.LSTM(7, name='LSTM4', activation='exponential', recurrent_activation='relu')(x3[0])
model = keras.Model(inputs=[input1, input2], outputs=[x4], name='model_lstm')

print("INPUTS OF EACH LAYER")
for layer in model.layers:
    print(f"{layer.name} => {get_inputs(layer)}")

# INPUTS OF EACH LAYER
# input1 => []
# input2 => []
# LSTM1 => [('input1', None, 0)]
# LSTM2 => [('input1', None, 0)]
# LSTM3 => [('input2', None, 0), ('LSTM1', 'tanh', 0), ('LSTM2', 'tanh', 0)]
# LSTM4 => [('LSTM3', 'tanh', 0)]

Image

In Keras 3.x, I have the following error: AttributeError: 'InputLayer' object has no attribute 'inbound_nodes'.

What is the equivalent of these lines in Keras 3?

for node in layer.inbound_nodes:
        for inbound_layer, _, tensor_index, _ in node.iterate_inbound():
           ...

I tried to change as follows, but it detects all the inputs of the last layer (LSTM4) and not the connected ones only (i.e. we should just have one input ('LSTM3', 'tanh', 0) and not 3):

def get_inputs_k3(layer):
    """Get the inputs of a layer."""
    inputs = []
    # layer.inbound_nodes has been removed in Keras 3.0
    # an equivalent is `layer._inbound_nodes`, but it may return many nodes with same name
    unique_nodes = {n.operation.name: n for n in layer._inbound_nodes}
    for node in unique_nodes.values():
        for inbound_layer in node.parent_nodes:
            for tensor in inbound_layer.output_tensors:
                tensor_index = tensor._keras_history.tensor_index
                try:
                    activation = inbound_layer.operation.activation.__name__
                except AttributeError:
                    activation = None
                inputs.append((inbound_layer.operation.name, activation, tensor_index))

    return inputs

# INPUTS OF EACH LAYER
# input1 => []
# input2 => []
# LSTM1 => [('input1', None, 0)]
# LSTM2 => [('input1', None, 0)]
# LSTM3 => [('input2', None, 0), ('LSTM1', 'tanh', 0), ('LSTM2', 'tanh', 0)]
# LSTM4 => [('LSTM3', 'tanh', 0), ('LSTM3', 'tanh', 1), ('LSTM3', 'tanh', 2)]

Solution

  • You can use the _keras_history attribute which tracks the history of operations that created a given tensor. This attribute is an instance of a namedtuple called KerasHistory which stores:

    This can help identify the layer connections in the model.

    def get_inputs_k3(layer):
        inputs = []
    
        input_tensors = layer.input if isinstance(layer.input, (list, tuple)) else [layer.input]
        
        for tensor in input_tensors:
            if hasattr(tensor, '_keras_history'):
                keras_history = tensor._keras_history
                inbound_layer = keras_history.operation  # Getting the layer
                tensor_index = keras_history.tensor_index # Getting the tensor index
    
                try:
                    activation = inbound_layer.activation.__name__ 
                except AttributeError:
                    activation = None
    
                inputs.append((inbound_layer.name, activation, tensor_index))
    
        return inputs