I'm trying to get batch_size
in call()
function in TF2 model.
However, I cannot get it because all the methods I know returns None
or Tensor instead of dimension tuple.
Here is a short example
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
def call(self, x):
print(len(x))
print(x.shape)
print(tf.size(x))
print(np.shape(x))
print(x.get_shape())
print(x.get_shape().as_list())
print(tf.rank(x))
print(tf.shape(x))
print(tf.shape(x)[0])
print(tf.shape(x)[1])
return tf.random.uniform((2, 10))
m = MyModel()
m.compile(optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
m.fit(np.array([[1,2,3,4], [5,6,7,8]]), np.array([0, 1]), epochs=1)
The output is:
Tensor("my_model_26/strided_slice:0", shape=(), dtype=int32)
(None, 4)
Tensor("my_model_26/Size:0", shape=(), dtype=int32)
(None, 4)
(None, 4)
[None, 4]
Tensor("my_model_26/Rank:0", shape=(), dtype=int32)
Tensor("my_model_26/Shape_2:0", shape=(2,), dtype=int32)
Tensor("my_model_26/strided_slice_1:0", shape=(), dtype=int32)
Tensor("my_model_26/strided_slice_2:0", shape=(), dtype=int32)
1/1 [==============================] - 0s 1ms/step - loss: 3.1796 - accuracy: 0.0000e+00
I fed (2,4)
numpy array as input and (2, )
as target to the model in this example.
But as you can see, I cannot get batch_size
in call()
function.
The reason I need it is because I have to iterate tensors for batch_size
which is dynamic in my real model.
For example, if the dataset size is 10 and batch size is 3, then the last batch size in last batch would be 1. So, I have to know batch size dynamically.
Can anyone help me?
It's because you're using TensorFlow (that's mandatory since Keras is now inside TensorFlow), and by using TensorFlow you need to be aware of the "compilation" of the dynamic graph into a static-graph.
In short, your call
method is (under the hood) decorated with the @tf.function
decorator.
This decorator:
if a > b
becomes tf.cond(tf.greater(a,b), something, something_else)
)tf.Graph
(the static graph)Al your print
calls are executed during the first step (the python execution tracing), that's why even if you train your model you see the output only 1 time.
That said, to get the runtime (dynamic shape) of a tensor, you must use tf.shape(x)
, the batch size is just batch_size = tf.shape(x)[0]
Please note that if you want to see the shape (using print) you can't use print, but you must use tf.print
.
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
def call(self, x):
shape = tf.shape(x)
batch_size = shape[0]
tf.print(shape, batch_size)
return tf.random.uniform((2, 10))
m = MyModel()
m.compile(
optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
m.fit(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), np.array([0, 1]), epochs=1)
More information about static and dynamic shapes: https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/
More info about the tf.function behavior: https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/
Note: I wrote these articles.