I am using the Zinc graph dataset via torch geometric which I access as
zinc_dataset = ZINC(root='my_path', split='train')
Each data element is a graph zinc_dataset[0]
looks like
Data(x=[33, 1], edge_index=[2, 72], edge_attr=[72], y=[1])
I have computed a tensor valued feature for each graph in the dataset. I have stored these tensors in a list with the ith element of the list being the feature for the ith graph in zinc_dataset.
I would like to add these new features to the data object. So ideally I want the result to be
Data(x=[33, 1], edge_index=[2, 72], edge_attr=[72], y=[1], new_feature=[33,12])
I have looked at the solution provided by How to add a new attribute to a torch_geometric.data Data object element? but that hasn't worked for me.
Could someone please help me take my list of new features and include them in the data object?
Thanks
To add your list of new features (e.g. List[Tensor]
, with each tensor corresponding to a graph in the dataset) to each torch_geometric.data.Data
object in a Dataset
like ZINC
You can do this by simply assigning your new tensor as an attribute of each Data
object.
Here’s how you can do it step-by-step:
import torch
from torch_geometric.datasets import ZINC
from torch_geometric.data import InMemoryDataset
# 1. Load the ZINC training dataset
zinc_dataset = ZINC(root='my_path', split='train')
# 2. Create a list of new features for each graph
# Replace this with your actual feature list (must match number of nodes per graph)
new_features = []
for data in zinc_dataset:
num_nodes = data.x.size(0) # data.x is [num_nodes, feature_dim]
new_feat = torch.randn(num_nodes, 12) # Example: [num_nodes, 12]
new_features.append(new_feat)
# 3. Define a custom dataset that injects new_feature into each graph's Data object
class ModifiedZINC(InMemoryDataset):
def __init__(self, original_dataset, new_features_list):
self.data_list = []
for i in range(len(original_dataset)):
data = original_dataset[i]
data.new_feature = new_features_list[i]
self.data_list.append(data)
super().__init__('.', transform=None, pre_transform=None)
self.data, self.slices = self.collate(self.data_list)
def __len__(self):
return len(self.data_list)
def get(self, idx):
return self.data_list[idx]
# 4. Create the modified dataset with new features
modified_dataset = ModifiedZINC(zinc_dataset, new_features)
# 5. Check the result
sample = modified_dataset[0]
print(sample)
print("Shape of new feature:", sample.new_feature.shape)
output:
Data(x=[33, 1], edge_index=[2, 72], edge_attr=[72], y=[1], new_feature=[33, 12])
Shape of new feature: torch.Size([33, 12])