swifttensorflowkerasmetalmetal-performance-shaders

MPSNNGraph Reshape node


Is there an analog to Keras/TensorFlow's Reshape layer in Metal/Metal-Performance-Shaders? I'm converting a trained Keras model to an MPSNNGraph and I've been having trouble handling a conv2d->dense->conv2dTranspose section. I am trying to build the following configuration:

Conv2D_1, output_shape: (1,1,128) <-- i've handled up to here
Reshape_1(Conv2D_1), output_shape: (128)
Dense_1(Reshape_1), output_shape: (1024)
Dense_2(Dense_1), output_shape: (8192)
Reshape_2(Dense_2), output_shape: (4,4,512) <--hung up here
Conv2DTranspose(Reshape_2), output_shape: (8,8,256) <-- i know how to do this

I understand the concept of expressing dense layers as 4D blocks in metal (i think?), but I do not understand how to get from shape (8192) to (4,4,512) with metal. I feel like there is something clever one could do to get Dense_2 and reshape into a single MPSNN node, but I don't know what it is.

I am trying to only use objects from https://developer.apple.com/documentation/metalperformanceshaders/objects_that_simplify_the_creation_of_neural_networks


Solution

  • The MPSCNNFullyConnected layer gives a 1x1 image as output with 8192 channels in your case (which is also the maximum number of channels supported by a Metal texture).

    The reshape option would take in that 1x1x8192 image and output a 4x4 image with 512 channels. In Keras etc that's just a matter of changing the strides in the underlying TF tensor but in Metal you'll have to copy the data.

    As of iOS 11.3, there is an MPSNNReshape object that can do this kind of rearranging. I haven't used it myself but it looks like you just give it the source and destination image (when encoding this kernel) and it will figure out for itself how to copy the data between them.

    However... MPSNNReshape doesn't appear to be a node that can be placed into an MPSNNGraph, since it extends MPSCNNKernel, not MPSNNFilterNode. This seems like a bit of an oversight...

    So you'll have to create two graphs, one up to this reshape node, and one with the rest of the code. Then run the first graph, encode the reshape node using the output of that first graph, then run the second graph using the output of the reshape node. This should be fast enough, but it's a little annoying since you can't do the whole thing in a single graph.