I wrote a code to segment a satellite-picture into seven regions (city, forest, water, ...). The problem is that when I execute the script I get exact the following error:
Traceback (most recent call last): File "/Users/.../pytorch_test.py", line 215, in model = train_model(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, num_epochs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/.../pytorch_test.py", line 179, in train_model train_loss = train(model, train_dataloader, loss_fn, optimizer, device) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/.../pytorch_test.py", line 134, in train loss = loss_fn(outputs, labels) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/loss.py", line 1174, in forward return F.cross_entropy(input, target, weight=self.weight, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/functional.py", line 3029, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: Expected input batch_size (3) to match target batch_size (9).
The problem is that when I change the batch_size I get every time tripple the target batch_size, but I cannot find the bug. I felt I searched the whole internet, but found nothing. I hope some of you can help me!
Thats my code:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from PIL import ImageEnhance
class SatelliteDataset(Dataset):
def __init__(self, image_folder, label_folder, transform=None):
self.image_folder = image_folder
self.label_folder = label_folder
self.transform = transform
self.image_paths = sorted([os.path.join(image_folder, filename) for filename in os.listdir(image_folder)])
self.label_paths = sorted([os.path.join(label_folder, filename) for filename in os.listdir(label_folder)])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
label_path = self.label_paths[idx]
image = Image.open(image_path)
label = Image.open(label_path)
image = self.adjust_brightness(image, brightness_factor=1.8)
# Resize images to 512x512
image = image.resize((512, 512), Image.BILINEAR)
label = label.resize((512, 512), Image.NEAREST)
# Apply transformations if specified
if self.transform:
image = self.transform(image)
label = self.transform(label)
return image, label
def adjust_brightness(self, image, brightness_factor=1.0):
enhancer = ImageEnhance.Brightness(image)
enhanced_image = enhancer.enhance(brightness_factor)
return enhanced_image
# Define paths to the folders containing satellite images and corresponding labels
image_folder = "/Users/.../train_data/images"
label_folder = "/Users/.../train_data/masks"
# Define transformations for normalization and scaling
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# Create an instance of the SatelliteDataset
dataset = SatelliteDataset(image_folder, label_folder, transform=transform)
# Create DataLoader for training
def collate_fn(batch):
images, labels = zip(*batch)
images = torch.stack(images)
labels = torch.stack(labels)
return images, labels
batch_size = 3
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
# cnn
class CNN(nn.Module):
def __init__(self, in_channels, out_channels):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.relu3 = nn.ReLU()
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.relu4 = nn.ReLU()
self.conv5 = nn.Conv2d(512, out_channels=7, kernel_size=1, stride=1)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.relu3(x)
x = self.maxpool3(x)
x = self.conv4(x)
x = self.relu4(x)
x = self.conv5(x)
x = torch.softmax(x, dim=1) # add Softmax layer
return x
# Training loop
def train(model, dataloader, loss_fn, optimizer, device):
model.train() # Set the model to training mode
running_loss = 0.0
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
# Forward pass
outputs = model(images)
# Reshape labels dimensions
labels = labels.view(-1, 512, 512)
labels = labels.long()
# Calculate loss
loss = loss_fn(outputs, labels)
# Backward pass and weight update
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
epoch_loss = running_loss / len(dataloader.dataset)
return epoch_loss
# Validation
def evaluate(model, dataloader, loss_fn, device):
model.eval() # Set the model to evaluation mode
running_loss = 0.0
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
# Calculate loss
loss = loss_fn(outputs, labels)
running_loss += loss.item() * images.size(0)
epoch_loss = running_loss / len(dataloader.dataset)
return epoch_loss
# Perform training
def train_model(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, num_epochs):
best_val_loss = float('inf')
best_model_weights = None
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
# Training step
train_loss = train(model, train_dataloader, loss_fn, optimizer, device)
print(f"Train Loss: {train_loss}")
# Validation step
val_loss = evaluate(model, val_dataloader, loss_fn, device)
print(f"Val Loss: {val_loss}")
# Check for improvement in validation loss
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_weights = model.state_dict()
# Return the best model
model.load_state_dict(best_model_weights)
return model
# Example usage of the model
model = CNN(in_channels=3, out_channels=7)
# Select device (CPU)
device = torch.device("cpu")
# Loss function
loss_fn = nn.CrossEntropyLoss()
# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
Number of epochs (adjust as needed)
num_epochs = 10
# Create a separate DataLoader for validation with batch_size=3
val_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
# Perform training
model = train_model(model, train_dataloader, val_dataloader, loss_fn, optimizer, device, num_epochs)
By the way I am programming on MacOS and with VSC
I tried: I copied all the error messages into google and read almost every article on that problem. Nonetheless I could not resolve this. I tried to change the batch sizes (I have 804 pictures to train the model with) and changed the data-loader. I also asked ChatGPT, although it explained the problem very good and gave a couple of ideas to solve this, it helped not.
You have a line labels = labels.view(-1, 512, 512)
that is changing the dimensions of your labels. From your dataloader, it seems that your labels are images. If it had 3 channels (RGB), then this dimension-changing line would basically triple your batch size.