
Implementing a vectorized function over LinkedLists using Jax’s vmap function

Trying to implement a vectorized version of an algorithm (from computational geometry) using Jax. I have made the minimum working example using a LinkedList to particularly express my query (I am using a DCEL otherwise).

The idea is that this vectorized algorithm will be checking certain criteria over a DCEL. I have substituted this “criteria checking procedure” with a simple summation algorithm for the sake simplicity.

import jax
from jax import vmap
import jax.numpy as jnp

class Node: 
    # Constructor to initialize the node object 
    def __init__(self, data): = data = None

class LinkedList: 
    def __init__(self): 
        self.head = None
    def push(self, new_data): 
        new_node = Node(new_data) = self.head 
        self.head = new_node 

    def printList(self): 
        temp = self.head 
            print (,end=" ") 
            temp =

def summate(list) :
    prev = None
    current = list.head
    sum = 0
    while(current is not None): 
        sum +=
        next =
        current = next
    return sum

list1 = LinkedList() 

list2 = LinkedList() 

#list(map(summate, ([list1, list2])))

vmap(summate)(jnp.array([list1, list2]))

I get the following error.

TypeError: Value '<__main__.LinkedList object at 0x1193799d0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

The objective is, if I have a set of say, 10,000 Linkedlists, I should be able to apply this summate function over each LinkedList in a vectorized fashion. I have implemented what I want in basic Python, but I want to do it in Jax as there is a larger probabilistic function which I will be using this subprocedure for (it’s a Markov Chain).

It might be the case that I am completely unable to work over such data structures over Jax as the error suggests that only numeric types are supported. Can I use pytrees in some way to mitigate this constraint?

It will be tempting to suggest I use a simple list from jnp, but I am using Linkedlist just as an example of a simple(st) data structure. As mentioned earlier, am actually working over a DCEL.

PS : the Linkedlist code was taken from GeeksForGeeks, as I wanted to come up with a minimum working example quickly.


  • The objective is, if I have a set of say, 10,000 Linkedlists, I should be able to apply this summate function over each LinkedList in a vectorized fashion.

    This goal is not feasible using JAX. You could register your class as a custom Pytree to make it work with JAX functions (see Extending pytrees), but this won't mean you can vectorize an operation over a list of such objects.

    JAX transformations like vmap and jit work for data stored with a struct-of-arrays pattern (e.g. a single LinkedList object containing arrays that represent multiple batched linked lists) not an array-of-structs pattern (e.g. a list of multiple LinkedList objects).

    Further, the algorithm you're using, based on a while loop, is not compatible with JAX transformations (See JAX sharp bits: control flow), and the dynamically sized tree of nodes will not fit into the static shape constraints of JAX programs.

    I'd love to point you in the right direction, but I think you either need to give up on using JAX, or give up on using dynamic linked lists. You won't be able to do both.