elixirelixir-nx

Mapping slices of an Nx tensor


If I have a function that expects an Nx tensor of a specific shape, and I have a larger tensor that includes slices of that shape, is there an efficient way to map some function over those slices?

The specific function I have in mind is Nx.Random.randint_split/4 which expects a key of shape exactly {2}. In purely functional code, you could imagine creating a list of keys and then calling Enum.map/2, like

def random_values(seed, count, max) do
  key0 = Nx.Random.key(seed)
  keys = 0..(count - 1) |> Enum.map(&Nx.Random.fold_in(key0, &1))
  Enum.map(keys, &Nx.Random.randint_split(&1, 0, max))
end

random_values(0, 6, 10) |> Enum.map(&Nx.to_number/1)
# [4, 0, 2, 2, 2, 7]
# Note, for all the "random", this is deterministic

But note here that keys is a list of exactly 6 tensors of shape exactly {2}; it would be more efficient to create a single tensor of shape {6, 2}. And in fact Nx.Random.fold_in/2 supports creating this matrix

defn random_values(key0)
  keys = Nx.Random.fold_in(key0, Nx.iota({6)))
  # Nx.shape(keys) ==> `{6, 2}`

Now I'd like to be able to call Nx.Random.randint_split/4 over each individual vector in this matrix, get some sort of collection of scalars out, and Nx.concatenate/2 then together into a vector of results. Is there an efficient way to do this?

I'm actually calling this to sample some larger range, so the keys won't be sequential integers like this, but it's important that they're deterministic.

defn random_values(key0, max, opts \\ []) do
  opts = keyword!(opts, start: 0, count: 6, stride: 1)
  count = opts[:count]
  offsets = Nx.iota({count}) * opts[:stride] + opts[:start]
  keys = Nx.Random.fold_in(key0, offsets)

  # This is slow and awkward, but works
  results = Nx.broadcast(0, {count})
  {_, _, results} = while {keys, max, results}, i <- 0..(count - 1) do
    updates = Nx.Random.randint_split(keys[i], 0, max, shape: {1})
    results = Nx.put_slice(results, [i], updates)
    {keys, max, results}
  end

  results
end

import Nx, only: :sigils
test "produces the same result as above"
  assert random_values(Nx.Random.key(0), 10, count: 6) == ~V[4 0 2 2 2 7]
end
test "produces the same value at position 6"
  key = Nx.Random.key(0)
  by_twos = random_values(key, 10, count: 7, stride: 2)
  by_threes = random_values(key, 10, count: 7, stride: 3)
  assert by_twos[3] == by_threes[2]
  assert by_twos[6] == by_threes[4]
end

Nx.map/3 only works on individual elements, not slices, so it's not a choice here. An Nx.Defn.Kernel.while/4 loop is possible, but clunky, and seems to involve preallocating the result tensor and putting individual values in it, so it's not the fastest thing; this is demonstrated in the last example. I can write an ordinary recursive function but that's even slower.

I do not have a GPU. EXLA is not currently working on my system, so I'm using the default backend, which I know is also not the fastest. I do have ambitions of running this under EXLA some day and I think this means minimizing the number of times data crosses in and out of defn functions. I don't consider the matrices involved huge, but they're not trivial either. For comparison, one test case creates a 131x131 matrix and runs code like this in 7-8 seconds; if I didn't want to control the keys I could call Nx.Random.uniform_split(..., shape: {131, 131}) and get back a response effectively instantly.


Solution

  • If your major concern is performance and not readability, EXLA is in fact significantly faster. On a test case cited in the question, the default ("binary") Nx backend took 7-8 seconds to process a moderately-sized matrix, but EXLA was able to do the same work in 200-300 ms.

    I mentioned that EXLA was not working, and the root cause of this was a compiler optimization problem that was resolved in EXLA 0.5.3. This means that the usual steps for enabling EXLA work here:

    1. In your mix.exs file, add a dependency

      {:exla, "~> 0.5"}
      
    2. In your config.exs file (or in config/dev.exs and config/prod.exs) enable the EXLA Nx backend and compiler

      config :nx,
             default_backend: EXLA.Backend,
             default_defn_options: [compiler: EXLA]
      

    The code continues to be that shown in the question: preallocate a result tensor, then use Nx.put_slice to fill that in within a while loop. So long as this remains within a single defn function, data does not cross in and out of Nx space.

    defn transform(input) do
      {rows, _} = Nx.shape(input)
      results = Nx.broadcast(0, input)
      {_, results} = while {input, results}, i <- (0..rows-1) do
        input_row = input[i]
        result_row = some_transformation(input_row)
        results = Nx.put_slice(results, [i], result_row)
        {input, results}
      end
      results
    end