pythontensorflowkerasgraph-neural-network

How to plot Graph-neural-network model-graph when using tensorflow Model Subclass API with spektral layers?


I am unable to plot graph-neural-networking. I have seen few related questions(1, 2, 3) to this topic but their answers do not apply to graph-neural-networks.

What makes it different is that the input vector include objects of different dimensions e.g. properties matrix dimension is [n_nodes, n_node_features], adjacency matrix dimension is [n_nodes, n_nodes] etc. Here is the example of my Model:

class GIN0(Model):
    def __init__(self, channels, n_layers):
        super().__init__()
        self.conv1 = GINConv(channels, epsilon=0, mlp_hidden=[channels, channels])
        self.convs = []
        for _ in range(1, n_layers):
            self.convs.append(
                GINConv(channels, epsilon=0, mlp_hidden=[channels, channels])
            )
        self.pool = GlobalAvgPool()
        self.dense1 = Dense(channels, activation="relu")
        self.dropout = Dropout(0.5)
        self.dense2 = Dense(channels, activation="relu")

    def call(self, inputs):
        x, a, i = inputs
        x = self.conv1([x, a])
        for conv in self.convs:
            x = conv([x, a])
        x = self.pool([x, i])
        x = self.dense1(x)
        x = self.dropout(x)
        return self.dense2(x)

One of the answers in 2 suggested to add build_graph function as follows:

class my_model(Model):
    def __init__(self, dim):
        super(my_model, self).__init__()
        self.Base  = VGG16(input_shape=(dim), include_top = False, weights = 'imagenet')
        self.GAP   = L.GlobalAveragePooling2D()
        self.BAT   = L.BatchNormalization()
        self.DROP  = L.Dropout(rate=0.1)
        self.DENS  = L.Dense(256, activation='relu', name = 'dense_A')
        self.OUT   = L.Dense(1, activation='sigmoid')
    
    def call(self, inputs):
        x  = self.Base(inputs)
        g  = self.GAP(x)
        b  = self.BAT(g)
        d  = self.DROP(b)
        d  = self.DENS(d)
        return self.OUT(d)
    
    # AFAIK: The most convenient method to print model.summary() 
    # similar to the sequential or functional API like.
    def build_graph(self):
        x = Input(shape=(dim))
        return Model(inputs=[x], outputs=self.call(x))

dim = (124,124,3)
model = my_model((dim))
model.build((None, *dim))
model.build_graph().summary()

However, I am not sure how to define dim or Input Layer using tf.keras.layers.Input for such a hybrid data-structure as described above.

Any suggestions?


Solution

  • Here is the minimal code to plot such subclass multi-input model. Note, as stated in the comment above, there are some issue of your GINConv which is from spektral and it's not related to the main query. So, I will give general soluton of such multi-input modeling scenarios. To make it work with your speckral, please reach to the package author for further discussion.


    From specktral repo, here, I got the idea the shape of the input tensors.

    x, y = next(iter(loader_tr))
    
    bs_x = list(x[0].shape)
    bs_y = list(x[1].shape)
    bs_z = list(x[2].shape)
    
    bs_x, bs_y, bs_z
    ([1067, 4], [1067, 1067], [1067])
    

    Similar model, it also takes same amount of inputs and with same shape. But without GINConv.

    class GIN0(Model):
        def __init__(self, channels, n_layers):
            super().__init__()
            self.conv1 = tf.keras.layers.Conv1D(channels, 3, activation='relu')
            self.conv2 = tf.keras.layers.Conv1D(channels, 3, activation='relu')
    
            self.dense1 = Dense(channels, activation="relu")
            self.dropout = Dropout(0.5)
            self.dense2 = Dense(n_out, activation="softmax")
    
        def call(self, inputs):
            x, a, i = inputs
    
            x = self.conv1(x)
            x = tf.keras.layers.GlobalAveragePooling1D()(x)
            a = self.conv2(a)
            a = tf.keras.layers.GlobalAveragePooling1D()(a)
    
            x = tf.keras.layers.Concatenate(axis=1)([a, x, i])
            x = self.dense1(x)
            x = self.dropout(x)
            return self.dense2(x)
        
        def build_graph(self):
            x = tf.keras.Input(shape=bs_x)
            y = tf.keras.Input(shape=bs_y)
            z = tf.keras.Input(shape=bs_z)
            return tf.keras.Model(
                inputs=[x, y, z], 
                outputs=self.call([x, y, z])
            )
    
    model = GIN0(channels, layers)
    model.build(
        [
            (None, *bs_x), 
            (None, *bs_y), 
            (None, *bs_z)
        ]
    )
    
    # OK
    model.build_graph().summary()
    
    # OK
    tf.keras.utils.plot_model(
        model.build_graph(), show_shapes=True
    )