I am trying to extend the elements of a TUDataset
dataset.
In particular, I have a dataset obtained via
dataset = TUDataset("PROTEIN", name=PROTEIN, use_node_attr=True)
I want to add a new vector-like feature to every entry of the dataset.
for i, current_g in enumerate(dataset):
nxgraph = nx.to_numpy_array(torch_geometric.utils.to_networkx(current_g) )
feature = do_something(nxgraph)
dataset[i].new_feature = feature
However, this code doesn't seem to work. As you can verify yourself, it's not possible to add attributes to an element of dataset
.
In [80]: dataset[2].test = 1
In [81]: dataset[2].test
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/workspace/grouptheoretical/new-experiments/HGP-SL-myfork/main.py in <cell line: 1>()
----> 1 dataset[2].test
AttributeError: 'Data' object has no attribute 'test'
In [82]: dataset[2].__setattr__('test', 1)
In [83]: dataset[2].test
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/workspace/grouptheoretical/new-experiments/HGP-SL-myfork/main.py in <cell line: 1>()
----> 1 dataset[2].test
AttributeError: 'Data' object has no attribute 'test'
An element in dataset
is a Data
from torch_geometric.data Data
.
I can create a new Data
element with all the features I want by using:
tmp=dataset[i].to_dict()
tmp['new_feature'] = feature
new_dataset[i]=torch_geometric.data.Data.from_dict(tmp)
However, I don't know how to create a TUDataset
dataset (Or the partent class of it) from a list of Data
elements. Do you know how?
Any idea on how to solve this problem? Thanks.
One elegant way to reach your goal is to define your transformation.
from torch_geometric.transforms import BaseTransform
class Add_Node_Feature(BaseTransform):
def __init__(self, parameters):
self.paramters= paramters # parameters you need
def __call__(self, data: Data) -> Data:
node_feature = data.x
data.x = do_something(node_feature)
return data
Then, you can apply this transformation when loading the dataset. This way, the dataset is modified, and new features will be added.
import torch_geometric.transforms as T
dataset = TUDataset("PROTEIN", name=PROTEIN, use_node_attr=True)
dataset.transform = T.Compose([Add_Node_Feature()])