kerasdeep-learningkeras-layertensorflow2.0tf.keras

Running the Tensorflow 2.0 code gives 'ValueError: tf.function-decorated function tried to create variables on non-first call'. What am I doing wrong?


error_giving_notebook

non_problematic_notebook

As it can be seen that I have used tf.function decorator in the 'error_giving_notebook' and it throws a ValueError while the same notebook without any changes except for removing the tf.function decorator runs smoothly in 'non_problematic_notebook'. What can be the reason?


Solution

  • The problem here is in the return values of the call method of class conv2d:

    if self.bias:
      if self.pad == 'REFLECT':
        self.p = (self.filter_size - 1) // 2
        self.x = tf.pad(inputs, [[0, 0], [self.p, self.p], [self.p, self.p], [0, 0]], 'REFLECT')
        return Conv2D(filters=self.filter_num, kernel_size=(self.filter_size, self.filter_size), strides=(self.stride, self.stride),
                                      padding='VALID', use_bias=True, kernel_initializer=self.w, bias_initializer=self.b)(self.x)
      else:
        return Conv2D(filters=self.filter_num, kernel_size=(self.filter_size, self.filter_size), strides=(self.stride, self.stride),
                                      padding=self.pad, use_bias=True, kernel_initializer=self.w, bias_initializer=self.b)(inputs)
    else:
       if self.pad == 'REFLECT':
          self.p = (self.filter_size - 1) // 2
          self.x = tf.pad(inputs, [[0, 0], [self.p, self.p], [self.p, self.p], [0, 0]], 'REFLECT')
          return Conv2D(filters=self.filter_num, kernel_size=(self.filter_size, self.filter_size), strides=(self.stride, self.stride),
                                      padding='VALID', use_bias=False, kernel_initializer=self.w)(self.x)
       else:
          return Conv2D(filters=self.filter_num, kernel_size=(self.filter_size, self.filter_size), strides=(self.stride, self.stride),
                                      padding=self.pad, use_bias=False, kernel_initializer=self.w)(inputs)
    

    By returning a Conv2D object tf.Variable(s) are created (weights, bias of conv layer) each time you call

    predictions = model(images)
    

    in your tf-decorated function. Hence, the exception.

    One possible way to solve this problem is by changing the build and call method in your conv2d class as follow:

    def build(self, inputs):
      self.w = tf.random_normal_initializer(mean=0.0, stddev=1e-4)
      if self.bias:
        self.b = tf.constant_initializer(0.0)
      else:
        self.b = None
    
      self.conv_a = Conv2D(filters=self.filter_num, kernel_size=(self.filter_size, self.filter_size), strides=(self.stride, self.stride), padding='VALID', use_bias=True, kernel_initializer=self.w, bias_initializer=self.b)
      self.conv_b = Conv2D(filters=self.filter_num, kernel_size=(self.filter_size, self.filter_size), strides=(self.stride, self.stride), padding=self.pad, use_bias=True, kernel_initializer=self.w, bias_initializer=self.b)
      self.conv_c = Conv2D(filters=self.filter_num, kernel_size=(self.filter_size, self.filter_size), strides=(self.stride, self.stride), padding='VALID', use_bias=False, kernel_initializer=self.w)
      self.conv_d = Conv2D(filters=self.filter_num, kernel_size=(self.filter_size, self.filter_size), strides=(self.stride, self.stride),padding=self.pad, use_bias=False, kernel_initializer=self.w)  
    
    def call(self, inputs):
      if self.bias:
        if self.pad == 'REFLECT':
          self.p = (self.filter_size - 1) // 2
          self.x = tf.pad(inputs, [[0, 0], [self.p, self.p], [self.p, self.p], [0, 0]], 'REFLECT')
          return self.conv_a(self.x)
        else:
          return self.conv_b(inputs)
      else:
         if self.pad == 'REFLECT':
            self.p = (self.filter_size - 1) // 2
            self.x = tf.pad(inputs, [[0, 0], [self.p, self.p], [self.p, self.p], [0, 0]], 'REFLECT')
            return self.conv_c(self.x)
         else:
            return self.conv_d(inputs)
    

    To better understand AutoGraph and how @tf.function works I suggest taking a look at this