pythonlistnumpymultidimensional-arrayarray-indexing

Writing to multidimensional arrays with variable dimension


I've initialized a list of arrays of different dimensions in Python and want to write to the entries of the arrays. If I want to write to the entry list[i, m1, m1, ..., m1, m2, m2, ..., m2] where m1 appears n_m1 times and m2 appears n_m2 times, how would I go about doing that, taking n_m1 and n_m2 as input variables?

I know that if I had fixed n_m1 and n_m2 I could just type out the index [i, m1, m1, ..., m1, m2, m2, ..., m2] with the exact number of m1 and m2 that I need, but this is part of a larger for loop and I have those numbers as variables.

Example code:

import numpy as np
m = 10
i_range = 6
list = [np.zeros((m,) * i) for i in range(0,i_range)]

The above generates a list where the ith entry is a zero array of dimension (m, m, ..., m) where m appears i times.

Running

for i, item in enumerate(list):
    print("i =", i, "shape: ", item.shape)

returns

i = 0 shape:  ()
i = 1 shape:  (10,)
i = 2 shape:  (10, 10)
i = 3 shape:  (10, 10, 10)
i = 4 shape:  (10, 10, 10, 10)
i = 5 shape:  (10, 10, 10, 10, 10)

What I'm looking for is a way to take take input i and m1, m2, n_m1, n_m2 such that n_m1 + n_m2 = i, and write something to the entry list[m1, m1, ..., m1, m2, m2, ..., m2]. Does anyone know of a way to do this?


Solution

  • Just keep in mind what you are working with - you have a listobject where every element is a numpy.ndarray object (of varying shapes). The simple way to accomplish this is then to just get the ith numpy.ndarray out of the list, then create a tuple with the appropriate indices (this can be done easily with python's nifty sequence operators + and *), then simply use index assignment on the numpy.ndarray:

    def write(jagged, *, i, m1, m2, n_m1, n_m2, val):
        arr = jagged[i]
        idx = (m1,)*n_m1 + (m2,)*n_m2
        arr[idx] = val
    

    This could be simplfied further (but why? isn't the above just nice and readable?)

    def write(jagged, *, i, m1, m2, n_m1, n_m2, val):
        jagged[i][(m1,)*n_m1 + (m2,)*n_m2] = val