I think this is a simple task, but I could not find a solution on the web to this. I have a external C++ library, which I'm using in my Python code, returning a ctypes.POINTER(ctypes.c_float) to me. I want to pass an array of these pointers to a jax.vmap function. The problem is that jax does not accept the ctypes.POINTER(ctypes.c_float) type. So, can I somehow cast this pointer to an ordinary int. Technically, this is clearly possible. But how do I do this in Python?
Here is an example:
lib = ctypes.cdll.LoadLibrary(lib_path)
lib.foo.argtypes = None
lib.foo.restype = ctypes.POINTER(ctypes.c_float)
bar = jax.vmap(lambda : dummy lib.foo())(jax.numpy.empty(16))
x = jax.numpy.empty(16, 256, 256, 1)
y = jax.vmap(lib.bar, in_axes = (0, 1))(x, bar)
So, I want to invoke lib.foo 16-times so that I have an array bar containing all the pointers. Then I want to invoke another library function lib.bar which expects bar together with another (batched) parameter x.
The problem is that jax claims that ctypes.POINTER(ctypes.c_float) is not a valid jax type. This is why I think the solution is to cast the pointers to ints and store those ints in bar instead.
Listing:
[SO]: C function called from Python via ctypes returns incorrect value (@CristiFati's answer) - a common pitfall when working with CTypes (calling functions)
[Python.Docs]: ctypes - A foreign function library for Python
Here's a piece of code exemplifying how to handle pointers and their addresses. The trick is to use ctypes.addressof (documented in the 2nd URL).
code00.py:
#!/usr/bin/env python
import ctypes as cts
import sys
CType = cts.c_float
CTypePtr = cts.POINTER(CType)
def ctype_pointer(seq): # Helper
CTypeArr = (CType * len(seq))
ctype_arr = CTypeArr(*seq)
return cts.cast(ctype_arr, CTypePtr)
def pointer_elements(addr, count): # Helper
return tuple(CType.from_address(addr + i * cts.sizeof(CType)).value for i in range(count))
def main(*argv):
seq = (2.718182, -3.141593, 1.618034, -0.618034, 0)
ptr = ctype_pointer(seq)
print(f"Pointer: {ptr}")
print(f"\nPointer elements: {tuple(ptr[i] for i in range(len(seq)))}") # Check if pointer has correct data
ptr_addr = cts.addressof(ptr.contents) # @TODO - cfati: Straightforward
print(f"\nAddress: {ptr_addr} (0x{ptr_addr:016X})\nElements from address: {pointer_elements(ptr_addr, len(seq))}")
ptr_addr0 = cts.cast(ptr, cts.c_void_p).value # @TODO - cfati: Alternative
print(f"\nAddresses match: {ptr_addr == ptr_addr0}")
if __name__ == "__main__":
print(
"Python {:s} {:03d}bit on {:s}\n".format(
" ".join(elem.strip() for elem in sys.version.split("\n")),
64 if sys.maxsize > 0x100000000 else 32,
sys.platform,
)
)
rc = main(*sys.argv[1:])
print("\nDone.\n")
sys.exit(rc)
Notes:
Although it adds a bit of complexity, I introduced the CType "layer" to show that it should work with any type, not just float (as long as the values in the sequence are of that type)
The only truly relevant lines are those marked with @TODO
Output:
(py_pc064_03.08_test0_lancer) [cfati@cfati-5510-0:/mnt/e/Work/Dev/StackExchange/StackOverflow/q078366208]> python ./code00.py Python 3.8.19 (default, Apr 6 2024, 17:58:10) [GCC 11.4.0] 064bit on linux Pointer: <__main__.LP_c_float object at 0x7203e97e7d40> Pointer elements: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0) Address: 125361127594576 (0x00007203E97A9A50) Elements from address: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0) Addresses match: True Done.