pythonalgorithmpytorch

Fast vectorized maximal independent set greedy algorithm


I need a really fast vectorized maximal independent set algorithm implemented in pytorch, so I can use it for tasks with thousands of nodes in reasonable time.

I cannot use networkx, it is way too slow for my needs.

I don't need an exact algorithm, a rough greedy approximation will do the job for me. It just needs to be really fast.

The input is a simple adjacency matrix, and the return value should be an independent set.


Solution

  • It turns out it is pretty hard to do anything better than networkx implementation for cpu.

    it is just way too good. But for gpu case we can actually do job better than networkx.

    import torch
    import networkx as nx
    
    def drop_row_and_col(A: torch.Tensor, i) -> torch.Tensor:
        """
        Remove the i-th row and i-th column from a 2D square tensor A.
    
        Args:
            A (torch.Tensor): 2D square tensor of shape (n, n)
            i List[int]: index of row and column to remove (0-based)
    
        Returns:
            torch.Tensor: new tensor of shape (n-1, n-1)
        """
        if A.ndim != 2 or A.shape[0] != A.shape[1]:
            raise ValueError("Input must be a 2D square tensor.")
    
        mask = torch.ones(A.shape[0], dtype=torch.bool, device=A.device)
        mask[i] = False
        return A[mask][:, mask]
    
    def drop_under_index(array: torch.Tensor,i):
        mask = torch.ones(array.shape[0], dtype=torch.bool, device=array.device)
        mask[i]=False
        return array[mask]
    
    
    def get_maximal_ind_set_nx(adj_matrix: torch.Tensor):
        """
        Maximal independent set computation using NetworkX.
    
        This function mirrors the interface of the torch-based greedy version,
        but internally converts the adjacency matrix to a NetworkX graph
        and calls its built-in `maximal_independent_set` algorithm.
    
        Args:
            adj_matrix (torch.Tensor): adjacency matrix of shape (N, N)
                                       (nonzero entries indicate edges)
    
        Returns:
            torch.Tensor: indices of nodes in the resulting maximal independent set
                          (as a 1D tensor on the same device as input)
        """
        # Convert torch adjacency matrix -> numpy -> networkx graph
        adj_np = adj_matrix.detach().cpu().numpy()
        G = nx.from_numpy_array(adj_np)
    
        # Compute maximal independent set
        ind_set = nx.maximal_independent_set(G)
    
        # Convert result back to torch tensor
        return torch.tensor(ind_set, dtype=torch.long, device=adj_matrix.device)
    
    # torch compile makes it faster and use less memory
    # @torch.compile
    def get_maximal_ind_set(adj_matrix,drop_at_once = 1):
        """
        Greedy vertex removal maximal independent set approximation.
        
        One by one removes nodes with largest degree and recalculates degrees until resulting adj matrix is empty.
        
        adj_matrix: adjacency matrix N x N
        drop_at_once: how many elements drop at once. Larger values allow to speed-up computation a lot.
        """
        if 'cpu' in str(adj_matrix.device):
            return get_maximal_ind_set_nx(adj_matrix)
        node_indices = torch.arange(adj_matrix.shape[0],device=adj_matrix.device)
        max_indep_set = adj_matrix.clone()
        max_indep_set[node_indices,node_indices]=False
    
        while True:
            close_points = max_indep_set.sum(-1)
            ind = close_points.argsort(descending=True)[:drop_at_once]
            ind = ind[close_points[ind]>0]
            if len(ind)==0:
                break
            node_indices=drop_under_index(node_indices,ind)
            max_indep_set=drop_row_and_col(max_indep_set,ind)
        return node_indices
    
    

    Simplest usage is like this

    adj = torch.randn((500,500))>0.5
    adj[torch.arange(500),torch.arange(500)]=False
    get_maximal_ind_set(adj,drop_at_once=3)
    

    You can even do it on gpu (a lot faster than networkx)

    adj = torch.randn((500,500))>0.5
    adj[torch.arange(500),torch.arange(500)]=False
    get_maximal_ind_set(adj.cuda(),drop_at_once=3).cpu()
    

    Here is simple visualization of it's work

    import networkx as nx
    
    # ==== Generate random adjacency matrix ====
    torch.manual_seed(0)
    n = 10
    adj_matrix = (torch.rand((n, n)) > 0.5).int()
    adj_matrix = torch.triu(adj_matrix, 1)  # upper triangle only
    adj_matrix = adj_matrix + adj_matrix.T   # make symmetric (undirected graph)
    adj_matrix.fill_diagonal_(0)
    
    # ==== Find maximal independent set ====
    ind_set = get_maximal_ind_set(adj_matrix, drop_at_once=1)
    print("Independent set indices:", ind_set.tolist())
    
    # ==== Visualize ====
    G = nx.Graph()
    G.add_nodes_from(range(n))
    for i in range(n):
        for j in range(i + 1, n):
            if adj_matrix[i, j]:
                G.add_edge(i, j)
    
    pos = nx.spring_layout(G, seed=42)  # layout for consistent visualization
    
    # Node colors: blue if in independent set, red otherwise
    node_colors = ['tab:blue' if i in ind_set else 'tab:red' for i in G.nodes()]
    
    plt.figure(figsize=(6, 6))
    nx.draw(
        G,
        pos,
        with_labels=True,
        node_color=node_colors,
        node_size=600,
        font_color='white',
        edge_color='gray',
    )
    plt.title("Graph with Maximal Independent Set (blue nodes)")
    plt.show()
    

    visual