pythonmachine-learningdeep-learningcomputer-visionjax

Looking for an efficent JAX function to reconstruct an image from patches


I have a set of images in (c, h, w) jax arrays. These arrays have been converted to (patch_index, patch_dim) arrays where patch_dim == c * h * w.

I am trying to reconstruct the original images from the patches. Here is vanilla python code that works:

kernel = jnp.ones((PATCH_DIM, IMG_CHANNELS, PATCH_HEIGHT, PATCH_WIDTH), dtype=jnp.float32)

def fwd(x):
    xcv = lax.conv_general_dilated_patches(x, (PATCH_HEIGHT, PATCH_WIDTH), (PATCH_HEIGHT, PATCH_WIDTH), padding='VALID')

    # return channels last
    return jnp.transpose(xcv, [0,2,3,1])

patches = fwd(bfrc)

patch_reshaped_pn_c_h_w = patch_reshaped_ph_pw_c_h_w = jnp.reshape(patches, (V_PATCHES, H_PATCHES, IMG_CHANNELS, PATCH_HEIGHT, PATCH_WIDTH))

# V_PATCHES == IMG_HEIGHT // PATCH_HEIGHT
# H_PATCHES == IMG_WIDTH // PATCH_WIDTH

reconstructed = np.zeros(EXPECTED_IMG_SHAPE)

for vpatch in range(0, patch_reshaped_ph_pw_c_h_w.shape[0]):
    for hpatch in range(0, patch_reshaped_ph_pw_c_h_w.shape[1]):
        for ch in range(0, patch_reshaped_ph_pw_c_h_w.shape[2]):
            for prow in range(0, patch_reshaped_ph_pw_c_h_w.shape[3]):
                for pcol in range(0, patch_reshaped_ph_pw_c_h_w.shape[4]):
                    row = vpatch * PATCH_HEIGHT + prow
                    col = hpatch * PATCH_WIDTH + pcol
                    reconstructed[0, ch, row , col] = patch_reshaped_ph_pw_c_h_w[vpatch, hpatch, ch, prow, pcol]

# This assert passes
assert jnp.max(jnp.abs(reconstructed - bfrc[0])) == 0

Of course this vanilla python code is very inefficient. How can I convert the for loops into efficient JAX code?


Solution

  • I'm not sure what happened here:

    patch_reshaped_pn_c_h_w = patch_reshaped_ph_pw_c_h_w = jnp.reshape(patches, (V_PATCHES, H_PATCHES, IMG_CHANNELS, PATCH_HEIGHT, PATCH_WIDTH))
    

    but I assume it's some kind of mistake.

    Assuming bfrc has shape of (batch, channels, height, width), and

    V_PATCHES = IMG_HEIGHT // PATCH_HEIGHT
    H_PATCHES = IMG_WIDTH // PATCH_WIDTH
    

    then patch_reshaped_pn_c_h_w will have the shape of (V_PATCHES, H_PATCHES, IMG_CHANNELS, PATCH_HEIGHT, PATCH_WIDTH).

    Keeping this in mind, you can simply reconstruct the image via simply transposing and reshaping, which is quite cheaper than these nested loops.

    V, H, C, PH, PW = patch_reshaped_ph_pw_c_h_w.shape
    
    H_total = V * PH
    W_total = H * PW
    
    patches = jnp.transpose(patch_reshaped_ph_pw_c_h_w, (0, 1, 3, 4, 2))  # (V, H, PH, PW, C)
    
    reconstructed = patches.reshape(V, H, PH, PW, C)
    reconstructed = reconstructed.transpose(0, 2, 1, 3, 4)
    reconstructed = reconstructed.reshape(H_total, W_total, C)
    reconstructed = jnp.transpose(reconstructed, (2, 0, 1))[jnp.newaxis, ...] # (1, C, H, W)
    

    You can additionally wrap it into @jax.jit, which should be slightly faster.