python-3.xpytorchnetworkxunsupervised-learningpytorch-geometric

Troubles in unsupervised domain adaptation with GCN


I am trying to implement an unsupervised domain adaptation network following the paper GCAN: Graph Convolutional Adversarial Network for Unsupervised Domain Adaptation, presented in CVPR in 2019 (can be found at this link). I have some trouble understanding some parts of the paper. enter image description here I reported the image found in the paper explaining the structure od the model. I have some troubles understanding if the input of the model is just one image or multiple, since there is a domain classification network that should classify the domain that the image comes from, but at the same time there is a part in which the alignment of the classes' centroid is evaluated. Moreover, there is no indication on how to compute the class centroid itself, and since I am not an expert in this matter, i wonder how it is possible to compute it and optimize it using the loss function given in the paper. The last thing I'm wondering about is an error that I get in the code (using pytorch to implement the solution). With this being the code I wrote for the model:

class GCAN(nn.Module):

  def __init__(self, num_classes, gcn_in_channels=256, gcn_out_channels=150):

    super(GCAN, self).__init__()
    self.cnn = resnet50(pretrained=True)
    resnet_features = self.cnn.fc.in_features
    combined_features = resnet_features + gcn_out_channels
    self.cnn = nn.Sequential(*list(self.cnn.children())[:-1])
    self.dsa = alexnet(pretrained=True)
    self.gcn = geometric_nn.GCNConv(in_channels=gcn_in_channels, 
                                  out_channels=gcn_out_channels)
    self.domain_alignment = nn.Sequential(
      nn.Linear(in_features=combined_features, 
                out_features=1024),
      nn.ReLU(),
      nn.Linear(in_features=1024, out_features=1024),
      nn.ReLU(),
      nn.Linear(in_features=1024, out_features=1),
      nn.Sigmoid()
    )
    self.classifier = nn.Sequential(
      nn.Linear(in_features=combined_features, out_features=1024),
      nn.Dropout(p=0.2),
      nn.ReLU(),
      nn.Linear(in_features=1024, out_features=1024),
      nn.Dropout(p=0.2),
      nn.ReLU(),
      nn.Linear(in_features=1024, out_features=num_classes),
      nn.Softmax()
    )


  def forward(self, xs):
    resnet_features = self.cnn(xs)
    scores = self.dsa(xs)
    scores = scores.cpu().detach().numpy()
    adjacency_matrix = np.matmul(scores, np.transpose(scores))
    graph = nx.from_numpy_matrix(adjacency_matrix) # networkx
    gcn_features = self.gcn(graph)
    concat_features = torch.cat((resnet_features, gcn_features))

    domain_classification = self.domain_alignment(concat_features)
    pseudo_label = self.classifier(concat_features)

    return domain_classification, pseudo_label

when I try to plot the summary I get the following error:

forward() missing 1 required positional argument: 'edge_index'

But looking at the documentation of the GCN convolution (which is the part that gives the error), I have given to the layer both in_channels and out_channels. What am I missing in this case?


Solution

  • So I managed to implement this work, it's very likely that it's not the best implementation you've seen but here it is in case you need it

    class GCANDataset(Dataset):
    
    def __init__(self, source_path, target_path, class_list, train_portion=0.8, transform=None):
    
        self.transform = transform
        self.file_counter = 0
        class_dict = dict()
        
        source_images = list()
        source_labels = list()
        target_images = list()
        target_labels = list()
        source_length = 0
        
        # taking the paths
        for folder in os.listdir(source_path):
            full_folder = os.path.join(source_path, folder)
            for img in os.listdir(full_folder):
                full_path = os.path.join(full_folder, img)
                self.file_counter += 1
                source_length += 1
                source_images.append(full_path)
                source_labels.append(folder)
                
    
        for folder in os.listdir(target_path):
            full_folder = os.path.join(target_path, folder)
            for img in os.listdir(full_folder):
                full_path = os.path.join(full_folder, img)
                self.file_counter += 1
                target_images.append(full_path)
                target_labels.append(folder)
    
        for i in range(len(class_list)):
            class_dict[class_list[i]] = i
        
        
        self.number_of_train = int(train_portion * self.file_counter)
        number_of_test = self.file_counter - self.number_of_train
    
        indexes = list(range(len(target_images)))
    
        target_test_split = random.sample(indexes, number_of_test)
        target_train_split = [item for item in indexes if item not in target_test_split]
        self.train_images = source_images + [target_images[i] for i in target_train_split]
        self.train_labels = source_labels + [target_labels[i] for i in target_train_split]
        self.train_domain = [0.] * source_length + [1.] * len(target_train_split)
        self.test_images = [target_images[i] for i in target_test_split]
        self.test_labels = [target_labels[i] for i in target_test_split]
    
        self.train_labels = [torch.Tensor([float(class_dict[label])]).to(torch.long) for label in self.train_labels]
        self.test_labels = [torch.Tensor([float(class_dict[label])]).to(torch.long) for label in self.test_labels]
    
    
    
    def __getitem__(self, index):
    
        if index < self.number_of_train:
            image = Image.open(self.train_images[index])
            label = self.train_labels[index][0]
            domain = self.train_domain[index]
        else:
            index -= self.number_of_train
            image = Image.open(self.test_images[index])
            label = self.test_labels[index][0]
            domain = torch.Tensor([1.])
    
        if self.transform is not None:
            image = self.transform(image)
        
        return image, label, torch.Tensor([domain])
    
    
    
    def __len__(self):
        return self.file_counter
    

    Basically the dataset class implements the dataset and returns an image with its label and domain (trated as 0 or 1). The split is already done in the dataset itself.

    class GCANModel(nn.Module):
    
    def __init__(self, num_classes,
                 gcn_hidden_channels=256, gcn_layers=5, 
                 gcn_out_channels=150, gcn_dropout=0.2):
        
        super(GCANModel, self).__init__()
        
        self.cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        features = self.cnn.fc.in_features
        self.combined_features = features + gcn_out_channels
    
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-1])
        self.dsa = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    
        gcn_in_channels = 1000  # output of the full resnet (DSA module)
    
        self.gcn = geometric_nn.GCN(in_channels=gcn_in_channels,
                                    hidden_channels=gcn_hidden_channels,
                                    num_layers=gcn_layers,
                                    out_channels=gcn_out_channels, 
                                    dropout=gcn_dropout)
    
        self.domain_classification = nn.Sequential(
            nn.Linear(in_features=self.combined_features, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=1),
            nn.Sigmoid()
        )
    
        self.fc1 = nn.Linear(in_features=self.combined_features, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=512)
        self.fc3 = nn.Linear(in_features=512, out_features=20)
        self.soft = nn.Softmax()
    
    
    
    def forward(self, x):
        
        features = self.cnn.forward(x)
        scores = self.dsa.forward(x)
    
        transposed_scores = torch.transpose(scores, 0, 1)
        adjacency_matrix = torch.matmul(scores, transposed_scores)
        sparse_adj_matrix = dense_to_sparse(adjacency_matrix)       
        
        edge_index = sparse_adj_matrix[0]
        graph = geometric_data(scores, edge_index=edge_index)
    
        gcn_features = self.gcn(graph.x, graph.edge_index)
        gcn_features = gcn_features.view(-1, 150, 1, 1)
        
        concat_features = torch.cat([features, gcn_features], dim=1)
        concat_features = concat_features.view(-1, self.combined_features)
        
        domain_classification = self.domain_classification(concat_features)
    
        pseudo_label = relu(self.fc1(concat_features))
        mid_out = self.fc2(pseudo_label)
        pseudo_label = relu(mid_out)
        pseudo_label = softmax(self.fc3(pseudo_label), dim=1)
        
        
        return domain_classification, pseudo_label, mid_out, scores
    

    The model itself follows the implementation described in the paper, a part from the gcn module itself, which in my implementation has 5 layers.

    def uda_domain_alignment_loss(domain_pred, domain_target):
    
        loss_function = nn.BCELoss()
        return loss_function(domain_pred.to(torch.float), domain_target.to(torch.float))
    
    def uda_classification_loss(x, predicted, target, domain):
        loss_function = nn.CrossEntropyLoss()
    
        # Removing target samples to not compute the loss on them
        x_copy = x.clone()
        predicted_copy = predicted.clone()
        target_copy = target.clone()
    
        for i in range(len(domain) - 1 , -1, -1):
            if domain[i] == 1:
                 x_copy = torch.cat([x_copy[:i], x_copy[i + 1:]])
                 predicted_copy = torch.cat([predicted_copy[:i], predicted_copy[i + 1:]])
                 target_copy = torch.cat([target_copy[:i], target_copy[i + 1:]])
    
         return loss_function(predicted_copy, target_copy)
    
    def uda_structure_aware_alignment_loss(scores, classes, domain, threshold=1):
        
        # removing target domain samples
        scores_copy = scores.clone()
        classes_copy = classes.clone()
    
        for i in range(len(domain) - 1 , -1, -1):
            if domain[i] == 1:
                scores_copy = torch.cat([scores_copy[:i], scores_copy[i + 1:]])
                classes_copy = torch.cat([classes_copy[:i], classes_copy[i + 1:]])
        
        # choosing the two categories
        classes_copy = classes_copy.detach().cpu().numpy()
        unique, counts = np.unique(classes_copy, return_counts=True)
        class_count_dict = dict(zip(unique, counts))
    
        source_cat = -1
        counter = 0
    
        for key in class_count_dict:
            counter += 1
            if source_cat == -1:
                if class_count_dict[key] > 1:
                    source_cat = key
            else:
                if class_count_dict[key] > 1 and random.random() > 0.01:
                    source_cat = key
    
        if counter < 2 or source_cat == -1:
            return 0
    
        # Choosing the samples
        first_sample = None
        second_sample = None
        third_sample = None
    
        
        for i in range(len(scores_copy)):
            if classes_copy[i] == source_cat:
                if first_sample is None:
                    first_sample = scores_copy[i]
                elif second_sample is None:
                    second_sample = scores_copy[i]
                else:
                    choice = random.choice([0, 1])
                    n = random.random()
                    if n > 0.5:
                        if choice == 0:
                            first_sample = scores_copy[i]
                        else:
                            second_sample = scores_copy[i]
            else:
                if third_sample is None:
                    third_sample = scores_copy[i]
                else:
                    n = random.random()
                    if n > 0.5:
                        third_sample = scores_copy[i]
    
        # Computing the actual loss
        same_class_squared_dist = sum((first_sample - second_sample) ** 2)
        diff_class_squared_dist = sum((first_sample - third_sample) ** 2)
    
        l = same_class_squared_dist - diff_class_squared_dist + threshold
    
        return max(l, 0)
    
    def uda_class_alignment_loss(x, domain, pseudo_classes, classes):
    
        # Dividing source and target samples and mantaining only necessary labels and pseudo labels
        x_source_copy = x.clone()
        x_target_copy = x.clone()
    
        pseudo_classes_target_copy = pseudo_classes.clone()
        pseudo_classes_target_copy = torch.argmax(pseudo_classes_target_copy, dim=1)    
        classes_source_copy = classes.clone()
        
        for i in range(len(domain) - 1 , -1, -1):
            
            if domain[i] == 0:
                x_target_copy = torch.cat([x_target_copy[:i], x_target_copy[i + 1:]])
                pseudo_classes_target_copy = torch.cat([pseudo_classes_target_copy[:i], pseudo_classes_target_copy[i + 1:]])
            else:
                x_source_copy = torch.cat([x_source_copy[:i], x_source_copy[i + 1:]])
                classes_source_copy = torch.cat([classes_source_copy[:i], classes_source_copy[i + 1:]])
    
    
        # Computing prototypes for each class as the mean of the extracted features
        source_dict = {}
        target_dict = {}
    
        source_dict = dict(zip(x_source_copy, classes_source_copy))
    
        final_source_dict = {}
    
        for key in source_dict:
            counter = 1
            sum = key
            for inner_key in source_dict:
                if not torch.all(torch.eq(key, inner_key)) and source_dict[key].item() == source_dict[inner_key].item():
                    counter += 1
                    sum = sum + inner_key
            
            prototype = sum / counter
            final_source_dict[source_dict[key].item()] = prototype
    
    
        target_dict = dict(zip(x_target_copy, pseudo_classes_target_copy))
    
        final_target_dict = {}
    
        for key in target_dict:
            counter = 1
            sum = key
            for inner_key in target_dict:
                if not torch.all(torch.eq(key, inner_key)) and target_dict[key].item() == target_dict[inner_key].item():
                    counter += 1
                    sum = sum + inner_key
            
            prototype = sum / counter
            final_target_dict[target_dict[key].item()] = prototype
    
        
        # Adding squared euclidean distances of prototypes of same classes. 
        # If a class is present in the source domain but not in the target domain
        # it is ignored
        sum_dists = 0
    
        for key in final_source_dict:
            if key in final_target_dict:
                s = ((final_source_dict[key] - final_target_dict[key]) ** 2).sum(axis=0)
                sum_dists = sum_dists + s
    
    
        return sum_dists
    
    
    def uda_loss(x, class_prediction, domain_prediction, target_class, target_domain, mid_results, scores):
    
        class_prediction_weight = 1               
        domain_prediction_weight = 0.0005
        structure_aware_alignment_weight = 0.0005
        class_alignment_loss_weight = 0.0005
    
        domain_loss = uda_domain_alignment_loss(domain_prediction, target_domain)
        classification_loss = uda_classification_loss(x, class_prediction, target_class, target_domain)
        triplet_loss = uda_structure_aware_alignment_loss(scores, target_class, target_domain)
        class_alignment_loss = uda_class_alignment_loss(mid_results, target_domain, class_prediction, target_class)
    
        return class_prediction_weight * classification_loss + domain_prediction_weight * domain_loss + \
                structure_aware_alignment_weight * triplet_loss + class_alignment_loss * class_alignment_loss_weight
    

    This is the formulation of the loss, with which I tried to stick as much as I could to the original implementation described in the paper. Lastly, the training loop looks like this:

    def train_uda_epoch(model, data, train_batches, optimizer, train_num, batch_size, device='cuda:0'):
        
        samples = 0
        cumulative_acc = 0
        cumulative_loss = 0
    
        model.train()
        
        for i in tqdm(range(train_num), desc='Progress'):
    
            input, real_class, domain = take_batch(data, i, train_batches)
    
            input = torch.from_numpy(input).to(device)
            real_class = torch.from_numpy(real_class).to(device)
            domain = torch.from_numpy(domain).to(device)
    
            output = model.forward(input)
            domain_classification, pseudo_label, mid_out, scores = output[0], output[1], output[2], output[3]
    
            loss = uda_loss(input, pseudo_label, domain_classification, real_class, domain, mid_out, scores)
            loss.backward()
    
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
    
            _, predicted = pseudo_label.max(1)
            acc_to_add = predicted.eq(real_class).sum().item()
    
            samples += batch_size
            cumulative_loss += loss
            cumulative_acc += acc_to_add
    
        return (cumulative_acc / samples) * 100, cumulative_loss / samples 
    

    In this version I implemented I used a "manual" data loading since I thought the data loader on pytorch gave me some memory issues, but it came out it wasn't that the problem and it's possible to use also pytorch's dataloader without any issue. The imports I used are reported below.

    import os
    import random
    
    from tqdm import tqdm
    import numpy as np
    from PIL import Image
    
    import torch
    import torch.nn as nn
    from torch.nn.functional import relu, softmax
    import torch.cuda as cuda
    from torch.optim import Adam
    from torch.utils.tensorboard import SummaryWriter
    from torch.utils.data import Dataset, DataLoader, random_split, Subset
    
    from torchvision.models import resnet50, ResNet50_Weights
    from torchvision.datasets import ImageFolder
    from torchvision import transforms
    
    from torch_geometric import nn as geometric_nn
    from torch_geometric.data import Data as geometric_data
    from torch_geometric.utils import dense_to_sparse