javajava-ffm

How can I write an array-like datastructure in Java that takes longs as indices using modern unsafe APIs?


Arrays in Java are limited to Integer.MAX_VALUE for initial capacity & indexable elements (around 2 billion). I would like to write a data structure class that uses a long for this instead.

I know there are two methods in wide circulation:

  1. Use an array of arrays
  2. Use APIs in sun.misc.Unsafe to manually allocate and access large slabs of memory

I don't want to use an array of arrays and using sun.misc.Unsafe is heavily discouraged, producing compilation warnings that cannot be silenced using ordinary methods.

Starting in Java 9 there began efforts to standardize & replace sun.misc.Unsafe with the addition of java.lang.invoke.VarHandle in JEP 193. Then in Java 22 there was the addition of java.lang.foreign.MemorySegment in JEP 454. JEP 471 coming in Java 23 is going to deprecate the memory access methods in sun.misc.Unsafe for removal.

So it seems like there should be a way to use the existing VarHandle and MemorySegment APIs to write a long array in Java. How do I do this?


Solution

  • You can use SegmentAllocator::allocate(MemoryLayout,long) to create a MemorySegment that can be used as an array of "objects" represented by the given MemoryLayout. Then you can wrap the segment in a Java class to encapsulate the "array access".

    Note this means the data has to be able to be put into off-heap memory. In other words, the data has to be a Java primitive or, for more complex types, a MemorySegment. You won't be able to fill the array with arbitrary Java reference types. If you want to treat complex elements as Java objects, you'll have to write a class that wraps the MemorySegment. Or at least, that's the only approach I'm aware of.


    Primitive Types

    For primitive types, this is relatively easy:

    import java.lang.foreign.MemorySegment;
    import java.lang.foreign.SegmentAllocator;
    import java.lang.foreign.ValueLayout;
    import java.util.Objects;
    
    public final class LargeIntArray {
    
      public static final ValueLayout.OfInt LAYOUT = ValueLayout.JAVA_INT_UNALIGNED;
      
      private final MemorySegment segment;
      private final long length;
    
      public LargeIntArray(SegmentAllocator allocator, long length) {
        this.segment = allocator.allocate(LAYOUT, length);
        this.length = length;
      }
    
      public MemorySegment address() {
        return MemorySegment.ofAddress(segment.address());
      }
    
      public int get(long index) {
        return segment.getAtIndex(LAYOUT, index);
      }
    
      public void set(long index, int element) {
        segment.setAtIndex(LAYOUT, index, element);
      }
    
      public long length() {
        return length;
      }
    }
    

    There are ValueLayout.OfXXX interfaces for each of the primitive Java types.


    Complex Types

    For more complex data types, you will be working with MemorySegment instead of primitive types:

    import java.lang.foreign.AddressLayout;
    import java.lang.foreign.MemoryLayout;
    import java.lang.foreign.MemorySegment;
    import java.lang.foreign.SegmentAllocator;
    import java.lang.foreign.ValueLayout;
    
    public final class LargeArray {
    
      private final MemorySegment segment;
      private final long length;
      private final AddressLayout layout;
    
      public LargeArray(SegmentAllocator allocator, MemoryLayout elementLayout, long length) {
        this.segment = allocator.allocate(elementLayout, length);
        this.layout = ValueLayout.ADDRESS.withTargetLayout(elementLayout);
        this.length = length;
      }
    
      public AddressLayout layout() {
        return layout;
      }
    
      public MemorySegment address() {
        return MemorySegment.ofAddress(segment.address());
      }
    
      public MemorySegment get(long index) {
        return segment.getAtIndex(layout, index);
      }
    
      public void set(long index, MemorySegment element) {
        segment.setAtIndex(layout, index, element);
      }
    
      public long length() {
        return length;
      }
    }
    

    Better Encapsulation

    One potential improvement is to create a Java class that wraps the MemorySegment representing the "objects". This will make working with the array more natural on the Java side. First, you need a generic way to map between MemorySegment and a Java class:

    import java.lang.foreign.MemoryLayout;
    import java.lang.foreign.MemorySegment;
    import java.util.Objects;
    import java.util.function.Function;
    
    public interface ElementDescriptor<T> {
    
      public static <T> ElementDescriptor<T> of(
          MemoryLayout layout,
          Function<MemorySegment, T> toElement,
          Function<T, MemorySegment> toAddress) {
        Objects.requireNonNull(layout);
        Objects.requireNonNull(toElement);
        Objects.requireNonNull(toAddress);
        return new ElementDescriptor<>() {
          @Override
          public MemoryLayout layout() {
            return layout;
          }
    
          @Override
          public T elementFrom(MemorySegment segment) {
            if (segment.equals(MemorySegment.NULL)) {
              return null;
            }
            return toElement.apply(segment);
          }
    
          @Override
          public MemorySegment addressOf(T element) {
            if (element == null) {
              return MemorySegment.NULL;
            }
            return toAddress.apply(element);
          }
        };
      }
    
      MemoryLayout layout();
    
      T elementFrom(MemorySegment segment);
    
      MemorySegment addressOf(T element);
    }
    

    Then you need to update LargeArray to work with the above:

    import java.lang.foreign.AddressLayout;
    import java.lang.foreign.MemorySegment;
    import java.lang.foreign.SegmentAllocator;
    import java.lang.foreign.ValueLayout;
    
    public final class LargeArray<T> {
    
      private final MemorySegment segment;
      private final long length;
      private final AddressLayout layout;
      private final ElementDescriptor<T> descriptor;
    
      public LargeArray(SegmentAllocator allocator, long length, ElementDescriptor<T> descriptor) {
        this.segment = allocator.allocate(descriptor.layout(), length);
        this.layout = ValueLayout.ADDRESS.withTargetLayout(descriptor.layout());
        this.length = length;
        this.descriptor = descriptor;
      }
    
      public AddressLayout layout() {
        return layout;
      }
    
      public MemorySegment address() {
        return MemorySegment.ofAddress(segment.address());
      }
    
      public T get(long index) {
        return descriptor.elementFrom(segment.getAtIndex(layout, index));
      }
    
      public void set(long index, T element) {
        segment.setAtIndex(layout, index, descriptor.addressOf(element));
      }
    
      public long length() {
        return length;
      }
    }
    

    And finally, you need a data structure. For example, here is a Point struct with x and y coordinates:

    import static java.lang.foreign.ValueLayout.JAVA_INT;
    
    import java.lang.foreign.MemoryLayout;
    import java.lang.foreign.MemoryLayout.PathElement;
    import java.lang.foreign.MemorySegment;
    import java.lang.foreign.SegmentAllocator;
    import java.lang.foreign.StructLayout;
    import java.lang.invoke.MethodHandles;
    import java.lang.invoke.VarHandle;
    import java.util.Objects;
    
    public final class Point {
    
      public static final StructLayout LAYOUT;
      public static final ElementDescriptor<Point> DESCRIPTOR;
    
      private static final VarHandle X;
      private static final VarHandle Y;
    
      static {
        LAYOUT = MemoryLayout.structLayout(JAVA_INT.withName("x"), JAVA_INT.withName("y"));
    
        var x = LAYOUT.varHandle(PathElement.groupElement("x"));
        X = MethodHandles.insertCoordinates(x, 1, 0L);
    
        var y = LAYOUT.varHandle(PathElement.groupElement("y"));
        Y = MethodHandles.insertCoordinates(y, 1, 0L);
    
        DESCRIPTOR = ElementDescriptor.of(LAYOUT, Point::new, Point::address);
      }
    
      private final MemorySegment segment;
    
      public Point(SegmentAllocator allocator) {
        segment = allocator.allocate(LAYOUT);
      }
    
      public Point(MemorySegment segment) {
        this.segment = Objects.requireNonNull(segment);
      }
    
      public MemorySegment address() {
        return MemorySegment.ofAddress(segment.address());
      }
    
      public int getX() {
        return (int) X.get(segment);
      }
    
      public void setX(int x) {
        X.set(segment, x);
      }
    
      public int getY() {
        return (int) Y.get(segment);
      }
    
      public void setY(int y) {
        Y.set(segment, y);
      }
    
      @Override
      public String toString() {
        return "Point(x=" + getX() + ", y=" + getY() + ")";
      }
    }
    

    Example Use

    Here is an example of using a LargeArray<Point>:

    import java.lang.foreign.Arena;
    
    public class Main {
    
      public static void main(String[] args) throws Throwable {
        try (var arena = Arena.ofConfined()) {
          var array = new LargeArray<Point>(arena, 10L, Point.DESCRIPTOR);
    
          // populate array
          for (long i = 0; i < array.length(); i++) {
            var point = new Point(arena);
            point.setX((int) i);
            point.setY((int) i * 2);
            array.set(i, point);
          }
    
          // show modification of element in array
          var midPoint = array.get(5L);
          midPoint.setX(42);
          midPoint.setY(117);
    
          // print array contents
          for (long i = 0; i < array.length(); i++) {
            System.out.printf("array[%d] = %s%n", i, array.get(i));
          }
        }
      }
    }
    

    Output:

    array[0] = Point(x=0, y=0)
    array[1] = Point(x=1, y=2)
    array[2] = Point(x=2, y=4)
    array[3] = Point(x=3, y=6)
    array[4] = Point(x=4, y=8)
    array[5] = Point(x=42, y=117)
    array[6] = Point(x=6, y=12)
    array[7] = Point(x=7, y=14)
    array[8] = Point(x=8, y=16)
    array[9] = Point(x=9, y=18)
    

    Notes

    Few notes: