I would like to create a custom keras layer (a codebook for a VQVAE model.) While training I would like to have a tf.Variable
which tracks the usage of each code so I can restart unused codes. So I created my Codebook layer as follows...
class Codebook(layers.Layer):
def __init__(self, num_codes, code_reset_limit = None, **kwargs):
super().__init__(**kwargs)
self.num_codes = num_codes
self.code_reset_limit = code_reset_limit
if self.code_reset_limit:
self.code_counter = tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False)
def build(self, input_shape):
self.codes = self.add_weight(name = 'codes',
shape = (self.num_codes, input_shape[-1]),
initializer = 'random_uniform',
trainable = True)
super().build(input_shape)
The issue I have is that the Layer
class finds the member variable self.code_counter
and adds it to the list of weights which are saved with the layer. It also expects the self.code_counter
to be present when weights are loaded which is not the case when I run in inference mode. How can I make it so keras does not track a variable in my layer. I do not want it persisted or to be part of the layers.weights
.
I am a bit late with the answer, but I had the same problem and came across the question without an answer. Now, I have found an answer that works for Keras 2 and Keras 3, so I am sharing it here for others encountering the same question.
To prevent TensorFlow and Keras from tracking variables one needs to encapsulate the variable in a class that TensorFlow and Keras do not handle in the tracking module. The list of classes that are automatically tracked for Keras 3 are: keras.Variable
, list
, dict
, tuple
, and NamedTuple
(see here). For Keras 2 the list of objects is not so easy to find but appears to include tf.Variable
(see the present question), dict
, and list
.
The solution that did work in my context for keras.Variable and tf.Variable is to create dataclass encapsulating the Variable. Here the setup for tensorflow and keras 2.
import tensorflow as tf
from dataclasses import dataclass
@dataclass
class DoNotTrackContainer:
data: tf.Variable
In the code of the present question, this would then be used like this
if self.code_reset_limit:
self.code_counter = DoNotTrackContainer(data=tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False) )
When accessing the counter the data attribute needs to be included in the path
# for accessing the counter
self.code_counter.data.assign_add(1)
For Keras 3 the Container becomes
import keras
from dataclasses import dataclass
@dataclass
class DoNotTrackContainer:
data: keras.Variable