imagepytorchneural-networkdimensions

dimentions of images red channel to FC network input


as part of the diffusion model i build this fully connected network :

import torch
from torch import linalg as LA
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
lr = 0.0001
num_data = 3000
batch_size = 3000
num_epochs =3000
x_dim = 2  # dimension of input
beta1 = 0.5  # Beta1 hyperparameter for Adam optimizers
l2 = nn.MSELoss()
colors = px.colors.qualitative.T10
MODEL_PATH = 'denoiser.pth'
CONDITIONAL_MODEL_PATH = 'denoiser_conditional.pth'
# the denoiser
class NetD(nn.Module):
    def __init__(self):
        super(NetD, self).__init__()
        # intput: [x0, x1, t]
        self.fc = nn.Sequential(
            nn.Linear(x_dim + 1, (x_dim + 1) * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear((x_dim + 1) * 8, (x_dim + 1) * 20),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear((x_dim + 1) * 20, x_dim),
            # output: [e0, e1]
        )

    def forward(self, x_t):
        return self.fc(x_t)

first i trained it on uniform data that looks like this :

enter image description here

its shape was [3000,2]

now i want to train the model on red channel of image of size (50,50):

import torchvision.transforms as transforms
import torchvision.models as models
# desired size of the output image
imsize =(50,50)  # use small size if no GPU

loader = transforms.Compose([
    transforms.Resize(imsize),  # scale imported image
    transforms.ToTensor()])  # transform it into a torch tensor
def image_loader(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)
cat=image_loader('/content/orange_cat.jpg')
red_channel = cat[:, 0, :, :]  # Extract the red channel (R) of the image
C = 200  # Threshold value for red channel
density = torch.where(red_channel > 0.78, torch.tensor(1.0), torch.tensor(0.0))
data_1 = red_channel.squeeze().to(device)

As I understand in order to pass the red channel of an image into the FC network->netD the dimensionality of the data should also be 2 (x,y),BUT Im not so sure how should i transform the tensor shape [50,50] correctly : I tried transforming it like this :

# Reshape the tensor to a 2-dimensional tensor
reshaped_red_channel = data_1.view(-1, 2)

It created a tensor of size [1250,2] and Im not sure its correct ? as i understand the Total elements in the original tensor = Total elements in the reshaped tensor so it does create the 2500 elements of the original vector


Solution

  • From the pytorch docs, torch.linear expects an input of size [*,H] where the * means there can be any number of dimensions beforehand, and the last dimension H is the number of features.

    In your initial training, the dimension is [3000,2], which meets this criteria.

    There are many ways you could view a [1,50,50] tensor (the red channel) and none will throw an error upon being input to the linear layer provided that the resulting dimension is [*,2] and the product of all the dimensions in * adds up to 1250*batch_size. So view(-1,2) will accomplish this and yield a result of size [1250*batch_size,2]. It is "correct" in a syntactic sense, though perhaps not in a semantic sense.

    It is worth considering that each possible way of viewing an image preserves different parts of the structure of the data. Since you are dealing with an image, the original data is highly spatially structured. By viewing this data as [1250*batch_size,2] the only structure maintained by the linear layer is the correlation between \ pairs of pixel values consecutive along the last (probably column) dimension of the image. Without knowing more about the problem you intend to solve it can't be said whether this is a problem, but it seems odd to process highly-structured image data in this way to me at first glance because it throws away almost all of the information contained in an image. Very little can be said about an image's content based solely on the pixel intensities.