Trying to implement a vectorized version of an algorithm (from computational geometry) using Jax. I have made the minimum working example using a LinkedList to particularly express my query (I am using a DCEL otherwise).
The idea is that this vectorized algorithm will be checking certain criteria over a DCEL. I have substituted this “criteria checking procedure” with a simple summation algorithm for the sake simplicity.
import jax
from jax import vmap
import jax.numpy as jnp
class Node:
# Constructor to initialize the node object
def __init__(self, data):
self.data = data
self.next = None
class LinkedList:
def __init__(self):
self.head = None
def push(self, new_data):
new_node = Node(new_data)
new_node.next = self.head
self.head = new_node
def printList(self):
temp = self.head
while(temp):
print (temp.data,end=" ")
temp = temp.next
def summate(list) :
prev = None
current = list.head
sum = 0
while(current is not None):
sum += current.data
next = current.next
current = next
return sum
list1 = LinkedList()
list1.push(20)
list1.push(4)
list1.push(15)
list1.push(85)
list2 = LinkedList()
list2.push(19)
list2.push(13)
list2.push(2)
list2.push(13)
#list(map(summate, ([list1, list2])))
vmap(summate)(jnp.array([list1, list2]))
I get the following error.
TypeError: Value '<__main__.LinkedList object at 0x1193799d0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
The objective is, if I have a set of say, 10,000 Linkedlists, I should be able to apply this summate function over each LinkedList in a vectorized fashion. I have implemented what I want in basic Python, but I want to do it in Jax as there is a larger probabilistic function which I will be using this subprocedure for (it’s a Markov Chain).
It might be the case that I am completely unable to work over such data structures over Jax as the error suggests that only numeric types are supported. Can I use pytrees
in some way to mitigate this constraint?
It will be tempting to suggest I use a simple list from jnp, but I am using Linkedlist just as an example of a simple(st) data structure. As mentioned earlier, am actually working over a DCEL.
PS : the Linkedlist code was taken from GeeksForGeeks, as I wanted to come up with a minimum working example quickly.
The objective is, if I have a set of say, 10,000 Linkedlists, I should be able to apply this summate function over each LinkedList in a vectorized fashion.
This goal is not feasible using JAX. You could register your class as a custom Pytree to make it work with JAX functions (see Extending pytrees), but this won't mean you can vectorize an operation over a list of such objects.
JAX transformations like vmap
and jit
work for data stored with a struct-of-arrays pattern (e.g. a single LinkedList
object containing arrays that represent multiple batched linked lists) not an array-of-structs pattern (e.g. a list of multiple LinkedList
objects).
Further, the algorithm you're using, based on a while
loop, is not compatible with JAX transformations (See JAX sharp bits: control flow), and the dynamically sized tree of nodes will not fit into the static shape constraints of JAX programs.
I'd love to point you in the right direction, but I think you either need to give up on using JAX, or give up on using dynamic linked lists. You won't be able to do both.