pythonctypesjax

How can we cast a `ctypes.POINTER(ctypes.c_float)` to `int`?


I think this is a simple task, but I could not find a solution on the web to this. I have a external C++ library, which I'm using in my Python code, returning a ctypes.POINTER(ctypes.c_float) to me. I want to pass an array of these pointers to a jax.vmap function. The problem is that jax does not accept the ctypes.POINTER(ctypes.c_float) type. So, can I somehow cast this pointer to an ordinary int. Technically, this is clearly possible. But how do I do this in Python?

Here is an example:

lib = ctypes.cdll.LoadLibrary(lib_path)
lib.foo.argtypes = None
lib.foo.restype = ctypes.POINTER(ctypes.c_float)

bar = jax.vmap(lambda : dummy lib.foo())(jax.numpy.empty(16))

x = jax.numpy.empty(16, 256, 256, 1)
y = jax.vmap(lib.bar, in_axes = (0, 1))(x, bar)

So, I want to invoke lib.foo 16-times so that I have an array bar containing all the pointers. Then I want to invoke another library function lib.bar which expects bar together with another (batched) parameter x.

The problem is that jax claims that ctypes.POINTER(ctypes.c_float) is not a valid jax type. This is why I think the solution is to cast the pointers to ints and store those ints in bar instead.


Solution

  • Listing:

    Here's a piece of code exemplifying how to handle pointers and their addresses. The trick is to use ctypes.addressof (documented in the 2nd URL).

    code00.py:

    #!/usr/bin/env python
    
    import ctypes as cts
    import sys
    
    
    CType = cts.c_float
    CTypePtr = cts.POINTER(CType)
    
    
    def ctype_pointer(seq):  # Helper
        CTypeArr = (CType * len(seq))
        ctype_arr = CTypeArr(*seq)
        return cts.cast(ctype_arr, CTypePtr)
    
    
    def pointer_elements(addr, count):  # Helper
        return tuple(CType.from_address(addr + i * cts.sizeof(CType)).value for i in range(count))
    
    
    def main(*argv):
        seq = (2.718182, -3.141593, 1.618034, -0.618034, 0)
        ptr = ctype_pointer(seq)
        print(f"Pointer: {ptr}")
        print(f"\nPointer elements: {tuple(ptr[i] for i in range(len(seq)))}")  # Check if pointer has correct data
        ptr_addr = cts.addressof(ptr.contents)  # @TODO - cfati: Straightforward
        print(f"\nAddress: {ptr_addr} (0x{ptr_addr:016X})\nElements from address: {pointer_elements(ptr_addr, len(seq))}")
        ptr_addr0 = cts.cast(ptr, cts.c_void_p).value  # @TODO - cfati: Alternative
        print(f"\nAddresses match: {ptr_addr == ptr_addr0}")
    
    
    if __name__ == "__main__":
        print(
            "Python {:s} {:03d}bit on {:s}\n".format(
                " ".join(elem.strip() for elem in sys.version.split("\n")),
                64 if sys.maxsize > 0x100000000 else 32,
                sys.platform,
            )
        )
        rc = main(*sys.argv[1:])
        print("\nDone.\n")
        sys.exit(rc)
    

    Notes:

    Output:

    (py_pc064_03.08_test0_lancer) [cfati@cfati-5510-0:/mnt/e/Work/Dev/StackExchange/StackOverflow/q078366208]> python ./code00.py 
    Python 3.8.19 (default, Apr  6 2024, 17:58:10) [GCC 11.4.0] 064bit on linux
    
    Pointer: <__main__.LP_c_float object at 0x7203e97e7d40>
    
    Pointer elements: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0)
    
    Address: 125361127594576 (0x00007203E97A9A50)
    Elements from address: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0)
    
    Addresses match: True
    
    Done.