I am trying customize a layer in tensorflow. The layer has to take ragged tesnor with unidentified length as input. But the code is stuck when trying to build the layer. Even the simple code attached below could not work properly.
import tensorflow as tf
class myLayer(tf.keras.layers.Layer):
def __init__(self):
super(myLayer, self).__init__()
self._supports_ragged_inputs = True
def call(self, inputs):
# Try to loop over ragged tensor
for x in inputs:
pass
return tf.constant(0)
# Input is ragged tensor
inputs = tf.keras.layers.Input(shape=(None, 1), ragged=True)
layer1 = myLayer()
output = layer1(inputs)
When I ran your code in Tensorflow version 2.2.0
, I got the below error in the for
loop -
Error -
ValueError: in user code:
<ipython-input-24-1681d59017fc>:10 call *
for x in inputs:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:359 for_stmt
iter_, extra_test, body, get_state, set_state, symbol_names, opts)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:491 _tf_ragged_for_stmt
opts)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:885 _tf_while_stmt
aug_test, aug_body, init_vars, **opts)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2688 while_loop
back_prop=back_prop)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:104 while_loop
maximum_iterations)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:1258 _build_maximum_iterations_loop_var
maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1317 convert_to_tensor
(dtype.name, value.dtype.name, value))
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype int64: <tf.Tensor 'my_layer_15/strided_slice:0' shape=() dtype=int64>
So I just performed the below experiment to understand the data type produced by the for
loop and enumerate
when using inputs
. for
loop generates a tensor
class whereas enumerate
generates a int
class.
Experiment Code -
inputs = tf.keras.layers.Input(shape=(None, 1), ragged=True)
for x in inputs:
print(type(x))
break
for i,x in enumerate(inputs):
print(type(i))
break
Output -
<class 'tensorflow.python.framework.ops.Tensor'>
<class 'int'>
So I modified your code as below and it worked fine -
Fixed Code -
import tensorflow as tf
class myLayer(tf.keras.layers.Layer):
def __init__(self):
super(myLayer, self).__init__()
self._supports_ragged_inputs = True
def call(self, inputs):
# Try to loop over ragged tensor
# for x in inputs: # Throws Error
for i,x in enumerate(inputs): #Enumerate Works fine
break #Using break as pass will go into loop
return tf.constant(0)
# Input is ragged tensor
inputs = tf.keras.layers.Input(shape=(None, 1), ragged=True)
layer1 = myLayer()
output = layer1(inputs)
print(output)
Output -
Tensor("my_layer_17/Identity:0", shape=(), dtype=int32)
Hope this answers your question. Happy Learning.