pythonmachine-learningvectorizationjax

Implementing a vectorized function over LinkedLists using Jax’s vmap function


Trying to implement a vectorized version of an algorithm (from computational geometry) using Jax. I have made the minimum working example using a LinkedList to particularly express my query (I am using a DCEL otherwise).

The idea is that this vectorized algorithm will be checking certain criteria over a DCEL. I have substituted this “criteria checking procedure” with a simple summation algorithm for the sake simplicity.


import jax
from jax import vmap
import jax.numpy as jnp

class Node: 
  
    # Constructor to initialize the node object 
    def __init__(self, data): 
        self.data = data 
        self.next = None

class LinkedList: 
  
    def __init__(self): 
        self.head = None
  
    def push(self, new_data): 
        new_node = Node(new_data) 
        new_node.next = self.head 
        self.head = new_node 

    def printList(self): 
        temp = self.head 
        while(temp): 
            print (temp.data,end=" ") 
            temp = temp.next

def summate(list) :
    prev = None
    current = list.head
    sum = 0
    while(current is not None): 
        sum += current.data
        next = current.next
        current = next
    return sum

list1 = LinkedList() 
list1.push(20) 
list1.push(4) 
list1.push(15) 
list1.push(85) 

list2 = LinkedList() 
list2.push(19)
list2.push(13)
list2.push(2)
list2.push(13)

#list(map(summate, ([list1, list2])))

vmap(summate)(jnp.array([list1, list2]))

I get the following error.

TypeError: Value '<__main__.LinkedList object at 0x1193799d0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

The objective is, if I have a set of say, 10,000 Linkedlists, I should be able to apply this summate function over each LinkedList in a vectorized fashion. I have implemented what I want in basic Python, but I want to do it in Jax as there is a larger probabilistic function which I will be using this subprocedure for (it’s a Markov Chain).

It might be the case that I am completely unable to work over such data structures over Jax as the error suggests that only numeric types are supported. Can I use pytrees in some way to mitigate this constraint?

It will be tempting to suggest I use a simple list from jnp, but I am using Linkedlist just as an example of a simple(st) data structure. As mentioned earlier, am actually working over a DCEL.

PS : the Linkedlist code was taken from GeeksForGeeks, as I wanted to come up with a minimum working example quickly.


Solution

  • The objective is, if I have a set of say, 10,000 Linkedlists, I should be able to apply this summate function over each LinkedList in a vectorized fashion.

    This goal is not feasible using JAX. You could register your class as a custom Pytree to make it work with JAX functions (see Extending pytrees), but this won't mean you can vectorize an operation over a list of such objects.

    JAX transformations like vmap and jit work for data stored with a struct-of-arrays pattern (e.g. a single LinkedList object containing arrays that represent multiple batched linked lists) not an array-of-structs pattern (e.g. a list of multiple LinkedList objects).

    Further, the algorithm you're using, based on a while loop, is not compatible with JAX transformations (See JAX sharp bits: control flow), and the dynamically sized tree of nodes will not fit into the static shape constraints of JAX programs.

    I'd love to point you in the right direction, but I think you either need to give up on using JAX, or give up on using dynamic linked lists. You won't be able to do both.