For some reason I'm not being able to assign features to edges using HeteroData
from PyGeometric package. The version installed are Python 3.11.5, PyTorch 2.0.1+cu117 and TorchGeometric 2.3.1 running over ArchLinux in a virtual environment.
Given the following code snippet
import torch
from torch_geometric.data import HeteroData
data = HeteroData()
# Params
num_papers, num_paper_features = 5, 7
num_authors, num_author_features = 4, 10
num_edges = torch.randint(5, num_papers*num_authors, [1]).item()
author_writes_paper_num_features = 4
# Adding features to nodes
data['paper'].x = torch.randn(num_papers, num_paper_features)
data['author'].x = torch.randn(num_authors, num_author_features)
# Creating some random edges
author_edge_index = torch.randint(0, num_authors, [num_edges])
paper_edge_index = torch.randint(0, num_papers, [num_edges])
edge_index = torch.stack((author_edge_index, paper_edge_index))
data['author', 'writes', 'paper'].edge_index = edge_index
data = data.coalesce()
# Adding features to edges
data['author', 'writes', 'paper'].x = torch.randn(author_writes_paper_num_features, data.num_edges)
# Also tried using the transpose edge feature matrix
# data['author', 'writes', 'paper'].x = torch.randn(data.num_edges, author_writes_paper_num_features)
But, while I can correctly fetch node features
data.num_node_features
# {'paper': 7, 'author': 10}
data.node_stores
# [{'x': tensor([[...]])},
# {'x': tensor([[...]])}]
data.node_attrs()
# ['x']
The same does not happens with edge features
data.num_edge_features
# {('author', 'writes', 'paper'): 0}
Is ZERO. Even with a x
entry in edge_stores.
data.edge_stores
#[{'edge_index': tensor([[0, 0, 0, 1, 1, 2, 2, 3, 3],
# [0, 2, 4, 0, 1, 0, 3, 0, 2]]), 'x': tensor([[...]])}]
We can see that no feature attribute x
is associated with edges
data.edge_attrs()
# ['edge_index']
Any suggestions?
Edges should use edge_attr
Python attribute instead x
. That is, you should instead use
# Adding features to edges
data['author', 'writes', 'paper'].edge_attr = torch.randn(author_writes_paper_num_features, data.num_edges)