to_dense_batch
(doc) converts mini-batch to dense batch. How to convert the dense batch back to the mini-batch? Is there something like "from_dense_batch", which accept dense_batch
and mask
, and gives mini-batched data
?
I find a possible work around, but I not sure where it is the best implementation. The following is my code with test example.
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
import torch
from torch_geometric.utils import to_dense_batch
def from_dense_batch(dense_bath, mask):
# dense batch, B, N, F
# mask, B, N
B, N, F = dense_bath.size()
flatten_dense_batch = dense_bath.view(-1, F)
flatten_mask = mask.view(-1)
data_x = flatten_dense_batch[flatten_mask, :]
num_nodes = torch.sum(mask, dim=1) # B, like 3,4,3
pr_value = torch.cumsum(num_nodes, dim=0) # B, like 3,7,10
indicator_vector = torch.zeros(torch.sum(num_nodes, dim=0))
indicator_vector[pr_value[:-1]] = 1 # num_of_nodes, 0,0,0,1,0,0,0,1,0,0,1
data_batch = torch.cumsum(indicator_vector, dim=0) # num_of_nodes, 0,0,0,1,1,1,1,1,2,2,2
return data_x, data_batch
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
test_batch = next(iter(loader))
dense_data, mask = to_dense_batch(test_batch.x, test_batch.batch)
output_data, output_batch = from_dense_batch(dense_data, mask)
print((test_batch.x == output_data).all())
print((test_batch.batch == output_batch).all())