pythonarraysnumpyjaxkronecker-product

How to map the kronecker product along array dimensions?


Given two tensors A and B with the same dimension (d>=2) and shapes [A_{1},...,A_{d-2},A_{d-1},A_{d}] and [A_{1},...,A_{d-2},B_{d-1},B_{d}] (shapes of the first d-2 dimensions are identical).

Is there a way to calculate the kronecker product over the last two dimensions? The shape of my_kron(A,B)should be [A_{1},...,A_{d-2},A_{d-1}*B_{d-1},A_{d}*B_{d}]. For example with d=3,

A.shape=[2,3,3]
B.shape=[2,4,4]
C=my_kron(A,B)

C[0,...] should be the kronecker product of A[0,...] and B[0,...] and C[1,...] the kronecker product of A[1,...] and B[1,...].

For d=2 this is simply what the jnp.kron(or np.kron) function does.

For d=3 this can be achived with jax.vmap. jax.vmap(lambda x, y: jnp.kron(x[0, :], y[0, :]))(A, B)

But I was not able to find a solution for general (unknown) dimensions. Any suggestions?


Solution

  • In numpy terms I think this is what you are doing:

    In [104]: A = np.arange(2*3*3).reshape(2,3,3)
    In [105]: B = np.arange(2*4*4).reshape(2,4,4)
    
    In [106]: C = np.array([np.kron(a,b) for a,b in zip(A,B)])
    In [107]: C.shape
    Out[107]: (2, 12, 12)
    

    That treats the initial dimension, the 2, as a batch. One obvious generalization is to reshape the arrays, reducing the higher dimensions to 1, e.g. reshape(-1,3,3), etc. And then afterwards, reshape C back to the desired n-dimensions.

    np.kron does accept 3d (and higher), but it's doing some sort of outer on the shared 2 dimension:

    In [108]: np.kron(A,B).shape
    Out[108]: (4, 12, 12)
    

    And visualizing that 4 dimension as (2,2), I can take the diagonal and get your C:

    In [109]: np.allclose(np.kron(A,B)[[0,3]], C)
    Out[109]: True
    

    The full kron does more calculations than needed, but is still faster:

    In [110]: timeit C = np.array([np.kron(a,b) for a,b in zip(A,B)])
    108 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
    
    In [111]: timeit np.kron(A,B)[[0,3]]
    76.4 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
    

    I'm sure it's possible to do your calculation in a more direct way, but doing that requires a better understanding of how the kron works. A quick glance as the np.kron code suggest that is does an outer(A,B)

    In [114]: np.outer(A,B).shape
    Out[114]: (18, 32)
    

    which has the same number of elements, but it then reshapes and concatenates to produce the kron layout.

    But following a hunch, I found that this is equivalent to what you want:

    In [123]: D = A[:,:,None,:,None]*B[:,None,:,None,:]
    In [124]: np.allclose(D.reshape(2,12,12),C)
    Out[124]: True
    In [125]: timeit np.reshape(A[:,:,None,:,None]*B[:,None,:,None,:],(2,12,12))
    14.3 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
    

    That is easily generalized to more leading dimensions.

    def my_kron(A,B):
       D = A[...,:,None,:,None]*B[...,None,:,None,:]
       ds = D.shape
       newshape = (*ds[:-4],ds[-4]*ds[-3],ds[-2]*ds[-1])
       return D.reshape(newshape)
    
    In [137]: my_kron(A.reshape(1,2,1,3,3),B.reshape(1,2,1,4,4)).shape
    Out[137]: (1, 2, 1, 12, 12)