I am very new to pytorch and pytorch-geometric. I need to load a dataset and then map an attribute to every object of the set that I will use later in the script. However I can't figure out how to do it.
I start the loading as
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='data/TUDataset', name='PROTEINS')
then I add the attribute. I tried (the value 3 is only for example, it will be a db query)
for data in dataset:
data.keys.append('szemeredi_id')
data.szemeredi_id = 3
or
for data in dataset:
data['szemeredi_id'] = 3
or
for i, s in enumerate(dataset):
dataset[i]['szemeredi_id'] = 3
or
for data in dataset:
setattr(data, 'szemeredi_id', 3)
but that attribute is always empty. I even tried to write a decorator class for the Data class as
class SzeData(Data):
def __init__(self, x=None, edge_index=None, edge_attr=None, y=None,
pos=None, normal=None, face=None, **kwargs):
super(SzeData, self).__init__(x, edge_index, edge_attr, y, pos, normal, face)
self.szemeredi_id = None
but if I try to replace the Data objects it raises the error TypeError: 'TUDataset' object does not support item assignment
or it does nothing if I use this solution.
Any suggestion is much appreciated. Thank you.
You can organize your modification process to each sample data
as a transform function, and then pass it to the transform
or pre_transform
(which depends on your need) parameter when constructing the dataset:
from torch_geometric.datasets import TUDataset
def transform(data):
data.szemeredi_id = 3
return data
dataset = TUDataset(root='data/TUDataset', name='PROTEINS', transform=transform)
# or dataset = TUDataset(root='data/TUDataset', name='PROTEINS', pre_transform=transform)
See the documentation of torch_geometric.data.Dataset
- transform (callable, optional) – A function/transform that takes in an
Data
object and returns a transformed version. The data object will be transformed before every access. (default:None
)- pre_transform (callable, optional) – A function/transform that takes in an
Data
object and returns a transformed version. The data object will be transformed before being saved to disk. (default:None
)
Edit:
The above method is unaware of the data
index in the dataset
, so if you want to add some index-related attributes, it won't help.
To add an index-related attribute (e.g. simply index
), I use the less elegant but more general approach as follows:
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='data/TUDataset', name='PROTEINS')
def add_attributes(dataset):
data_list = []
for i, data in enumerate(dataset):
data.index = i
data_list.append(data)
dataset.data, dataset.slices = dataset.collate(data_list)
return dataset
dataset = add_attributes(dataset)