fix[storage]: Fix storage functional method reference bug

This commit is contained in:
凌星
2023-09-18 10:14:42 +08:00
parent beb0f17cfc
commit 05ae9073fb
13 changed files with 231 additions and 75 deletions
@@ -15,6 +15,7 @@ package com.zfoo.boot;
import com.zfoo.boot.graalvm.GraalvmStorageHints;
import com.zfoo.storage.StorageContext;
import com.zfoo.storage.config.StorageConfig;
import com.zfoo.storage.graalvm.feature.RuntimeRegistrationFeature;
import com.zfoo.storage.manager.StorageManager;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -36,6 +37,7 @@ public class StorageAutoConfiguration {
public StorageManager storageManager(StorageConfig storageConfig) {
var storageManager = new StorageManager();
storageManager.setStorageConfig(storageConfig);
RuntimeRegistrationFeature.setLambdaCapturePackage(storageConfig.getLambdaCapturePackage());
return storageManager;
}
@@ -0,0 +1,15 @@
package com.zfoo.protocol.util;
/**
* 类扫描过滤器
*
* @author veione
*/
@FunctionalInterface
public interface ClassFilter {
/**
* 是否满足条件
*/
boolean accept(Class<?> clazz);
}
@@ -16,8 +16,10 @@ import com.zfoo.protocol.collection.ArrayUtils;
import com.zfoo.protocol.exception.RunException;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.net.JarURLConnection;
@@ -25,8 +27,10 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.net.URLDecoder;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -48,19 +52,10 @@ public abstract class ClassUtils {
public final static String JAR_URL_SEPARATOR = "!/";
private static ClassLoader systemClassLoader;
static {
try {
systemClassLoader = ClassLoader.getSystemClassLoader();
} catch (SecurityException ignored) {
// AccessControlException on Google App Engine
}
}
private ClassUtils() {
}
/**
* 默认过滤器(无实现)
*/
private final static ClassFilter DEFAULT_FILTER = clazz -> true;
public static Class<?> forName(String className) {
try {
@@ -512,37 +507,154 @@ public abstract class ClassUtils {
}
/**
* @param name
* @param classLoader
* @return
* @since 3.4.3
* 扫描目录下的所有class文件
*
* @param scanPackage 搜索的包根路径
*/
public static Class<?> toClassConfident(String name, ClassLoader classLoader) {
try {
return loadClass(name, getClassLoaders(classLoader));
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
public static Set<Class<?>> getClasses(String scanPackage) {
return getClasses(scanPackage, DEFAULT_FILTER);
}
private static Class<?> loadClass(String className, ClassLoader[] classLoaders) throws ClassNotFoundException {
for (ClassLoader classLoader : classLoaders) {
if (classLoader != null) {
/**
* 返回所有带制定注解的class列表
*
* @param scanPackage 搜索的包根路径
*/
public static <A extends Annotation> Set<Class<?>> listClassesWithAnnotation(String scanPackage, Class<A> annotation) {
return getClasses(scanPackage, (clazz) -> clazz.getAnnotation(annotation) != null);
}
/**
* 扫描目录下的所有class文件
*
* @param pack 包路径
* @param filter 自定义类过滤器
*/
public static Set<Class<?>> getClasses(String pack, ClassFilter filter) {
Set<Class<?>> result = new LinkedHashSet<>();
// 是否循环迭代
boolean recursive = true;
// 获取包的名字 并进行替换
String packageName = pack;
String packageDirName = packageName.replace('.', '/');
// 定义一个枚举的集合 并进行循环来处理这个目录下的things
Enumeration<URL> dirs;
try {
dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
// 循环迭代下去
while (dirs.hasMoreElements()) {
// 获取下一个元素
URL url = dirs.nextElement();
// 得到协议的名称
String protocol = url.getProtocol();
// 如果是以文件的形式保存在服务器上
if ("file".equals(protocol)) {
// 获取包的物理路径
String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
// 以文件的方式扫描整个包下的文件 并添加到集合中
findAndAddClassesInPackageByFile(packageName, filePath, recursive, result, filter);
} else if ("jar".equals(protocol)) {
// 如果是jar包文件
Set<Class<?>> jarClasses = findClassFromJar(url, packageName, packageDirName, recursive, filter);
result.addAll(jarClasses);
}
}
} catch (IOException e) {
throw new RunException(e);
}
return result;
}
private static Set<Class<?>> findClassFromJar(URL url, String packageName, String packageDirName,
boolean recursive, ClassFilter filter) {
Set<Class<?>> result = new LinkedHashSet<>();
try {
// 获取jar
JarFile jar = ((JarURLConnection) url.openConnection()).getJarFile();
// 从此jar包 得到一个枚举类
Enumeration<JarEntry> entries = jar.entries();
// 同样的进行循环迭代
while (entries.hasMoreElements()) {
// 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
JarEntry entry = entries.nextElement();
String name = entry.getName();
// 如果是以/开头的
if (name.charAt(0) == '/') {
// 获取后面的字符串
name = name.substring(1);
}
// 如果前半部分和定义的包名相同
if (name.startsWith(packageDirName)) {
int idx = name.lastIndexOf('/');
// 如果以"/"结尾 是一个包
if (idx != -1) {
// 获取包名 把"/"替换成"."
packageName = name.substring(0, idx)
.replace('/', '.');
}
// 如果可以迭代下去 并且是一个包
if ((idx != -1) || recursive) {
// 如果是一个.class文件 而且不是目录
if (name.endsWith(".class")
&& !entry.isDirectory()) {
// 去掉后面的".class" 获取真正的类名
String className = name.substring(packageName.length() + 1,
name.length() - 6);
try {
// 添加到classes
Class<?> c = Class.forName(packageName + '.' + className);
if (filter.accept(c)) {
result.add(c);
}
} catch (ClassNotFoundException e) {
throw new RunException(e);
}
}
}
}
}
} catch (IOException e) {
throw new RunException(e);
}
return result;
}
private static void findAndAddClassesInPackageByFile(String packageName,
String packagePath, final boolean recursive, Set<Class<?>> classes,
ClassFilter filter) {
// 获取此包的目录 建立一个File
File dir = new File(packagePath);
// 如果不存在或者 也不是目录就直接返回
if (!dir.exists() || !dir.isDirectory()) {
return;
}
// 如果存在 就获取包下的所有文件 包括目录
File[] dirs = dir.listFiles(new FileFilter() {
// 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
public boolean accept(File file) {
return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
}
});
// 循环所有文件
for (File file : dirs) {
// 如果是目录 则继续扫描
if (file.isDirectory()) {
findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes, filter);
} else {
// 如果是java类文件 去掉后面的.class 只留下类名
String className = file.getName().substring(0,
file.getName().length() - 6);
try {
return Class.forName(className, true, classLoader);
// 添加到集合中去
Class<?> clazz = Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className);
if (filter.accept(clazz)) {
classes.add(clazz);
}
} catch (ClassNotFoundException e) {
// ignore
throw new RunException(e);
}
}
}
throw new ClassNotFoundException("Cannot find class: " + className);
}
private static ClassLoader[] getClassLoaders(ClassLoader classLoader) {
return new ClassLoader[]{
classLoader,
Thread.currentThread().getContextClassLoader(),
ClassUtils.class.getClassLoader(),
systemClassLoader};
}
}
@@ -10,7 +10,7 @@ import java.util.Locale;
*
* @author veione
*/
public class FieldUtils {
public abstract class FieldUtils {
public static String fieldToGetMethod(Class<?> clazz, Field field) {
var fieldName = field.getName();
@@ -0,0 +1,17 @@
package com.zfoo.storage.anno;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* @author veione
*/
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface GraalvmNative {
}
@@ -29,6 +29,9 @@ public class StorageConfig {
// 未被使用的Storage是否回收,默认开启节省资源
private boolean recycle;
// 函数式capture扫描包
private String lambdaCapturePackage;
public String getId() {
return id;
}
@@ -68,4 +71,12 @@ public class StorageConfig {
public void setRecycle(boolean recycle) {
this.recycle = recycle;
}
public String getLambdaCapturePackage() {
return lambdaCapturePackage;
}
public void setLambdaCapturePackage(String lambdaCapturePackage) {
this.lambdaCapturePackage = lambdaCapturePackage;
}
}
@@ -0,0 +1,33 @@
package com.zfoo.storage.graalvm.feature;
import com.zfoo.protocol.util.ClassUtils;
import com.zfoo.storage.anno.GraalvmNative;
import org.graalvm.nativeimage.hosted.Feature;
import org.graalvm.nativeimage.hosted.RuntimeSerialization;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* 用于注册graalvm lambda capture的类
*
* @author veione
*/
public class RuntimeRegistrationFeature implements Feature {
private static final Logger log = LoggerFactory.getLogger(RuntimeRegistrationFeature.class);
private static String lambdaCapturePackage;
@Override
public void duringSetup(DuringSetupAccess access) {
log.info("Runtime registration feature on duringSetup");
var filterClasses = ClassUtils.getClasses(lambdaCapturePackage, c -> c.isAnnotationPresent(GraalvmNative.class));
filterClasses.forEach(cls -> {
log.info("Starting register lambda capture class: {}", cls);
RuntimeSerialization.registerLambdaCapturingClass(cls);
});
}
public static void setLambdaCapturePackage(String lambdaCapturePackage) {
RuntimeRegistrationFeature.lambdaCapturePackage = lambdaCapturePackage;
}
}
@@ -37,7 +37,7 @@ public final class LambdaUtils {
Class<? extends Serializable> clazz = func.getClass();
Method method = clazz.getDeclaredMethod("writeReplace");
ReflectionUtils.makeAccessible(method);
return new ReflectLambdaMeta((java.lang.invoke.SerializedLambda) method.invoke(func), clazz.getClassLoader());
return new ReflectLambdaMeta((java.lang.invoke.SerializedLambda) method.invoke(func));
} catch (Throwable e) {
// 3. 反射失败使用序列化的方式读取
return new ShadowLambdaMeta(SerializedLambda.extract(func));
@@ -27,11 +27,6 @@ public class IdeaProxyLambdaMeta implements LambdaMeta {
return name;
}
@Override
public Class<?> getInstantiatedClass() {
return clazz;
}
@Override
public String toString() {
return clazz.getSimpleName() + "::" + name;
@@ -14,11 +14,4 @@ public interface LambdaMeta {
*/
String getImplMethodName();
/**
* 实例化该方法的类
*
* @return 返回对应的类名称
*/
Class<?> getInstantiatedClass();
}
@@ -1,30 +1,17 @@
package com.zfoo.storage.util.support;
import com.zfoo.protocol.util.ClassUtils;
import com.zfoo.protocol.util.StringUtils;
/**
* Created by hcl at 2021/5/14
*/
public class ReflectLambdaMeta implements LambdaMeta {
private final java.lang.invoke.SerializedLambda lambda;
private final ClassLoader classLoader;
public ReflectLambdaMeta(java.lang.invoke.SerializedLambda lambda, ClassLoader classLoader) {
public ReflectLambdaMeta(java.lang.invoke.SerializedLambda lambda) {
this.lambda = lambda;
this.classLoader = classLoader;
}
@Override
public String getImplMethodName() {
return lambda.getImplMethodName();
}
@Override
public Class<?> getInstantiatedClass() {
String instantiatedMethodType = lambda.getInstantiatedMethodType();
String instantiatedType = instantiatedMethodType.substring(2, instantiatedMethodType.indexOf(StringUtils.SEMICOLON)).replace(StringUtils.SLASH, StringUtils.PERIOD);
return ClassUtils.toClassConfident(instantiatedType, this.classLoader);
}
}
@@ -1,8 +1,5 @@
package com.zfoo.storage.util.support;
import com.zfoo.protocol.util.StringUtils;
import com.zfoo.protocol.util.ClassUtils;
/**
* 基于 {@link SerializedLambda} 创建的元信息
* <p>
@@ -20,11 +17,4 @@ public class ShadowLambdaMeta implements LambdaMeta {
return lambda.getImplMethodName();
}
@Override
public Class<?> getInstantiatedClass() {
String instantiatedMethodType = lambda.getInstantiatedMethodType();
String instantiatedType = instantiatedMethodType.substring(2, instantiatedMethodType.indexOf(StringUtils.SEMICOLON)).replace(StringUtils.SLASH, StringUtils.PERIOD);
return ClassUtils.toClassConfident(instantiatedType, lambda.getCapturingClass().getClassLoader());
}
}
@@ -91,6 +91,7 @@ public class ExportBinaryTesting {
config.setResourceLocation("classpath:/excel");
config.setWriteable(true);
config.setRecycle(false);
config.setLambdaCapturePackage("com.zfoo.storage");
var storageManager = new StorageManager();
storageManager.setStorageConfig(config);
storageManager.initBefore();