I'm running a training code using pyhtorch
and numpy
.
This is the plot_example
function:
def plot_example(low_res_folder, gen):
files=os.listdir(low_res_folder)
gen.eval()
for file in files:
image=Image.open("test_images/" + file)
with torch.no_grad():
upscaled_img=gen(
config1.both_transform(image=np.asarray(image))["image"]
.unsqueeze(0)
.to(config1.DEVICE)
)
save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
gen.train()
The problem I have is that the unsqueeze
attribute raises the error:
File "E:\Downloads\esrgan-tf2-masteren\modules\train1.py", line 58, in train_fn
plot_example("test_images/", gen)
File "E:\Downloads\esrgan-tf2-masteren\modules\utils1.py", line 46, in plot_example
config1.both_transform(image=np.asarray(image))["image"]
AttributeError: 'numpy.ndarray' object has no attribute 'unsqueeze'
The network is GAN network and gen()
represents the Generator.
Make sure image is a tensor in the shape of [batch size, channels, height, width] before entering any Pytorch layers.
Here you have
image=np.asarray(image)
I would remove this numpy conversion and keep it a torch.tensor.
Or if you really want it to be a numpy array, then right before it enters your generator make sure to use torch.from_numpy()
as shown in this documentation on your numpy image before it gets unsqueezed: https://pytorch.org/docs/stable/generated/torch.from_numpy.html
This function is ofcourse an alternative if you don't want to get rid of that original conversion.
To gain a clearer insight into torch numpy convertions. You should try looking at deep learning repos (especially the data classes). Websites such as repo-rift.com can be particularly useful for this purpose. They allow you to perform text searches with queries like "Show how an opencv image loaded in numpy is converted to pytorch tensor". This can help you pinpoint how other coders do it.
Sarthak Jain