pytorchpytorch-geometric

What is the inverse of `to_dense_batch` in Pytorch Geometric?


to_dense_batch(doc) converts mini-batch to dense batch. How to convert the dense batch back to the mini-batch? Is there something like "from_dense_batch", which accept dense_batch and mask, and gives mini-batched data?


Solution

  • I find a possible work around, but I not sure where it is the best implementation. The following is my code with test example.

    from torch_geometric.datasets import TUDataset
    from torch_geometric.loader import DataLoader
    import torch
    from torch_geometric.utils import to_dense_batch
    
    
    def from_dense_batch(dense_bath, mask):
        # dense batch, B, N, F
        # mask, B, N
        B, N, F = dense_bath.size()
        flatten_dense_batch = dense_bath.view(-1, F)
        flatten_mask = mask.view(-1)
        data_x = flatten_dense_batch[flatten_mask, :]
        num_nodes = torch.sum(mask, dim=1)  # B, like 3,4,3
        pr_value = torch.cumsum(num_nodes, dim=0)  # B, like 3,7,10
        indicator_vector = torch.zeros(torch.sum(num_nodes, dim=0))
        indicator_vector[pr_value[:-1]] = 1  # num_of_nodes, 0,0,0,1,0,0,0,1,0,0,1
        data_batch = torch.cumsum(indicator_vector, dim=0)  # num_of_nodes, 0,0,0,1,1,1,1,1,2,2,2
        return data_x, data_batch
    
    
    dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    test_batch = next(iter(loader))
    dense_data, mask = to_dense_batch(test_batch.x, test_batch.batch)
    output_data, output_batch = from_dense_batch(dense_data, mask)
    print((test_batch.x == output_data).all())
    print((test_batch.batch == output_batch).all())