I'm porting PyTorch code to Flashlight code. What is an Arrayfire or Flashlight function equivalent for squeeze
and unsqueeze
in Pytorch?
processed_query = self.query_layer(query.unsqueeze(1))
energies = energies.squeeze(-1)
How to convert this to Arrayfire code? (or, flashlight?)
You can do this using the af::moddims function:
array a = randu(10, 1, 10, 10);
squeezed_a = moddims(a, 10, 10, 10);