I am trying to create a version of the UNet CNN which will take in a certain type of MRI image volume as the source and use corresponding MRI image volume as the target.
After quite a bit of trial and error I am still getting a small mismatch between the size of the CNN's output and the dimensions of the target. The CNN output is 208x224x160, but the source/target data are both 210x224x160. This causes a runtime error during the calculation of the loss. What's strange is that the dimension mismatch doesn't occur when I put in randomly generated data, the output has the same dimensions as the input.
What could be causing this error and how should I go about fixing it?
Here is the code:
import nibabel as nib
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
# function using nibabel to load a single volume from the disk
def load_Volume(filepath):
img = nib.load(filepath)
data = img.get_fdata()
return data
def preprocess_mri_data(data):
# Normalize the data, other pre-processing can be added
mean = np.mean(data)
std = np.std(data)
data = (data - mean) / std
return data
# Dataset class to use with the data loader. Pairs sources with targets.
class MRISource_Target(Dataset):
def __init__(self, source_dir, target_dir, transform=None):
self.source_dir = source_dir
self.target_dir = target_dir
self.source_filenames = os.listdir(source_dir)
self.target_filenames = os.listdir(target_dir)
self.transform = transform
def __len__(self):
return len(self.source_filenames)
def __getitem__(self, idx):
source_filepath = os.path.join(self.source_dir, self.source_filenames[idx])
target_filepath = os.path.join(self.target_dir, self.target_filenames[idx])
source_data = load_Volume(source_filepath)
target_data = load_Volume(target_filepath)
source_data = preprocess_mri_data(source_data)
target_data = preprocess_mri_data(target_data)
if self.transform:
source_data = self.transform(source_data)
target_data = self.transform(target_data)
return {'source': source_data, 'target': target_data}
# directories for the training and testing data
train_source_dir = '/content/drive/MyDrive/qsmData/Train/Source'
train_target_dir = '/content/drive/MyDrive/qsmData/Train/Target/'
test_source_dir = '/content/drive/MyDrive/qsmData/Test/Source/'
test_target_dir = '/content/drive/MyDrive/qsmData/Test/Target/'
# create the paired datasets
train_dataset = MRISource_Target(train_source_dir, train_target_dir)
test_dataset = MRISource_Target(test_source_dir, test_target_dir)
# make the datasets iteratable for training
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# visualize an arbitrary slice
def plot_mri_slice(volume, slice_num):
plt.imshow(volume[:, :, slice_num], cmap='gray')
plt.axis('off')
plt.show()
import torch
import torch.nn as nn
# Define the U-Net architecture
class UNet(nn.Module):
def __init__(self, input_channels, output_channels):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv3d(input_channels, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(32, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=2, stride=2)
)
self.middle = nn.Sequential(
nn.Conv3d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=2, stride=2)
)
self.decoder = nn.Sequential(
nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv3d(32, output_channels, kernel_size=3,padding=1),
#nn.Tanh() # Assuming magnetic susceptibility values are in a specific range
)
def forward(self, x):
x1 = self.encoder(x)
x2 = self.middle(x1)
x3 = self.decoder(x2)
return x3
# Example usage:
batch_size = 1
input_channels = 1 # Number of input channels (MRI phase)
output_channels = 1 # Number of output channels (Magnetic susceptibility)
depth = 64 # Updated depth to match cropped data
height = 64
width = 64
# Create the U-Net model
generator = UNet(input_channels, output_channels)
# Example input data
input_data = torch.randn(batch_size, input_channels, depth, height, width)
# Generate output
output = generator(input_data)
# Print the generated output shape
print("Generated Output Shape:", output.shape)
import nibabel as nib
def get_data_dimensions(filepath):
img = nib.load(filepath)
data = img.get_fdata()
return data.shape
source_filepath = '/content/drive/MyDrive/qsmData/Train/Source/normPhaseSubj1.nii'
target_filepath = '/content/drive/MyDrive/qsmData/Train/Target/cosmos1.nii.gz'
source_dimensions = get_data_dimensions(source_filepath)
target_dimensions = get_data_dimensions(target_filepath)
print("Source data dimensions:", source_dimensions)
print("Target data dimensions:", target_dimensions)
# Define the loss function and optimizer
criterion = nn.MSELoss(reduce=None)
optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
# Move the model to the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
num_epochs = 5
print_interval = 10
for epoch in range(num_epochs):
generator.train()
running_loss = 0.0
for i, batch in enumerate(train_loader, 1): # Enumerate to track batch index
source_data = batch['source'].to(device).unsqueeze(1).float() # Add the channel dimension
target_data = batch['target'].to(device).unsqueeze(1).float() # Add the channel dimension
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = generator(source_data)
print(outputs.shape)
print("Target shape:", target_data.shape)
# Compute loss
loss = criterion(outputs, target_data)
# Backpropagation and optimization
loss.backward()
optimizer.step()
running_loss += loss.item()
# Print average loss for the epoch
if i % print_interval == 0:
avg_loss = running_loss / print_interval
print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{i}/{len(train_loader)}], Loss: {avg_loss:.4f}')
running_loss = 0.0
predictions = []
generator.eval() # Set the model to evaluation mode
with torch.no_grad():
for batch in test_loader:
source_patches = batch['source'].to(device).unsqueeze(1).float() # Add the channel dimension
# Forward pass and get the predictions
outputs = generator(source_patches)
# Store the predictions in the list
predictions.append(outputs.cpu().squeeze().numpy())
I tried making a simpler architecture and still got dimension errors, in fact they were even larger. When I wasn't getting dimension errors I would just get out of memory errors. I also have tried verifying the dimensions of the data throughout different stages of the program, and even though the randomly generated data doesn't have a mismatch between input and output, my MRI data still does once it's put through the network.
I was able to fix the dimension errors by applying max pooling after the middle layer and the decoder.
I am not sure why this works, but now output and input sizes are consistent.
The prediction results look bad right now, but I'm pretty sure that's because I set the number of channels down to between 2 to 8, and only trained for 1 epoch.
It will be interesting to see how this architecture works as I apply normal hyper-parameters. I'm just glad that there are no more runtime errors or out of memory issues.