pythonmachine-learningdeep-learningpytorchsemantic-segmentation

Multiclass UNet with n-dimensional satellite images


I'm trying to use a UNet in Pytorch to extract prediction masks from multidimensional (8 band) satellite images. I'm having trouble getting the prediction masks to look somewhat expected/coherent. I'm not sure if the issue is the way my training data is formatted, my training code, or the code I'm using to make predictions. My suspicion is that it is the way my training data is being fed to the model. I have 8 band satellite images and single band masks with values ranging 0-n number of classes with 0 being background and 1-n being target labels like this:

enter image description here

With the image shape being (8, 512, 512) and the mask shape being (512, 512) in the case of the single channel example, (512, 512, 8) in the OHE case, and (512, 512, 3) in the stacked case.

Some masks may contain all class labels, some may only have a couple or be background labels only. I've tried using these single channel masks, I've also converted them into 3 channel masks with the first channel being all the labels for a given image, and I've also tried one hot encoding them such that each mask is 0-n dimensions and each channel a different label with binary 0-1 for background/target.

EDIT After changing the softmax dim=2, the outputs started looking a little better. However, it appears the model is not learning at all after the first few warmup epochs as the training loss decreases initially but then immediately plateaus or increases and the prediction masks stop making sense (either all black or random blobs). I suspect there is an issue with my training pipeline (below) or possibly due to the class imbalance with class 0 (background).

import os
import torch
import numpy as np
from skimage import io
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import segmentation_models_pytorch as smp

image_dir = r'test_segmentation\images'
mask_dir = r'test_segmentation\masks'

data_dir=r'unet_training'
os.makedirs(data_dir, exist_ok=True)

model_dir = os.path.join(data_dir, 'models')
os.makedirs(model_dir, exist_ok=True)

pred_dir = os.path.join(data_dir, 'predictions')
os.makedirs(pred_dir, exist_ok=True)

num_bands = 8
num_classes = 9
epochs = 10
learning_rate = 0.001
weight_decay = 0
encoder = 'resnet50'
encoder_weights = 'imagenet'

model = smp.Unet(in_channels=num_bands, encoder_name=encoder, encoder_weights=encoder_weights, classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_function = nn.CrossEntropyLoss() if num_classes > 1 else nn.BCEWithLogitsLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for epoch in range(1, epochs + 1):
    train_loss = 0
    val_loss = 0 

    train_loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch} Training")
    
    model.train()
    
    for batch_idx, (data, targets) in train_loop:
        optimizer.zero_grad()
        
        data = data.float().to(device)
        targets = targets.long().to(device)
        predictions = model(data)
        loss = loss_function(predictions, targets)
        
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
    
        train_loop.set_postfix(loss=train_loss)

    val_loop = tqdm(enumerate(val_loader), total=len(val_loader), desc=f"Epoch {epoch} Validation")

    model.eval()
    
    for batch_idx, (data, targets) in val_loop:
        data, targets = data.to(device).float(), targets.to(device).long()
        preds = model(data)

        val_loss = loss_function(preds, targets).item()

        softmax = torch.nn.Softmax(dim=2)
        preds = torch.argmax(softmax(preds), dim=1).cpu().numpy()
        preds = np.array(preds[0, :, :], dtype=np.uint8)
        labels = np.array(targets.cpu().numpy()[0, :, :], dtype=np.uint8)

        #save prediction and label mask
        pred_path = os.path.join(pred_dir, f"{epoch}_{batch_idx}_pred.png")
        label_path = os.path.join(pred_dir, f"{epoch}_{batch_idx}_label.png")
        io.imsave(pred_path, preds)
        io.imsave(label_path, labels)

        val_loop.set_postfix(loss=val_loss)
    
    avg_train_loss = train_loss / (batch_idx + 1)
    avg_val_loss = val_loss/ (batch_idx + 1)

    print(f"\nEpoch {epoch} Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}")

    checkpoint_name = os.path.join(model_dir, f"{modeltype}_bands{num_bands}_classes{num_classes}_{encoder}_{learning_rate}_{epoch}.pt")
    
    if epoch == 1:
        torch.save(model.state_dict(), checkpoint_name)
    elif epoch % 10 == 0:
        torch.save(model.state_dict(), checkpoint_name)
    elif epoch == epochs:
        torch.save(model.state_dict(), checkpoint_name)
    else:
        pass

Solution

  • Change to softmax = torch.nn.Softmax(dim=1) in order to softmax over the channels dimension dim=1

    In the training loop that starts for batch_idx, (data, targets) in train_loop:, check the following:


    Original reply:

    Does the train loss go down at all? Useful to print it out at each epoch.

    Try fitting just a single image or one small batch - keep running it until the loss goes down further and further. A sensible output should start emerging, like blobs roughly in the right place. If not, it suggests the pipeline is broken somewhere since it's failing to learn at all

    It might also be worth working with down-sampled images initially and limiting the mask to a single channel. They will help the net converge more quickly and highlight convergence issues.