I need a really fast vectorized maximal independent set algorithm implemented in pytorch, so I can use it for tasks with thousands of nodes in reasonable time.
I cannot use networkx, it is way too slow for my needs.
I don't need an exact algorithm, a rough greedy approximation will do the job for me. It just needs to be really fast.
The input is a simple adjacency matrix, and the return value should be an independent set.
It turns out it is pretty hard to do anything better than networkx implementation for cpu.
it is just way too good. But for gpu case we can actually do job better than networkx.
import torch
import networkx as nx
def drop_row_and_col(A: torch.Tensor, i) -> torch.Tensor:
"""
Remove the i-th row and i-th column from a 2D square tensor A.
Args:
A (torch.Tensor): 2D square tensor of shape (n, n)
i List[int]: index of row and column to remove (0-based)
Returns:
torch.Tensor: new tensor of shape (n-1, n-1)
"""
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError("Input must be a 2D square tensor.")
mask = torch.ones(A.shape[0], dtype=torch.bool, device=A.device)
mask[i] = False
return A[mask][:, mask]
def drop_under_index(array: torch.Tensor,i):
mask = torch.ones(array.shape[0], dtype=torch.bool, device=array.device)
mask[i]=False
return array[mask]
def get_maximal_ind_set_nx(adj_matrix: torch.Tensor):
"""
Maximal independent set computation using NetworkX.
This function mirrors the interface of the torch-based greedy version,
but internally converts the adjacency matrix to a NetworkX graph
and calls its built-in `maximal_independent_set` algorithm.
Args:
adj_matrix (torch.Tensor): adjacency matrix of shape (N, N)
(nonzero entries indicate edges)
Returns:
torch.Tensor: indices of nodes in the resulting maximal independent set
(as a 1D tensor on the same device as input)
"""
# Convert torch adjacency matrix -> numpy -> networkx graph
adj_np = adj_matrix.detach().cpu().numpy()
G = nx.from_numpy_array(adj_np)
# Compute maximal independent set
ind_set = nx.maximal_independent_set(G)
# Convert result back to torch tensor
return torch.tensor(ind_set, dtype=torch.long, device=adj_matrix.device)
# torch compile makes it faster and use less memory
# @torch.compile
def get_maximal_ind_set(adj_matrix,drop_at_once = 1):
"""
Greedy vertex removal maximal independent set approximation.
One by one removes nodes with largest degree and recalculates degrees until resulting adj matrix is empty.
adj_matrix: adjacency matrix N x N
drop_at_once: how many elements drop at once. Larger values allow to speed-up computation a lot.
"""
if 'cpu' in str(adj_matrix.device):
return get_maximal_ind_set_nx(adj_matrix)
node_indices = torch.arange(adj_matrix.shape[0],device=adj_matrix.device)
max_indep_set = adj_matrix.clone()
max_indep_set[node_indices,node_indices]=False
while True:
close_points = max_indep_set.sum(-1)
ind = close_points.argsort(descending=True)[:drop_at_once]
ind = ind[close_points[ind]>0]
if len(ind)==0:
break
node_indices=drop_under_index(node_indices,ind)
max_indep_set=drop_row_and_col(max_indep_set,ind)
return node_indices
Simplest usage is like this
adj = torch.randn((500,500))>0.5
adj[torch.arange(500),torch.arange(500)]=False
get_maximal_ind_set(adj,drop_at_once=3)
You can even do it on gpu (a lot faster than networkx)
adj = torch.randn((500,500))>0.5
adj[torch.arange(500),torch.arange(500)]=False
get_maximal_ind_set(adj.cuda(),drop_at_once=3).cpu()
Here is simple visualization of it's work
import networkx as nx
# ==== Generate random adjacency matrix ====
torch.manual_seed(0)
n = 10
adj_matrix = (torch.rand((n, n)) > 0.5).int()
adj_matrix = torch.triu(adj_matrix, 1) # upper triangle only
adj_matrix = adj_matrix + adj_matrix.T # make symmetric (undirected graph)
adj_matrix.fill_diagonal_(0)
# ==== Find maximal independent set ====
ind_set = get_maximal_ind_set(adj_matrix, drop_at_once=1)
print("Independent set indices:", ind_set.tolist())
# ==== Visualize ====
G = nx.Graph()
G.add_nodes_from(range(n))
for i in range(n):
for j in range(i + 1, n):
if adj_matrix[i, j]:
G.add_edge(i, j)
pos = nx.spring_layout(G, seed=42) # layout for consistent visualization
# Node colors: blue if in independent set, red otherwise
node_colors = ['tab:blue' if i in ind_set else 'tab:red' for i in G.nodes()]
plt.figure(figsize=(6, 6))
nx.draw(
G,
pos,
with_labels=True,
node_color=node_colors,
node_size=600,
font_color='white',
edge_color='gray',
)
plt.title("Graph with Maximal Independent Set (blue nodes)")
plt.show()