c++flashlightarrayfire

torch.squeeze and torch.unsqueeze equivalent in Flashlight (arrayfire)


I'm porting PyTorch code to Flashlight code. What is an Arrayfire or Flashlight function equivalent for squeeze and unsqueeze in Pytorch?

processed_query = self.query_layer(query.unsqueeze(1))

energies = energies.squeeze(-1)

How to convert this to Arrayfire code? (or, flashlight?)


Solution

  • You can do this using the af::moddims function:

    array a = randu(10, 1, 10, 10);
    squeezed_a = moddims(a, 10, 10, 10);