I have a set of images in (c, h, w) jax arrays. These arrays have been converted to (patch_index, patch_dim) arrays where patch_dim == c * h * w.
I am trying to reconstruct the original images from the patches. Here is vanilla python code that works:
kernel = jnp.ones((PATCH_DIM, IMG_CHANNELS, PATCH_HEIGHT, PATCH_WIDTH), dtype=jnp.float32)
def fwd(x):
xcv = lax.conv_general_dilated_patches(x, (PATCH_HEIGHT, PATCH_WIDTH), (PATCH_HEIGHT, PATCH_WIDTH), padding='VALID')
# return channels last
return jnp.transpose(xcv, [0,2,3,1])
patches = fwd(bfrc)
patch_reshaped_pn_c_h_w = patch_reshaped_ph_pw_c_h_w = jnp.reshape(patches, (V_PATCHES, H_PATCHES, IMG_CHANNELS, PATCH_HEIGHT, PATCH_WIDTH))
# V_PATCHES == IMG_HEIGHT // PATCH_HEIGHT
# H_PATCHES == IMG_WIDTH // PATCH_WIDTH
reconstructed = np.zeros(EXPECTED_IMG_SHAPE)
for vpatch in range(0, patch_reshaped_ph_pw_c_h_w.shape[0]):
for hpatch in range(0, patch_reshaped_ph_pw_c_h_w.shape[1]):
for ch in range(0, patch_reshaped_ph_pw_c_h_w.shape[2]):
for prow in range(0, patch_reshaped_ph_pw_c_h_w.shape[3]):
for pcol in range(0, patch_reshaped_ph_pw_c_h_w.shape[4]):
row = vpatch * PATCH_HEIGHT + prow
col = hpatch * PATCH_WIDTH + pcol
reconstructed[0, ch, row , col] = patch_reshaped_ph_pw_c_h_w[vpatch, hpatch, ch, prow, pcol]
# This assert passes
assert jnp.max(jnp.abs(reconstructed - bfrc[0])) == 0
Of course this vanilla python code is very inefficient. How can I convert the for loops into efficient JAX code?
I'm not sure what happened here:
patch_reshaped_pn_c_h_w = patch_reshaped_ph_pw_c_h_w = jnp.reshape(patches, (V_PATCHES, H_PATCHES, IMG_CHANNELS, PATCH_HEIGHT, PATCH_WIDTH))
but I assume it's some kind of mistake.
Assuming bfrc has shape of (batch, channels, height, width)
, and
V_PATCHES = IMG_HEIGHT // PATCH_HEIGHT
H_PATCHES = IMG_WIDTH // PATCH_WIDTH
then patch_reshaped_pn_c_h_w
will have the shape of (V_PATCHES, H_PATCHES, IMG_CHANNELS, PATCH_HEIGHT, PATCH_WIDTH)
.
Keeping this in mind, you can simply reconstruct the image via simply transposing and reshaping, which is quite cheaper than these nested loops.
V, H, C, PH, PW = patch_reshaped_ph_pw_c_h_w.shape
H_total = V * PH
W_total = H * PW
patches = jnp.transpose(patch_reshaped_ph_pw_c_h_w, (0, 1, 3, 4, 2)) # (V, H, PH, PW, C)
reconstructed = patches.reshape(V, H, PH, PW, C)
reconstructed = reconstructed.transpose(0, 2, 1, 3, 4)
reconstructed = reconstructed.reshape(H_total, W_total, C)
reconstructed = jnp.transpose(reconstructed, (2, 0, 1))[jnp.newaxis, ...] # (1, C, H, W)
You can additionally wrap it into @jax.jit
, which should be slightly faster.