I already asked this question here, but I thought StackOverflow would have more traffic/people that might know the answer.
I'm building a custom keras Layer similar to an example found here. I want the call
method inside the class to be able to know what the batch_size
of the inputs
data flowing through the method is, but the inputs.shape
is showing as (None, 3)
during model prediction. Here's a concrete example:
I initialize a simple data set like this:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
# Create fake data to use for model testing
n = 1000
np.random.seed(123)
x1 = np.random.random(n)
x2 = np.random.normal(0, 1, size=n)
x3 = np.random.lognormal(0, 1, size=n)
X = pd.DataFrame(np.concatenate([
np.reshape(x1, (-1, 1)),
np.reshape(x2, (-1, 1)),
np.reshape(x3, (-1, 1)),
], axis=1))
Then I define a custom class to test/show what I'm talking about:
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
print(inputs)
record_count, n = inputs.shape
print(f'inputs.shape = {inputs.shape}')
return inputs
Then, when I create a simple model and force it to do a forward pass...
input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])
... I get this output printed to the screen
model.predict(X.loc[:9, :])
Tensor("model_1/Cast:0", shape=(None, 3), dtype=float32)
inputs.shape = (None, 3)
1/1 [==============================] - 0s 28ms/step
Out[34]:
array([[ 0.5335418 , 0.7788839 , 0.64132416],
[ 0.2924202 , -0.08321562, 0.412311 ],
[ 0.5118007 , -0.6822934 , 1.1782378 ],
[ 0.03780456, -0.19350041, 0.7637337 ],
[ 0.86494124, -3.196387 , 4.8535166 ],
[ 0.26708454, -0.49397194, 0.91296834],
[ 0.49734482, -1.6618049 , 0.50054324],
[ 0.8563762 , 0.7956695 , 0.29466265],
[ 0.7682351 , 0.86538637, 0.6633331 ],
[ 0.85322225, 0.868021 , 0.1776046 ]], dtype=float32)
You can see that during the model.predict
call the inputs.shape
prints out a value of (None, 3)
, but obviously that's not true since the call
method returns an output with a shape of (10, 3)
. How can I capture the 10
value in this example while in the call
method?
When I use tf.shape
as suggested in the current answer, I can print the value to the screen, but I get an error when I try to capture that value in a variable.
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
record_count, n = tf.shape(inputs)
tf.print("Dynamic batch size", tf.shape(inputs)[0])
return inputs
This code causes an error on the record_count, ...
line.
Traceback (most recent call last):
File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-22-104d812c32e6>", line 1, in <module>
test = TestClass()(input_layer)
File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 692, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Exception encountered when calling layer "test_class_4" (type TestClass).
in user code:
File "<ipython-input-21-2dec1d5b9547>", line 12, in call *
record_count, n = tf.shape(inputs)
OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Call arguments received by layer "test_class_4" (type TestClass):
• inputs=tf.Tensor(shape=(None, 3), dtype=float32)
I tried decorating the call
method with @tf.function
, but I get the same error.
I tried a couple other things and found that, oddly, tensorflow doesn't seem to like the tuple assignment. It seems to work fine if it's coded like this instead.
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
shape = tf.shape(inputs)
record_count = shape[0]
n = shape[1]
tf.print("Dynamic batch size", tf.shape(inputs)[0])
return inputs
TL;DR --> Use tf.shape(inputs)[0]
if you want to capture dynamic batch size in call
method, or you can just use static batch size which can be specified in model creation.
Under the hood TensorFlow decorates call
and __call__
(that's what call
method calls) method with tf.function
. Using print
and .shape
will not work as expected.
With tf.function
python codes are traced and converted to native TensorFlow operations. After that, a static graph is created, this is just an instance of tf.Graph. In the end, the operations are executed in that graph.
Python's print
function only considered in the first step only, so this is not the correct way to print things in graph mode (decorated with tf.function
).
Tensor shapes are dynamic in runtime so you need to use tf.shape(inputs)[0]
which will give you the batch size for that batch.
If you really want to see that 10
in call
:
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
tf.print("Dynamic batch size", tf.shape(inputs)[0])
return inputs
Running:
input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])
Will return:
Dynamic batch size 10
1/1 [==============================] - 0s 65ms/step
array([[ 6.9646919e-01, -1.0032653e-02, 3.7556963e+00],
[ 2.8613934e-01, -8.4564441e-01, 9.9685013e-01],
[ 2.2685145e-01, 9.1146064e-01, 6.5008003e-01],
[ 5.5131477e-01, -1.3744969e+00, 8.6379850e-01],
[ 7.1946895e-01, -5.4706562e-01, 3.1904945e+00],
[ 4.2310646e-01, -7.5526608e-05, 5.2649558e-01],
[ 9.8076421e-01, -1.2116680e-01, 7.4064606e-01],
[ 6.8482971e-01, -2.0085855e+00, 5.3138912e-01],
[ 4.8093191e-01, -9.2064655e-01, 8.1520426e-01],
[ 3.9211753e-01, 1.6823435e-01, 1.2382457e+00]], dtype=float32)