javajava-ffmjava-22

New Java Foreign Function Api JEP 454 - Working with c strings and array of c strings


I'm playing with the new Foreign Function API. I followed the example in JEP 454, showed in the description section. I tried using qsort instead of radix sort.

Unfortunately, the document doesn't remain consistent with this example. In the Upcalls instead of building upon this example, it describes the qsort function in Java with ints (while the first non complete example - some code is missing - uses strings and qsort).

I noticed that the FunctionDescriptor is different between downcalls and Upcalls. When there is the need to describe a pointer, in the downcall you can express it via ADDRESS while for Upcalls you have to specify the total size of the MemorySegment:

MethodHandle qsort = linker.downcallHandle(
    linker.defaultLookup().find("qsort").get(),
    FunctionDescriptor.ofVoid(ADDRESS, JAVA_LONG, JAVA_LONG, ADDRESS)
);

MemorySegment comparFunc
    = linker.upcallStub(comparHandle,
                        FunctionDescriptor.of(JAVA_INT,
                                              ADDRESS.withTargetLayout(JAVA_INT),
                                              ADDRESS.withTargetLayout(JAVA_INT)),
                        Arena.ofAuto());

Instead of:

MemorySegment comparFunc
    = linker.upcallStub(comparHandle,
                        FunctionDescriptor.of(JAVA_INT,
                                              ADDRESS,
                                              ADDRESS,
                        Arena.ofAuto());

Following this, I wrote my upcall code:

static MemorySegment compareStringsFuncDescriptor(Arena arena) {
    return nativeLinker().upcallStub(compareStringFunction(),
        // qsort needs takes a function "int (*compar)(const void*,const void*)" but in our case there is a cast therefore is int (*compar)(const char*, const char*)
        FunctionDescriptor.of(
            JAVA_INT, // Return type
            ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(10, JAVA_CHAR)),
            ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(10, JAVA_CHAR))
        ), arena);
}

private static MethodHandle compareStringFunction() {
    try {
        return MethodHandles.lookup().findStatic(CFunctionImplementations.class,
            "strcmpJavaImpl", methodType(int.class, MemorySegment.class, MemorySegment.class));
    } catch (NoSuchMethodException | IllegalAccessException e) {
        throw new RuntimeException(e);
    }
}

private static int strcmpJavaImpl(MemorySegment pString1, MemorySegment pString2) {
    String s1 = pString1.getAtIndex(ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(10, JAVA_CHAR)), 0).getString(0);
    String s2 = pString2.getAtIndex(ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(10, JAVA_CHAR)), 0).getString(0);

    assert s1 != null;
    assert s2 != null;
    System.out.println("Comparing: " + s1 + " with " + s2);
    return s1.compareTo(s2); // Use java method
} 

and this is the method to copy the sorted array from native memory to heap:

private static void copySortedArrayFromNativeMemoryToHeap(String[] strings, MemorySegment pStrings) {
    for (int i = 0; i < strings.length; i++) {
        strings[i] = pStrings.getAtIndex(ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(10, JAVA_CHAR)), i)
            .getString(0);
    }
}

The question regards the AddressLayout used:

ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(10, JAVA_CHAR)) 

I'm not convinced about my own implementation but I didn't find a better way to do it. The array used to test my code is:

string[] strings = {"mouse", "cat", "crocodile", "dog", "car"}

A the beginning I was using:

ADDRESS.withTargetLayout(ADDRESS)

but it started to fail when I added "crocodile" because this string doesn't fit in just 8 bytes.

How to make this code more "dynamic"? Should someone uses the max string length in the array and replace the static 10 number with it? I'm not convinced! What a waste of memory if I use a memory layout of 10 * JAVA_CHAR or even worse MAX_STRING_LENGTH * JAVA_CHAR

I saw that in the example in JEP 454,copies from off-heap to on-heap is done in this way:

// 7. Copy the (reordered) strings from off-heap to on-heap
for (int i = 0; i < javaStrings.length; i++) {
    MemorySegment cString = pointers.getAtIndex(ValueLayout.ADDRESS, i);
    javaStrings[i] = cString.reinterpret(...).getString(0);
}

The author didn't specify the parameter of the reinterpret function and in fact how it is possible to specify it if we are not sure of the size of the string we are trying to copy since the array is now sorted in a way that is possibly different from the input array??? That's why I didn't use reinterpret in my solution.

EDIT:

I updated the code. Now:

ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(10, JAVA_CHAR)) 

has been replaced with:

ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(maxtringByteSizeWithCEndChar, JAVA_BYTE));

In my example stringByteSizeWithCEndChar is equals to crocodile.length + 1. Given the hypotheses that the Charset is UTF_8 the byte size for a char is 1 instead of JAVA_CHAR.byteSize() which is 2 for my platform. So the target layout size is 10 bytes instead of 20 bytes. Still the question remains. Can I do any better? And does it make sense to do any better? The total off heap allocation for the array should now arraySize * 10 bytes = 50 bytes!


Solution

  • Byte Sizes

    The byte size of a MemorySegment is not directly related to the size of the region of memory. For instance, if you do the following:

    try (var arena = Arena.ofConfined()) {
      var segment = arena.allocate(50);
      segment = segment.reinterpret(Long.MAX_VALUE);
    }
    

    That does not cause Long.MAX_VALUE bytes of memory to be allocated. The underlying region of memory is not changed, and so the amount of memory allocated remains 50 bytes. All the reinterpret did is return a copy of the segment (the Java object, not the native memory) but with its size set to the new value.

    A segment's bounds are meant to help keep the code relatively safe. Similar to arrays in Java, trying to access memory outside the bounds of a segment will result in an exception. It's better to get an exception in Java than it is to silently corrupt memory or crash the JVM with a segmentation fault. Changing the bounds, or otherwise defining the bounds in such a way that Java cannot guarantee they're correct, can lead to unsafe code. Hence methods like reinterpret and withTargetLayout are restricted.

    However, Java doesn't always know how large a segment should be, particularly if the segment came from native code. So, we help Java out by passing around appropriate MemoryLayout objects or calling methods like reinterpret.

    Strings

    Strings in C are simply arrays of characters. As such, you run into the common problem of not knowing how large the array is. The general solution is to pass around the size of the array alongside the array itself. For strings specifically, however, the convention is to have the string be null terminated. In C, if you don't already know the length, you would loop over the array until encountering the null character, at which point you know you've reached the end of the string. But in Java, how do you know what the bounds of the MemorySegment should be? If all you have is the string, then unfortunately the answer is that you don't know.

    If you want to read any string, and you're confident the strings will be properly null terminated, then basically you want to tell Java to "read the characters until the null character is found". In other words, make the bounds "infinite".

    MemorySegment stringSegment = ...;
    String string = stringSegment.reinterpret(Long.MAX_VALUE).getString(0L);
    

    Remember, that does not allocate any memory. It only only changes the bounds of the segment. Though now there's no real limit on how large the Java String might be. You might not have enough memory to create the String, or there might be too many characters to fit inside a byte[] (which is how String stores its characters from Java 9 to at least Java 24).

    If you're not confident the strings are always null-terminated, or you simply want to limit how large the strings can be, then you can reinterpret to a smaller size and handle out-of-bounds exceptions. Which size you choose depends on context. In other words, you need to choose a size that's "big enough", where that size is determined by factors unique to your use case.

    That said, if you know the offset and length of the string, then you can do the following:

    byte[] bytes = new byte[length];
    MemorySegment.copy(segment, JAVA_BYTE, offset, bytes, 0, length);
    return new String(bytes, charset);
    

    Target Layout - Downcall vs Upcall

    The other thing MemoryLayout helps with is telling Java how to interact with native functions. Specifically, how to translate arguments and return values.

    Downcall Handle

    When creating a downcall handle where the native function accepts a pointer, you tell Java that parameter is an AddressLayout. What type of "object" does the pointer point to? Doesn't matter for calling the function. Fundamentally, the address layout is just telling Java to pass a memory address to the native function. What the native function does with the memory at that address is outside Java's control, so telling Java the object's expected memory layout doesn't really help. If the wrong type of object is passed then that's a bug in the Java program and should be fixed.

    That's why you don't (need to) define the target layout when creating the downcall handle for qsort.

    If the return value of the native function is a pointer, then you will likely want to define the target layout. The reasons are the same as discussed in the next section.

    Upcall Stub

    When creating an upcall stub, things are a little different. You're essentially translating a Java method into a native function to be passed to native code. Again, if this native function is supposed to accept a pointer then you tell Java the parameter is an AddressLayout. And in the Java method, the parameter will be a MemorySegment. Now it does matter what the object's memory layout and size is, at least if you want to do anything with it beyond pass it around.

    When you define the target layout, the MemorySegment passed to the Java method will have its size set to the size of the target layout. Otherwise, you'll get a zero-length segment. Note you still don't have to define the target layout. You can instead reinterpret the segment. Defining the target layout is, in a way, a convenience for when you know the layout of the memory given to Java by the native code.

    Strings

    Again, if all you have is the strings, then it's hard to know the length of each string ahead of time. Which means it can be hard to define the memory layout of the strings. Your options are basically the same as when it comes to reinterpret as discussed previously. Either choose a byte size that's "big enough" for your use case, or simply choose the maximum possible byte size.

    Note if you do:

    ValueLayout.OfChar ch = ...;
    SequenceLayout str = MemoryLayout.sequenceLayout(Long.MAX_VALUE / ch.byteSize(), ch);
    AddressLayout addr = ADDRESS.withTargetLayout(str);
    try (var arena = Arena.ofConfined()) {
      var segment = arena.allocate(addr, 10);
    }
    

    Then the allocated memory is only enough for 10 addresses. And just because you said each sequence contains Long.MAX_VALUE / ch.byteSize() characters does not mean you're forced to allocate that much memory for each element. Each element can be just big enough to hold its string. For example:

    segment.setAtIndex(addr, 0, arena.allocateFrom("foo"));
    

    The element at index 0 is only a sequence of 4 characters (including the null character), not Long.MAX_VALUE / ch.byteSize() characters.


    Example qsort

    Source code

    QSort.java

    import static java.lang.foreign.ValueLayout.ADDRESS;
    import static java.lang.foreign.ValueLayout.JAVA_INT;
    import static java.lang.foreign.ValueLayout.JAVA_LONG;
    import static java.lang.invoke.MethodType.methodType;
    
    import java.lang.foreign.AddressLayout;
    import java.lang.foreign.Arena;
    import java.lang.foreign.FunctionDescriptor;
    import java.lang.foreign.Linker;
    import java.lang.foreign.MemoryLayout;
    import java.lang.foreign.MemorySegment;
    import java.lang.invoke.MethodHandle;
    import java.lang.invoke.MethodHandles;
    import java.util.Comparator;
    import java.util.Objects;
    
    public final class QSort {
    
      // A 'qsort' overload that works with any given 'Comparator' instance. The comparator's
      // 'compare' method will be used as the upcall stub function.
      public static void qsort(
          MemorySegment array, MemoryLayout elementLayout, Comparator<MemorySegment> comparator) {
        Objects.requireNonNull(array, "array");
        Objects.requireNonNull(elementLayout, "elementLayout");
        Objects.requireNonNull(comparator, "comparator");
    
        // The 'qsort' function passes pointers to the comparator function, so make sure the
        // layout for the comparator function's parameters are an ADDRESS.
        var compareLayout = elementLayout instanceof AddressLayout
            ? elementLayout
            : ADDRESS.withTargetLayout(elementLayout);
    
        // "Dynamically" determine comparator function's descriptor based on element layout
        var compareDesc = FunctionDescriptor.of(JAVA_INT, compareLayout, compareLayout);
    
        // Bind 'Comparator::compare' to the given 'comparator' instance, then adapt it to
        // accept 'MemorySegment' instead of 'Object'. Plus, the type of the target 'MethodHandle'
        // much match the 'FunctionDescriptor' when calling 'Linker::upcallStub'.
        var compareHandle = COMPARE_HANDLE.bindTo(comparator).asType(compareDesc.toMethodType());
    
        // Adapt the 'compareHandle' to catch any exceptions and handle them in the
        // 'QSort::handleCompareException' method. This prevents exceptions thrown by the
        // comparator from crashing the JVM.
        compareHandle = MethodHandles.catchException(compareHandle, Throwable.class, CATCH_HANDLE);
    
        // Allocate the upcall stub function and invoke 'qsort'. The upcall stub will be
        // freed when the 'try' block exits.
        try (var arena = Arena.ofConfined()) {
          var compareStub = LINKER.upcallStub(compareHandle, compareDesc, arena);
          qsort(array, elementLayout, compareStub);
        }
      }
    
      // A 'qsort' overload that works with an upcall stub allocated by the caller. This allows
      // callers to use any appropriate Java method or even a native function as the comparator.
      public static void qsort(
          MemorySegment array, MemoryLayout elementLayout, MemorySegment comparator) {
        Objects.requireNonNull(array, "array");
        Objects.requireNonNull(elementLayout, "elementLayout");
        Objects.requireNonNull(comparator, "comparator");
    
        long count = array.byteSize() / elementLayout.byteSize();
        long size = elementLayout.byteSize();
    
        try {
          QSORT_HANDLE.invokeExact(array, count, size, comparator);
        } catch (Throwable t) {
          throw new RuntimeException(t);
        }
      }
    
      // If an upcall stub throws an exception "out of" the stub, then the JVM
      // will crash. This method is used to handle any exception thrown by the
      // 'compare' stub created from a 'Comparator'. There's not much we can do
      // with the exception though, so all we do is print the stack trace and 
      // return a default value.
      private static int handleCompareException(Throwable t) {
        new RuntimeException("'qsort' comparator threw exception; returning 0", t).printStackTrace();
        return 0;
      }
    
      private QSort() {}
    
      /* *****************************************************************************
       *                                                                             *
       * FFM State                                                                   *
       *                                                                             *
       *******************************************************************************/
      
       private static final Linker LINKER = Linker.nativeLinker(); // need Linker to create upcall stubs
       private static final MethodHandle QSORT_HANDLE; // handle to 'qsort' native function
     
       private static final MethodHandle COMPARE_HANDLE; // handle to 'Comparator::compare'
       private static final MethodHandle CATCH_HANDLE; // handle to 'QSort::handleCompareException'
     
       static {
         var symbols = LINKER.defaultLookup();
         QSORT_HANDLE = LINKER.downcallHandle(
             symbols.findOrThrow("qsort"),
             FunctionDescriptor.ofVoid(ADDRESS, JAVA_LONG, JAVA_LONG, ADDRESS));
     
         var lookup = MethodHandles.lookup();
         try {
           var compareType = methodType(int.class, Object.class, Object.class);
           COMPARE_HANDLE = lookup.findVirtual(Comparator.class, "compare", compareType);
     
           var catchType = methodType(int.class, Throwable.class);
           CATCH_HANDLE = lookup.findStatic(QSort.class, "handleCompareException", catchType);
         } catch (ReflectiveOperationException ex) {
           throw new RuntimeException(ex);
         }
       }
    }
    

    Main.java

    import static java.lang.foreign.MemoryLayout.sequenceLayout;
    import static java.lang.foreign.ValueLayout.ADDRESS;
    import static java.lang.foreign.ValueLayout.JAVA_CHAR;
    import static java.lang.foreign.ValueLayout.JAVA_INT;
    import static java.util.Comparator.comparing;
    import static java.util.Comparator.comparingInt;
    
    import java.lang.foreign.AddressLayout;
    import java.lang.foreign.Arena;
    import java.lang.foreign.MemorySegment;
    import java.lang.foreign.SequenceLayout;
    import java.lang.foreign.ValueLayout;
    
    public class Main {
    
      private static final SequenceLayout STRING_LAYOUT =
          sequenceLayout(Long.MAX_VALUE / JAVA_CHAR.byteSize(), JAVA_CHAR);
      private static final AddressLayout STRING_ELEMENT_LAYOUT =
          ADDRESS.withTargetLayout(STRING_LAYOUT);
    
      private static final ValueLayout.OfInt INT_ELEMENT_LAYOUT = JAVA_INT;
    
      public static void main(String[] args) {
        testQSortWithIntArray();
        testQSortWithStringArray();
      }
    
      static void testQSortWithIntArray() {
        System.out.println("===== INT ARRAY =====");
        try (var arena = Arena.ofConfined()) {
          var array = arena.allocateFrom(INT_ELEMENT_LAYOUT, 8, 4, 3, 6, 1, 9, 5, 7, 2);
    
          System.out.println("Before qsort:");
          System.out.print("   ");
          printIntArray(array);
          System.out.println();
    
          var comparator = comparingInt((MemorySegment ms) -> ms.get(INT_ELEMENT_LAYOUT, 0L));
          QSort.qsort(array, INT_ELEMENT_LAYOUT, comparator);
    
          System.out.println("After qsort:");
          System.out.print("   ");
          printIntArray(array);
          System.out.println();
        }
      }
    
      static void testQSortWithStringArray() {
        String[] strings = {"pear", "banana", "pineapple", "apple", "orange"};
    
        System.out.println("===== STRING ARRAY =====");
        try (var arena = Arena.ofConfined()) {
          var array = arena.allocate(STRING_ELEMENT_LAYOUT, strings.length);
          for (int i = 0; i < strings.length; i++) {
            array.setAtIndex(STRING_ELEMENT_LAYOUT, i, arena.allocateFrom(strings[i]));
          }
    
          System.out.println("Before qsort:");
          System.out.print("   ");
          printStringArray(array);
          System.out.println();
    
          var comparator =
              comparing((MemorySegment ms) -> ms.get(STRING_ELEMENT_LAYOUT, 0L).getString(0L));
          QSort.qsort(array, STRING_ELEMENT_LAYOUT, comparator);
    
          System.out.println("After qsort:");
          System.out.print("   ");
          printStringArray(array);
          System.out.println();
        }
      }
    
      static void printIntArray(MemorySegment array) {
        long count = array.byteSize() / INT_ELEMENT_LAYOUT.byteSize();
    
        System.out.print("[");
        for (long i = 0; i < count; i++) {
          if (i != 0) System.out.print(", ");
          System.out.print(array.getAtIndex(INT_ELEMENT_LAYOUT, i));
        }
        System.out.println("]");
      }
    
      static void printStringArray(MemorySegment array) {
        long count = array.byteSize() / STRING_ELEMENT_LAYOUT.byteSize();
    
        System.out.print("[");
        for (long i = 0; i < count; i++) {
          if (i != 0) System.out.print(", ");
          var segment = array.getAtIndex(STRING_ELEMENT_LAYOUT, i);
          var string = segment.getString(0L);
          System.out.print("\"" + string + "\"");
        }
        System.out.println("]");
      }
    }
    

    Output

    ===== INT ARRAY =====
    Before qsort:
       [8, 4, 3, 6, 1, 9, 5, 7, 2]
    
    After qsort:
       [1, 2, 3, 4, 5, 6, 7, 8, 9]
    
    ===== STRING ARRAY =====
    Before qsort:
       ["pear", "banana", "pineapple", "apple", "orange"]
    
    After qsort:
       ["apple", "banana", "orange", "pear", "pineapple"]