pytorchpytorch-dataloaderpytorch-geometric

How to add a new attribute to a torch_geometric.data Data object element?


I am trying to extend the elements of a TUDataset dataset. In particular, I have a dataset obtained via

dataset = TUDataset("PROTEIN", name=PROTEIN, use_node_attr=True)

I want to add a new vector-like feature to every entry of the dataset.

for i, current_g in enumerate(dataset):

    nxgraph = nx.to_numpy_array(torch_geometric.utils.to_networkx(current_g) )
    feature = do_something(nxgraph)
    dataset[i].new_feature = feature

However, this code doesn't seem to work. As you can verify yourself, it's not possible to add attributes to an element of dataset.

In [80]: dataset[2].test = 1

In [81]: dataset[2].test
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/workspace/grouptheoretical/new-experiments/HGP-SL-myfork/main.py in <cell line: 1>()
----> 1 dataset[2].test

AttributeError: 'Data' object has no attribute 'test'

In [82]: dataset[2].__setattr__('test', 1)

In [83]: dataset[2].test
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/workspace/grouptheoretical/new-experiments/HGP-SL-myfork/main.py in <cell line: 1>()
----> 1 dataset[2].test

AttributeError: 'Data' object has no attribute 'test'

An element in dataset is a Data from torch_geometric.data Data.

I can create a new Data element with all the features I want by using:

tmp=dataset[i].to_dict()
tmp['new_feature'] = feature
new_dataset[i]=torch_geometric.data.Data.from_dict(tmp)

However, I don't know how to create a TUDataset dataset (Or the partent class of it) from a list of Data elements. Do you know how?

Any idea on how to solve this problem? Thanks.


Solution

  • One elegant way to reach your goal is to define your transformation.

    from torch_geometric.transforms import BaseTransform
    
    class Add_Node_Feature(BaseTransform):
        def __init__(self, parameters):
            self.paramters= paramters  # parameters you need
        def __call__(self, data: Data) -> Data:
            node_feature = data.x
            data.x = do_something(node_feature)
            return data
    

    Then, you can apply this transformation when loading the dataset. This way, the dataset is modified, and new features will be added.

    import torch_geometric.transforms as T
    dataset = TUDataset("PROTEIN", name=PROTEIN, use_node_attr=True)
    dataset.transform = T.Compose([Add_Node_Feature()])