pythongraphpytorchpytorch-geometric

Adding edge features in heterographs using PyGeometric


For some reason I'm not being able to assign features to edges using HeteroData from PyGeometric package. The version installed are Python 3.11.5, PyTorch 2.0.1+cu117 and TorchGeometric 2.3.1 running over ArchLinux in a virtual environment.

Given the following code snippet

import torch
from torch_geometric.data import HeteroData

data = HeteroData()

# Params
num_papers, num_paper_features = 5, 7
num_authors, num_author_features = 4, 10
num_edges = torch.randint(5, num_papers*num_authors, [1]).item()
author_writes_paper_num_features = 4

# Adding features to nodes
data['paper'].x = torch.randn(num_papers, num_paper_features)
data['author'].x = torch.randn(num_authors, num_author_features)

# Creating some random edges
author_edge_index = torch.randint(0, num_authors, [num_edges])
paper_edge_index = torch.randint(0, num_papers, [num_edges])
edge_index = torch.stack((author_edge_index, paper_edge_index))
data['author', 'writes', 'paper'].edge_index = edge_index

data = data.coalesce()

# Adding features to edges
data['author', 'writes', 'paper'].x = torch.randn(author_writes_paper_num_features, data.num_edges)

# Also tried using the transpose edge feature matrix 
# data['author', 'writes', 'paper'].x = torch.randn(data.num_edges, author_writes_paper_num_features)

But, while I can correctly fetch node features

data.num_node_features
# {'paper': 7, 'author': 10}

data.node_stores
# [{'x': tensor([[...]])},
#  {'x': tensor([[...]])}]

data.node_attrs()
# ['x']

The same does not happens with edge features

data.num_edge_features
# {('author', 'writes', 'paper'): 0}

Is ZERO. Even with a x entry in edge_stores.

data.edge_stores
#[{'edge_index': tensor([[0, 0, 0, 1, 1, 2, 2, 3, 3],
#         [0, 2, 4, 0, 1, 0, 3, 0, 2]]), 'x': tensor([[...]])}]

We can see that no feature attribute x is associated with edges

data.edge_attrs()
# ['edge_index']

Any suggestions?


Solution

  • Edges should use edge_attr Python attribute instead x. That is, you should instead use

    # Adding features to edges
    data['author', 'writes', 'paper'].edge_attr = torch.randn(author_writes_paper_num_features, data.num_edges)