python-3.xnumpy

How to find all combinations of values in rows of an M x N array


Apologies if the title is not very clear. I'm having some difficulty framing the question.

I have two numpy arrays of N X M dimensions. For simplicity, assume they both have shape (2,10). The first array is composed of floats. For example:

[[0.1,0.02,0.2,0.3,0.013,0.7,0.7,0.11,0.18,0.6],
 [0.23,0.02,0.1,0.1,0.011,0.3,0.4,0.4,0.4,0.5]]

The second array is composed of 0's and 1's. For example:

[[0,1,0,0,0,0,1,1,0,1],
 [1,0,1,0,0,0,0,1,1,0]]

I am trying do the following: for a given configuration of unique values in the second array, select the elements in the first array at that location. To give an example, we have two rows in the second array, so 4 possible configurations of 1's and 0's. That is (1,1), (0,0), (1,0), (0,1). If we were to take the (1,1) case (that is, where elements in both row 1 and row 2 are equal to '1' in the second array), I would want to locate these values and do a lookup of their locations in the first array. This would return (0.11,0.4) from the first array.

Again, apologies if this is not clearly communicated. Grateful for any feedback. Thanks.


Solution

  • IIUC, you want to search in arr1 on positions where in arr2 are all 1 in columns:

    arr1 = np.array(
        [
            [0.1, 0.02, 0.2, 0.3, 0.013, 0.7, 0.7, 0.11, 0.18, 0.6],
            [0.23, 0.02, 0.1, 0.1, 0.011, 0.3, 0.4, 0.4, 0.4, 0.5],
        ]
    )
    
    arr2 = np.array([[0, 1, 0, 0, 0, 0, 1, 1, 0, 1], [1, 0, 1, 0, 0, 0, 0, 1, 1, 0]])
    
    out = arr1[:, np.all(arr2, axis=0)]
    print(out)
    

    Prints:

    [[0.11]
     [0.4 ]]
    

    If you want to find all combinations:

    unique = np.unique(arr2.T, axis=0)
    
    for row in unique:
        print("Combination:")
        print(row)
        print()
        print(arr1[:, np.all(arr2 == row.reshape(arr2.shape[0], -1), axis=0)])
        print("-" * 80)
    

    Prints:

    Combination:
    [0 0]
    
    [[0.3   0.013 0.7  ]
     [0.1   0.011 0.3  ]]
    --------------------------------------------------------------------------------
    Combination:
    [0 1]
    
    [[0.1  0.2  0.18]
     [0.23 0.1  0.4 ]]
    --------------------------------------------------------------------------------
    Combination:
    [1 0]
    
    [[0.02 0.7  0.6 ]
     [0.02 0.4  0.5 ]]
    --------------------------------------------------------------------------------
    Combination:
    [1 1]
    
    [[0.11]
     [0.4 ]]
    --------------------------------------------------------------------------------