javaspring-bootspring-native

spring native process-aot unable to register lambda


If you want to use lambda expressions under native, you must register the lambdaCapturingTypes type in serialization-config.json.


@Service
@AllArgsConstructor
public class SysRegionServiceImpl implements ISysRegionService {

    private final SysRegionRepository sysRegionRepository;

    @Override
    public SysRegion selectSysRegionByRegionCodel(String regionCode) {
        return sysRegionRepository.getOne(Wrappers.<SysRegion>lambdaQuery().eq(SysRegion::getRegionCode, regionCode));
    }

}

Must be registered in serialization-config.json

{
  "types":[
  ],
  "lambdaCapturingTypes":[
    {"name":"com.shushang.system.service.impl.SysRegionServiceImpl"}
  ],
  "proxies":[
  ]
}

or add buildArg --features=x.x.x.LambdaRegistrationFeature

public class LambdaRegistrationFeature implements Feature {

    @Override
    public void duringSetup(DuringSetupAccess access) {
        RuntimeSerialization.registerLambdaCapturingClass(SysRegionServiceImpl.class);
    }

}

It is very troublesome to use.

but Spring provides APIs related to Hints to generate JSON files in advance, org.springframework.aot.hint.SerializationHints

use RuntimeHintsRegistrar

hints.serialization().registerType(SysRegionServiceImpl.class);

But the registered lambdaCapturingTypes type cannot be added.

What Spring process-aot helped me generate is like this serialization-config.json

[
  {
    "name": "com.shushang.system.service.impl.SysRegionServiceImpl"
  }
]

How should I dynamically add classes that need to be registered during process-aot, such as classes annotated by @Service, and add them to lambdaCapturingTypes?


Solution

  • Manually use the scanned class file to register.
    This method scans the class in the jar package for manual registration in the native:compile stage.

    There is no scanning processing in the spring process-aot stage, because spring does not provide relevant api (registerLambdaCapturingClass)

    public class LambdaRegistrationFeature implements Feature {
    
        @SneakyThrows
        @Override
        public void duringSetup(DuringSetupAccess access) {
            // 这里只支持一个**的匹配
            List<String> classes = MultiPackageScanner.scanClasses("com.shushang.**.service.impl", "com.shushang.system.repository");
            // 忽略的类 spring自动生成的 不需要加载
            List<String> ignoreClass = CollUtil.newArrayList("__BeanDefinitions","__Autowiring","SpringCGLIB");
            for (String aClass : classes) {
                if (ignoreClass.stream().anyMatch(aClass::contains)) {
                    continue;
                }
                Class<?> clazz = Thread.currentThread().getContextClassLoader().loadClass(aClass);
                RuntimeSerialization.registerLambdaCapturingClass(clazz);
                // 需要确保类资源被扫描进去了
                System.out.println("成功加载类: " + aClass);
            }
        }
    
    }
    
    
    import java.io.File;
    import java.net.JarURLConnection;
    import java.net.URL;
    import java.util.ArrayList;
    import java.util.Enumeration;
    import java.util.List;
    import java.util.jar.JarEntry;
    import java.util.jar.JarFile;
    
    public class MultiPackageScanner {
    
        public static List<String> scanClasses(String... packagePatterns) throws Exception {
            List<String> classes = new ArrayList<>();
            for (String pattern : packagePatterns) {
                String basePackage = pattern.replace(".**", ""); // 移除通配符以获取基础包名
                String targetPath = "";
                // 包含通配符
                if (pattern.contains(".**")) {
                    basePackage = pattern.substring(0, pattern.indexOf(".**"));
                    targetPath = pattern.substring(pattern.indexOf(".**")).replace(".**.","");
                }
                String resourcePath = basePackage.replace('.', '/');
                Enumeration<URL> resources = Thread.currentThread().getContextClassLoader().getResources(resourcePath);
    
                while (resources.hasMoreElements()) {
                    URL url = resources.nextElement();
                    if ("file".equals(url.getProtocol())) {
                        scanFileSystem(url, basePackage, targetPath, classes, pattern.contains("**"));
                    } else if ("jar".equals(url.getProtocol())) {
                        scanJar(url, basePackage,targetPath, classes, pattern.contains("**"));
                    }
                }
            }
            return classes;
        }
    
        private static void scanFileSystem(URL url, String basePackage, String targetPath, List<String> classes, boolean recursive) {
            File dir = new File(url.getPath());
            if (dir.exists()) {
                scanDirectory(dir, basePackage,targetPath, classes, recursive);
            }
        }
    
        private static void scanDirectory(File dir, String basePackage, String targetPath, List<String> classes, boolean recursive) {
            File[] files = dir.listFiles();
            if (files == null) return;
    
            for (File file : files) {
                if (file.isDirectory()) {
                    if (recursive) { // 支持递归扫描子包
                        scanDirectory(file, basePackage + "." + file.getName(),targetPath, classes, true);
                    }
                } else if (file.getName().endsWith(".class")) {
                    String className = basePackage + '.' + file.getName().replace(".class", "");
                    if (className.contains(targetPath.replace('/', '.'))) {
                        loadClass(className, classes);
                    }
                }
            }
        }
    
        private static void scanJar(URL url, String basePackage, String targetPath, List<String> classes, boolean recursive) throws Exception {
            JarFile jar = ((JarURLConnection) url.openConnection()).getJarFile();
            Enumeration<JarEntry> entries = jar.entries();
            String packagePath = basePackage.replace('.', '/');
    
            while (entries.hasMoreElements()) {
                JarEntry entry = entries.nextElement();
                String entryName = entry.getName();
                if (entryName.endsWith(".class")) {
                    if (entryName.startsWith(packagePath) &&
                            (recursive || entryName.substring(packagePath.length() + 1).indexOf('/') == -1)) {
                        String className = entryName.replace('/', '.').replace(".class", "");
                        if (className.contains(targetPath.replace('/', '.'))) {
                            loadClass(className, classes);
                        }
                    }
                }
            }
        }
    
        private static void loadClass(String className, List<String> classes) {
            classes.add(className);
        }
    
        public static void main(String[] args) throws Exception {
            List<String> classes = scanClasses("com.shushang.**.service.impl", "com.shushang.system.repository");
            classes.forEach(clazz -> System.out.println("加载类: " + clazz));
        }
    }