Currently working on a classifier using PyWavelets, here is my calculation block:
class WaveletLayer(nn.Module):
def __init__(self):
super(WaveletLayer, self).__init__()
def forward(self, x):
def wavelet_transform(img):
coeffs = pywt.dwt2(img.cpu().numpy(), "haar")
LL, (LH, HL, HH) = coeffs
return (
torch.from_numpy(LL).to(img.device),
torch.from_numpy(LH).to(img.device),
torch.from_numpy(HL).to(img.device),
torch.from_numpy(HH).to(img.device),
)
# Apply wavelet transform to each channel separately
LL, LH, HL, HH = zip(
*[wavelet_transform(x[:, i : i + 1]) for i in range(x.shape[1])]
)
# Concatenate the results
LL = torch.cat(LL, dim=1)
LH = torch.cat(LH, dim=1)
HL = torch.cat(HL, dim=1)
HH = torch.cat(HH, dim=1)
return torch.cat([LL, LH, HL, HH], dim=1)
The output from this module goes to a resnet block for learning, while doing this I find my CPU clogged and thus slowing down my training process
I am trying to use the GPUs for these calculations.
Since you only seem to be interested in the Haar wavelet, you can pretty much implement it yourself:
The following code achieves this in pure PyTorch:
class HaarWaveletLayer(nn.Module):
def l_0(self, t): # sum ("low") along cols
t = torch.cat([t, t[..., -1:, :]], dim=-2) if t.shape[-2] % 2 else t
return (t[..., ::2, :] + t[..., 1::2, :])
def l_1(self, t): # sum ("low") along rows
t = torch.cat([t, t[..., :, -1:]], dim=-1) if t.shape[-1] % 2 else t
return (t[..., :, ::2] + t[..., :, 1::2])
def h_0(self, t): # diff ("hi") along cols
t = torch.cat([t, t[..., -1:, :]], dim=-2) if t.shape[-2] % 2 else t
return (t[..., ::2, :] - t[..., 1::2, :])
def h_1(self, t): # diff ("hi") along rows
t = torch.cat([t, t[..., :, -1:]], dim=-1) if t.shape[-1] % 2 else t
return (t[..., :, ::2] - t[..., :, 1::2])
def forward(self, x):
x = .5 * x
l_1 = self.l_1(x)
h_1 = self.h_1(x)
ll = self.l_0(l_1)
lh = self.h_0(l_1)
hl = self.l_0(h_1)
hh = self.h_0(h_1)
return torch.cat([ll, lh, hl, hh], dim=1)
In combination with your given code, you can convince yourself of the equivalence as follows:
t = torch.rand((7, 3, 127, 128)).to("cuda:0")
result_given = WaveletLayer()(t)
result_proposed = HaarWaveletLayer()(t)
# Same result?
assert (result_given - result_proposed).abs().max() < 1e-5
# Time comparison
from timeit import Timer
num_timings = 100
print("time given: ", Timer(lambda: WaveletLayer()(t)).timeit(num_timings))
print("time proposed:", Timer(lambda: HaarWaveletLayer()(t)).timeit(num_timings))
The timing shows a speedup of more than a factor of 10 on my machine.
t = torch.cat...
parts are only necessary if you want to be able to handle odd-shaped images: In that case, we pad by replicating the last row and column, respectively, mimicking the default padding of PyWavelets.x
with .5 is done for normalization. Compare this discussion on the Signal Processing Stack Exchange for more details.