pythontensorflowclassooptf.keras

Aren't instances in __init___() reuseable?


So I'm currently learning how to create a model in Tensorflow using subclassing. According to the tutorial, the following snippet of code should run perfectly:

#Defining the class
class FeatureExtractor(Model):
    def __init__(self):
        super().__init__()

        self.conv_1 = Conv2D(filters = 6, kernel_size = 4, padding = "valid", activation = "relu")
        self.batchnorm_1 = BatchNormalization()
        self.maxpool_1 = MaxPool2D(pool_size = 2, strides=2)

        self.conv_2 = Conv2D(filters = 16, kernel_size = 4, padding = "valid", activation = "relu")
        self.batchnorm_2 = BatchNormalization()
        self.maxpool_2 = MaxPool2D(pool_size = 2, strides=2)


    def call(self, x):
        x = self.conv_1(x)
        x = self.batchnorm_1(x)
        x = self.maxpool_1(x)

        x = self.conv_2(x)
        x = self.batchnorm_2(x)
        x = self.maxpool_2(x)

        return x

#Calling and using the class
feature_extractor = FeatureExtractor()

func_input = Input(shape=(IMG_SIZE, IMG_SIZE, 3), name="Input_Image")

x = feature_extractor(func_input)

And it does indeed run flawlessly. But then I realized that in __init__(), the BatchNormalization() and MaxPool2D() look the same but are defined twice, so I edited it and hence have this class being:

#Defining the class
class FeatureExtractor(Model):
    def __init__(self):
        super().__init__()

        self.conv_1 = Conv2D(filters = 6, kernel_size = 4, padding = "valid", activation = "relu")
        #Defining batchnorm and maxpool only once
        self.batchnorm = BatchNormalization()
        self.maxpool = MaxPool2D(pool_size = 2, strides=2)

        self.conv_2 = Conv2D(filters = 16, kernel_size = 4, padding = "valid", activation = "relu")


    def call(self, x):
        x = self.conv_1(x)
        x = self.batchnorm(x)
        x = self.maxpool(x)

        x = self.conv_2(x)
        x = self.batchnorm(x)
        x = self.maxpool(x)

        return x

But then, I was thrown with a dimension error.

I thought instances in __init__() are reuseable? Is it because when layers are called in call(), they adapt to the dimension of the input, and then keep that dimension for later calls?

Thank you in advance for your answers. I'm still inexperienced in both Python and Tensorflow, so this might be some basic stuff that I overlooked when learning.


Solution

  • I thought instances in __init__() are reuseable?

    In general, objects are reusable, but, as tdelaney pointed out in the comments, you need to make sure that the objects state management doesn't get in your way. And this holds everywhere, not just in the __init__.

    To demonstrate, that reuse works in principle, let's define MaxPool2D only once and reuse it, but BatchNormalization twice. See here:

    from keras import Model, Input
    from keras.src.layers import Conv2D, BatchNormalization, MaxPool2D
    
    
    class FeatureExtractor(Model):
        def __init__(self):
            super().__init__()
    
            # Defining maxpool only once
            self.maxpool = MaxPool2D(pool_size=2, strides=2)
    
            self.conv_1 = Conv2D(filters=6, kernel_size=4, padding="valid", activation="relu")
            self.batchnorm1 = BatchNormalization()
    
            self.conv_2 = Conv2D(filters=16, kernel_size=4, padding="valid", activation="relu")
            self.batchnorm2 = BatchNormalization()
    
        def call(self, x):
            x = self.conv_1(x)
            x = self.batchnorm1(x)
            x = self.maxpool(x)
    
            x = self.conv_2(x)
            x = self.batchnorm2(x)
            x = self.maxpool(x)
    
            return x
    
    
    if __name__ == '__main__':
        # Calling and using the class
        feature_extractor = FeatureExtractor()
        func_input = Input(shape=(28, 28, 3), name="Input_Image")
        x = feature_extractor(func_input)
    
    

    This version reuses the maxpool member, but works fine.

    Is it because when layers are called in call(), they adapt to the dimension of the input, and then keep that dimension for later calls?

    This seems to be exactly what's happening here.