pythonpython-3.xpytorchdatasetpytorch-geometric

Add attribute to object of dataset


I am very new to pytorch and pytorch-geometric. I need to load a dataset and then map an attribute to every object of the set that I will use later in the script. However I can't figure out how to do it.

I start the loading as

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='PROTEINS')

then I add the attribute. I tried (the value 3 is only for example, it will be a db query)

for data in dataset:
    data.keys.append('szemeredi_id')
    data.szemeredi_id = 3

or

for data in dataset:
    data['szemeredi_id'] = 3

or

for i, s in enumerate(dataset):
    dataset[i]['szemeredi_id'] = 3

or

for data in dataset:
    setattr(data, 'szemeredi_id', 3)

but that attribute is always empty. I even tried to write a decorator class for the Data class as

class SzeData(Data):
    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None,
                 pos=None, normal=None, face=None, **kwargs):
        super(SzeData, self).__init__(x, edge_index, edge_attr, y, pos, normal, face)
        self.szemeredi_id = None

but if I try to replace the Data objects it raises the error TypeError: 'TUDataset' object does not support item assignment or it does nothing if I use this solution.

Any suggestion is much appreciated. Thank you.


Solution

  • You can organize your modification process to each sample data as a transform function, and then pass it to the transform or pre_transform(which depends on your need) parameter when constructing the dataset:

    from torch_geometric.datasets import TUDataset
    
    def transform(data):
        data.szemeredi_id = 3
        return data
    
    dataset = TUDataset(root='data/TUDataset', name='PROTEINS', transform=transform)
    # or dataset = TUDataset(root='data/TUDataset', name='PROTEINS', pre_transform=transform)
    

    See the documentation of torch_geometric.data.Dataset

    • transform (callable, optional) – A function/transform that takes in an Data object and returns a transformed version. The data object will be transformed before every access. (default: None)
    • pre_transform (callable, optional) – A function/transform that takes in an Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

    Edit:

    The above method is unaware of the data index in the dataset, so if you want to add some index-related attributes, it won't help.

    To add an index-related attribute (e.g. simply index), I use the less elegant but more general approach as follows:

    from torch_geometric.datasets import TUDataset
    
    dataset = TUDataset(root='data/TUDataset', name='PROTEINS')
    
    def add_attributes(dataset):
        data_list = []
        for i, data in enumerate(dataset):
            data.index = i
            data_list.append(data)
        dataset.data, dataset.slices = dataset.collate(data_list)
        return dataset
    
    dataset = add_attributes(dataset)