pythontensorflowtensorflow-litetinyml

Is the keras function Flatten() supported by TensorFlow Lite?


I'm building my own CNN and I'm trying to put it on a Disco-f746ng according to the "TensorFlow Lite for microcontrollers" tutorials and the TinyML book. I know that the supported tensorflow-keras functions can be found here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/all_ops_resolver.cc But the Flatten() function seems not to be listed. That's irritating me because it is such a basic function, so I thought maybe it just has a different name in the all_ops_resolver. I'm using only functions that are listed there plus the Flatten() function. When I run a test with my own model, I always get a segmentation fault, no matter how much space I allocate. That's why I wanted to ask if the Flatten() function is supported by TensorFlow Lite?

That's my Python code for creating the CNN:

model = models.Sequential()
model.add(layers.Conv2D(16, (3, 3), activation='relu', input_shape=(36, 36, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(36, 36, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(36, 36, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(8, activation='softmax'))
model.add(layers.Dense(2))

model.summary()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Solution

  • Ok, I think I figured it out now. I had another problem that led to the segmentation faults, but I solved it now. Afterwards I was ready to check if Flatten() is supported. It works!

    The CNN-model code above works when adding following Builtins to the micro-op-resolver:

    tflite::MicroMutableOpResolver<5> micro_op_resolver;
    
    micro_op_resolver.AddConv2D();
    micro_ou_resolver.AddFullyConnected();
    micro_op_resolver.AddMaxPool2D();
    micro_op_resolver.AddSoftmax();
    micro_op_resolver.AddReshape();
    

    According to my trial & error approach, adding RESHAPE() is necessary for being able to use Flatten()