diff --git a/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java b/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java index 8d368baa..36690f3c 100644 --- a/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java +++ b/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java @@ -13,34 +13,15 @@ package com.zfoo.protocol; import com.zfoo.protocol.buffer.ByteBufUtils; -import com.zfoo.protocol.collection.CollectionUtils; import com.zfoo.protocol.exception.RunException; -import com.zfoo.protocol.exception.UnknownException; import com.zfoo.protocol.generate.GenerateOperation; -import com.zfoo.protocol.generate.GenerateProtocolDocument; -import com.zfoo.protocol.generate.GenerateProtocolFile; -import com.zfoo.protocol.generate.GenerateProtocolPath; -import com.zfoo.protocol.registration.EnhanceUtils; -import com.zfoo.protocol.registration.IProtocolRegistration; -import com.zfoo.protocol.registration.ProtocolModule; -import com.zfoo.protocol.registration.ProtocolRegistration; -import com.zfoo.protocol.registration.field.*; -import com.zfoo.protocol.serializer.*; -import com.zfoo.protocol.serializer.cs.GenerateCsUtils; -import com.zfoo.protocol.serializer.js.GenerateJsUtils; -import com.zfoo.protocol.serializer.lua.GenerateLuaUtils; +import com.zfoo.protocol.registration.*; import com.zfoo.protocol.util.AssertionUtils; -import com.zfoo.protocol.util.ReflectionUtils; -import com.zfoo.protocol.util.StringUtils; import com.zfoo.protocol.xml.XmlProtocols; import io.netty.buffer.ByteBuf; -import javassist.CannotCompileException; -import javassist.NotFoundException; -import java.io.IOException; import java.lang.reflect.*; import java.util.*; -import java.util.stream.Collectors; /** * @author jaysunxiao @@ -48,7 +29,6 @@ import java.util.stream.Collectors; */ public class ProtocolManager { - /** * 包体的头部的长度,一个int字节长度 */ @@ -56,48 +36,15 @@ public class ProtocolManager { public static final String PROTOCOL_ID = "PROTOCOL_ID"; public static final short MAX_PROTOCOL_NUM = Short.MAX_VALUE; public static final byte MAX_MODULE_NUM = Byte.MAX_VALUE; - private static final Comparator PACKET_FIELD_COMPARATOR = (a, b) -> a.getName().compareTo(b.getName()); + public static final Comparator PACKET_FIELD_COMPARATOR = (a, b) -> a.getName().compareTo(b.getName()); - private static final IProtocolRegistration[] protocols = new IProtocolRegistration[MAX_PROTOCOL_NUM]; - private static final ProtocolModule[] modules = new ProtocolModule[MAX_MODULE_NUM]; + public static final IProtocolRegistration[] protocols = new IProtocolRegistration[MAX_PROTOCOL_NUM]; + public static final ProtocolModule[] modules = new ProtocolModule[MAX_MODULE_NUM]; - // 临时变量,启动完成就会销毁,协议名称保留字符,即协议的名称不能用以下名称命名 - private static Set tempProtocolReserved = Set.of("Buffer", "ByteBuf", "ByteBuffer", "LittleEndianByteBuffer", "NormalByteBuffer" - , "IPacket", "IProtocolRegistration", "ProtocolManager", "IFieldRegistration" - , "ByteBufUtils", "ArrayUtils", "CollectionUtils" - , "Boolean", "Byte", "Short", "Integer", "Long", "Float", "Double", "String", "Character", "Object"); - - // 临时变量,启动完成就会销毁,是一个基本类型序列化器 - private static Map, ISerializer> tempBaseSerializerMap = new HashMap<>(); - - // 临时变量,启动完成就会销毁,协议Id对应的Class类 - private static Map> tempProtocolClassMap = new HashMap<>(MAX_PROTOCOL_NUM); - - // 临时变量,启动完成就会销毁,协议下包含的子协议,只包含一层子协议 - private static Map> tempSubProtocolIdMap = new HashMap<>(MAX_PROTOCOL_NUM); static { // 初始化默认协议模块 modules[0] = ProtocolModule.DEFAULT_PROTOCOL_MODULE; - - // 初始化基础类型序列化器 - tempBaseSerializerMap.put(boolean.class, BooleanSerializer.getInstance()); - tempBaseSerializerMap.put(Boolean.class, BooleanSerializer.getInstance()); - tempBaseSerializerMap.put(byte.class, ByteSerializer.getInstance()); - tempBaseSerializerMap.put(Byte.class, ByteSerializer.getInstance()); - tempBaseSerializerMap.put(short.class, ShortSerializer.getInstance()); - tempBaseSerializerMap.put(Short.class, ShortSerializer.getInstance()); - tempBaseSerializerMap.put(int.class, IntSerializer.getInstance()); - tempBaseSerializerMap.put(Integer.class, IntSerializer.getInstance()); - tempBaseSerializerMap.put(long.class, LongSerializer.getInstance()); - tempBaseSerializerMap.put(Long.class, LongSerializer.getInstance()); - tempBaseSerializerMap.put(float.class, FloatSerializer.getInstance()); - tempBaseSerializerMap.put(Float.class, FloatSerializer.getInstance()); - tempBaseSerializerMap.put(double.class, DoubleSerializer.getInstance()); - tempBaseSerializerMap.put(Double.class, DoubleSerializer.getInstance()); - tempBaseSerializerMap.put(char.class, CharSerializer.getInstance()); - tempBaseSerializerMap.put(Character.class, CharSerializer.getInstance()); - tempBaseSerializerMap.put(String.class, StringSerializer.getInstance()); } public static void write(ByteBuf buffer, IPacket packet) { @@ -143,459 +90,16 @@ public class ProtocolManager { public static synchronized void initProtocol(Set> protocolClassSet) { - initProtocol(protocolClassSet, GenerateOperation.NO_OPERATION); + ProtocolAnalysis.analyze(protocolClassSet); } public static synchronized void initProtocol(Set> protocolClassSet, GenerateOperation generateOperation) { - AssertionUtils.notNull(tempSubProtocolIdMap, "[{}]已经初始完成,只能parseProtocol一次,请不要重复初始化", ProtocolManager.class.getSimpleName()); - try { - for (var protocolClass : protocolClassSet) { - var id = getProtocolIdByClass(protocolClass); - var previous = tempProtocolClassMap.put(id, protocolClass); - if (previous != null) { - throw new RunException("[{}][{}]协议号[protocolId:{}]重复", protocolClass.getCanonicalName(), previous.getCanonicalName(), id); - } - } - - for (var protocolClass : protocolClassSet) { - try { - var registration = parseProtocolRegistration(protocolClass, ProtocolModule.DEFAULT_PROTOCOL_MODULE); - // 注册协议 - protocols[registration.protocolId()] = registration; - } catch (Exception e) { - throw new RuntimeException(StringUtils.format("解析协议[class:{}]异常", protocolClass), e); - } - } - - enhanceProtocolBefore(generateOperation); - - // 通过指定类注册的协议,全部使用字节码增强 - enhanceProtocolRegistration(Arrays.stream(protocols).filter(it -> Objects.nonNull(it)).collect(Collectors.toList())); - - enhanceProtocolAfter(); - } catch (Exception e) { - throw new RuntimeException(e); - } + ProtocolAnalysis.analyze(protocolClassSet, generateOperation); } public static synchronized void initProtocol(XmlProtocols xmlProtocols, GenerateOperation generateOperation) { - AssertionUtils.notNull(tempSubProtocolIdMap, "[{}]已经初始完成,只能parseProtocol一次,请不要重复初始化", ProtocolManager.class.getSimpleName()); - try { - var enhanceList = new ArrayList(); - - for (var moduleDefinition : xmlProtocols.getModules()) { - var module = new ProtocolModule(moduleDefinition.getId(), moduleDefinition.getName(), moduleDefinition.getVersion()); - - AssertionUtils.isTrue(module.getId() > 0, "[module:{}] [id:{}] 模块必须大于等于1", module.getName(), module.getId()); - AssertionUtils.isNull(modules[module.getId()], "duplicate [module:{}] [id:{}] Exception!", module.getName(), module.getId()); - AssertionUtils.notNull(moduleDefinition.getProtocols(), "[module:{}] does not have any protocols", module.getName()); - - modules[module.getId()] = module; - - for (var protocolDefinition : moduleDefinition.getProtocols()) { - var id = protocolDefinition.getId(); - var location = protocolDefinition.getLocation(); - var clazz = Class.forName(location); - - AssertionUtils.isTrue(id >= moduleDefinition.getMinId(), "模块[{}]中的协议[{}]的协议号必须大于或者等于[{}]", moduleDefinition.getName(), clazz.getSimpleName(), moduleDefinition.getMinId()); - AssertionUtils.isTrue(id < moduleDefinition.getMaxId(), "模块[{}]中的协议[{}]的协议号必须小于[{}]", moduleDefinition.getName(), clazz.getSimpleName(), moduleDefinition.getMaxId()); - AssertionUtils.isNull(protocols[id], "duplicate definition [id:{}] Exception!", id); - - var packet = (IPacket) ReflectionUtils.newInstance(clazz); - - // 协议号是否和id是否相等 - AssertionUtils.isTrue(packet.protocolId() == id, "[class:{}]协议序列号[{}]和协议文件里的协议序列号不相等", clazz.getCanonicalName(), PROTOCOL_ID); - - var previous = tempProtocolClassMap.put(id, clazz); - if (previous != null) { - throw new RunException("[{}][{}]协议号[protocolId:{}]重复", clazz.getCanonicalName(), previous.getCanonicalName(), id); - } - } - } - - for (var moduleDefinition : xmlProtocols.getModules()) { - var module = modules[moduleDefinition.getId()]; - for (var protocolDefinition : moduleDefinition.getProtocols()) { - var id = protocolDefinition.getId(); - var clazz = tempProtocolClassMap.get(id); - try { - var registration = parseProtocolRegistration(clazz, module); - if (protocolDefinition.isEnhance()) { - enhanceList.add(registration); - } - // 注册协议 - protocols[id] = registration; - } catch (Exception e) { - throw new UnknownException(e, "解析协议[id:{}][class:{}]异常", id, clazz); - } - } - } - - enhanceProtocolBefore(generateOperation); - - enhanceProtocolRegistration(enhanceList); - - enhanceProtocolAfter(); - } catch (Exception e) { - throw new UnknownException(e); - } + ProtocolAnalysis.analyze(xmlProtocols, generateOperation); } - private static void enhanceProtocolRegistration(List enhanceList) throws NoSuchMethodException, IllegalAccessException, InstantiationException, CannotCompileException, NotFoundException, InvocationTargetException, NoSuchFieldException { - // 字节码增强 - for (var registration : enhanceList) { - protocols[registration.protocolId()] = EnhanceUtils.createProtocolRegistration((ProtocolRegistration) registration); - } - - // 字节码增强过后,初始化各个子协议成员变量 - for (var registration : enhanceList) { - var enhanceProtocolRegistration = protocols[registration.protocolId()]; - var subProtocolIds = getAllSubProtocolIds(registration.protocolId()); - for (var subProtocolId : subProtocolIds) { - var protocolRegistrationField = enhanceProtocolRegistration.getClass().getDeclaredField(EnhanceUtils.getProtocolRegistrationFieldNameByProtocolId(subProtocolId)); - ReflectionUtils.makeAccessible(protocolRegistrationField); - ReflectionUtils.setField(protocolRegistrationField, enhanceProtocolRegistration, protocols[subProtocolId]); - } - } - } - - private static void enhanceProtocolBefore(GenerateOperation generateOperation) throws IOException { - // 检查协议格式 - checkAllProtocolClass(); - - // 检查模块格式 - checkAllModules(); - - // 生成协议 - GenerateProtocolFile.generate(protocols, generateOperation); - } - - private static void enhanceProtocolAfter() { - tempSubProtocolIdMap.clear(); - tempSubProtocolIdMap = null; - - tempProtocolReserved = null; - - tempBaseSerializerMap.clear(); - tempBaseSerializerMap = null; - - GenerateProtocolDocument.clear(); - GenerateProtocolPath.clear(); - GenerateProtocolFile.clear(); - GenerateCsUtils.clear(); - GenerateJsUtils.clear(); - GenerateLuaUtils.clear(); - - EnhanceUtils.clear(); - } - - - private static short checkProtocol(Class clazz) throws IllegalAccessException, InvocationTargetException, InstantiationException { - // 是否为一个简单的javabean - AssertionUtils.isTrue(clazz.getSuperclass().equals(Object.class), "[class:{}]不是简单的javabean,不能继承别的类", clazz.getCanonicalName()); - // 是否实现了IPacket接口 - AssertionUtils.isTrue(IPacket.class.isAssignableFrom(clazz), "[class:{}]没有实现接口[IPacket:{}]", clazz.getCanonicalName(), IPacket.class.getCanonicalName()); - // 不能是泛型类 - AssertionUtils.isTrue(CollectionUtils.isEmpty(clazz.getTypeParameters()), "[class:{}]不能是泛型类", clazz.getCanonicalName()); - - Field protocolIdField; - try { - protocolIdField = clazz.getDeclaredField(PROTOCOL_ID); - } catch (NoSuchFieldException e) { - throw new UnknownException(e, "[class:{}]没有[{}]协议序列号", clazz.getCanonicalName(), PROTOCOL_ID); - } - - // 是否被public修饰 - AssertionUtils.isTrue(Modifier.isPublic(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被public修饰", clazz.getCanonicalName(), PROTOCOL_ID); - // 是否被static修饰 - AssertionUtils.isTrue(Modifier.isStatic(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被static修饰", clazz.getCanonicalName(), PROTOCOL_ID); - // 是否被final修饰 - AssertionUtils.isTrue(Modifier.isFinal(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被final修饰", clazz.getCanonicalName(), PROTOCOL_ID); - // 是否被transient修饰 - AssertionUtils.isTrue(Modifier.isTransient(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被transient修饰", clazz.getCanonicalName(), PROTOCOL_ID); - // 命名只能包含字母,数字,下划线 - AssertionUtils.isTrue(clazz.getSimpleName().matches("[a-zA-Z0-9_]*"), "[class:{}]的命名只能包含字母,数字,下划线", clazz.getCanonicalName(), PROTOCOL_ID); - - // 必须要有一个空的构造器 - Constructor constructor; - try { - constructor = clazz.getDeclaredConstructor(); - } catch (NoSuchMethodException e) { - throw new UnknownException(e, "[class:{}]协议序列号[{}]必须有一个空的构造器", clazz.getCanonicalName(), PROTOCOL_ID); - } - ReflectionUtils.makeAccessible(protocolIdField); - IPacket packet = (IPacket) constructor.newInstance(); - - // 验证protocol()方法的返回是否和PROTOCOL_ID相等 - AssertionUtils.isTrue(Short.valueOf(packet.protocolId()).equals(protocolIdField.get(null)), "[class:{}]的protocolId返回的值和协议号的静态变量[{}]不相等", clazz.getCanonicalName(), PROTOCOL_ID); - return packet.protocolId(); - } - - - public static short getProtocolIdByClass(Class clazz) { - var protocolIdField = ReflectionUtils.getFieldByNameInPOJOClass(clazz, PROTOCOL_ID); - ReflectionUtils.makeAccessible(protocolIdField); - return (short) ReflectionUtils.getField(protocolIdField, null); - } - - private static void checkAllModules() { - // 模块id不能重复 - var moduleIdSet = new HashSet(); - Arrays.stream(modules) - .filter(it -> Objects.nonNull(it)) - .peek(it -> AssertionUtils.isTrue(!moduleIdSet.contains(it.getId()), "模块[{}]存在重复的id,模块的id不能重复", it)) - .forEach(it -> moduleIdSet.add(it.getId())); - - // 模块名称不能重复 - var moduleNameSet = new HashSet(); - Arrays.stream(modules) - .filter(it -> Objects.nonNull(it)) - .peek(it -> AssertionUtils.isTrue(!moduleNameSet.contains(it.getName()), "模块[{}]存在重复的name,模块名称不能重复", it)) - .forEach(it -> moduleNameSet.add(it.getName())); - } - - private static void checkAllProtocolClass() { - // 检查协议格式 - - // 协议的名称不能重复 - var allProtocolNameMap = new HashMap>(); - for (var protocolRegistration : protocols) { - if (protocolRegistration == null) { - continue; - } - - var protocolClass = protocolRegistration.protocolConstructor().getDeclaringClass(); - var protocolName = protocolClass.getSimpleName(); - if (allProtocolNameMap.containsKey(protocolName)) { - throw new RunException("[class:{}]和[class:{}]协议名称重复,协议不能含有重复的名称", protocolClass.getCanonicalName(), allProtocolNameMap.get(protocolName).getCanonicalName()); - } - - if (tempProtocolReserved.stream().anyMatch(it -> it.equalsIgnoreCase(protocolName))) { - throw new RunException("协议的名称[class:{}]不能是保留名称[{}]", protocolClass.getCanonicalName(), protocolName); - } - - allProtocolNameMap.put(protocolName, protocolClass); - } - - - // 检查循环协议 - for (var protocolEntry : tempSubProtocolIdMap.entrySet()) { - var protocolId = protocolEntry.getKey(); - var subProtocolSet = protocolEntry.getValue(); - if (subProtocolSet.contains(protocolId)) { - var protocolClass = protocols[protocolId].protocolConstructor().getDeclaringClass(); - throw new RunException("[class:{}]在第一层包含循环引用协议[class:{}]", protocolClass.getSimpleName(), protocolClass.getSimpleName()); - } - - getAllSubProtocolIds(protocolId); - } - } - - private static void checkSubProtocol(Class clazz, short id, Class subClass) { - var registerProtocolClass = tempProtocolClassMap.get(id); - if (registerProtocolClass == null || !registerProtocolClass.equals(subClass)) { - throw new RunException("协议[{}]的子协议[{}][{}]没有注册", clazz.getCanonicalName(), id, subClass.getCanonicalName()); - } - } - - private static ProtocolRegistration parseProtocolRegistration(Class clazz, ProtocolModule module) throws IllegalAccessException, NoSuchMethodException, InvocationTargetException, InstantiationException { - var protocolId = checkProtocol(clazz); - - var fields = new ArrayList(); - for (var field : clazz.getDeclaredFields()) { - var modifiers = field.getModifiers(); - if (Modifier.isTransient(modifiers) || Modifier.isStatic(modifiers)) { - continue; - } - - if (Modifier.isFinal(modifiers)) { - throw new RunException("[{}]协议号[protocolId:{}]中的[filed:{}]属性的访问修饰符不能为final" - , clazz.getCanonicalName(), protocolId, field.getName()); - } - - if (!Modifier.isPublic(modifiers) && !Modifier.isPrivate(modifiers)) { - throw new RunException("[{}]协议号[protocolId:{}]中的[filed:{}]属性的访问修饰符必须是public或者private" - , clazz.getCanonicalName(), protocolId, field.getName()); - } - - ReflectionUtils.makeAccessible(field); - fields.add(field); - } - - // 按变量名称从小到大排序 - fields.sort(PACKET_FIELD_COMPARATOR); - - var registrationList = new ArrayList(); - for (var field : fields) { - registrationList.add(toRegistration(clazz, field)); - } - - var constructor = clazz.getDeclaredConstructor(); - ReflectionUtils.makeAccessible(constructor); - var protocol = new ProtocolRegistration(); - protocol.setId(protocolId); - protocol.setConstructor(constructor); - protocol.setFields(fields.toArray(new Field[fields.size()])); - protocol.setFieldRegistrations(registrationList.toArray(new IFieldRegistration[registrationList.size()])); - protocol.setModule(module.getId()); - return protocol; - } - - private static IFieldRegistration toRegistration(Class clazz, Field field) { - Class fieldTypeClazz = field.getType(); - - ISerializer serializer = tempBaseSerializerMap.get(fieldTypeClazz); - - // 是一个基本类型变量 - if (serializer != null) { - return BaseField.valueOf(serializer); - } else if (fieldTypeClazz.getComponentType() != null) { - // 是一个数组 - Class arrayClazz = fieldTypeClazz.getComponentType(); - - IFieldRegistration registration = typeToRegistration(clazz, arrayClazz); - return ArrayField.valueOf(field, registration); - } else if (Set.class.isAssignableFrom(fieldTypeClazz)) { - if (!fieldTypeClazz.equals(Set.class)) { - throw new RunException("[class:{}]类型声明不正确,必须是Set接口类型", clazz.getCanonicalName()); - } - - Type type = field.getGenericType(); - - if (!(type instanceof ParameterizedType)) { - throw new RunException("[class:{}]类型声明不正确,不是泛型类[field:{}]", clazz.getCanonicalName(), field.getName()); - } - - Type[] types = ((ParameterizedType) type).getActualTypeArguments(); - - if (types.length != 1) { - throw new RunException("[class:{}]中Set类型声明不正确,[field:{}]必须声明泛型类", clazz.getCanonicalName(), field.getName()); - } - - IFieldRegistration registration = typeToRegistration(clazz, types[0]); - return SetField.valueOf(registration, type); - } else if (List.class.isAssignableFrom(fieldTypeClazz)) { - // 是一个List - if (!fieldTypeClazz.equals(List.class)) { - throw new RunException("[class:{}]类型声明不正确,必须是List接口类型", clazz.getCanonicalName()); - } - - Type type = field.getGenericType(); - - if (!(type instanceof ParameterizedType)) { - throw new RunException("[class:{}]类型声明不正确,不是泛型类[field:{}]", clazz.getCanonicalName(), field.getName()); - } - - Type[] types = ((ParameterizedType) type).getActualTypeArguments(); - - if (types.length != 1) { - throw new RunException("[class:{}]中List类型声明不正确,[field:{}]必须声明泛型类", clazz.getCanonicalName(), field.getName()); - } - - IFieldRegistration registration = typeToRegistration(clazz, types[0]); - return ListField.valueOf(registration, type); - - } else if (Map.class.isAssignableFrom(fieldTypeClazz)) { - if (!fieldTypeClazz.equals(Map.class)) { - throw new RunException("[class:{}]类型声明不正确,必须是Map接口类型", clazz.getCanonicalName()); - } - - Type type = field.getGenericType(); - - if (!(type instanceof ParameterizedType)) { - throw new RunException("[class:{}]中数组类型声明不正确,[field:{}]不是泛型类", clazz.getCanonicalName(), field.getName()); - } - - Type[] types = ((ParameterizedType) type).getActualTypeArguments(); - - if (types.length != 2) { - throw new RunException("[class:{}]中数组类型声明不正确,[field:{}]必须声明泛型类", clazz.getCanonicalName(), field.getName()); - } - - IFieldRegistration keyRegistration = typeToRegistration(clazz, types[0]); - IFieldRegistration valueRegistration = typeToRegistration(clazz, types[1]); - - return MapField.valueOf(keyRegistration, valueRegistration, type); - } else { - // 是一个协议引用变量 - var referenceProtocolId = getProtocolIdByClass(field.getType()); - checkSubProtocol(clazz, referenceProtocolId, field.getType()); - tempSubProtocolIdMap.computeIfAbsent(getProtocolIdByClass(clazz), it -> new HashSet<>()).add(referenceProtocolId); - return ObjectProtocolField.valueOf(referenceProtocolId); - } - } - - private static IFieldRegistration typeToRegistration(Class currentProtocolClass, Type type) { - if (type instanceof ParameterizedType) { - // 泛型类 - Class clazz = (Class) ((ParameterizedType) type).getRawType(); - if (Set.class.equals(clazz)) { - // Set> - IFieldRegistration registration = typeToRegistration(currentProtocolClass, ((ParameterizedType) type).getActualTypeArguments()[0]); - return SetField.valueOf(registration, type); - } else if (List.class.equals(clazz)) { - // List> - IFieldRegistration registration = typeToRegistration(currentProtocolClass, ((ParameterizedType) type).getActualTypeArguments()[0]); - return ListField.valueOf(registration, (ParameterizedType) type); - } else if (Map.class.equals(clazz)) { - // Map, List> - IFieldRegistration keyRegistration = typeToRegistration(currentProtocolClass, ((ParameterizedType) type).getActualTypeArguments()[0]); - IFieldRegistration valueRegistration = typeToRegistration(currentProtocolClass, ((ParameterizedType) type).getActualTypeArguments()[1]); - return MapField.valueOf(keyRegistration, valueRegistration, type); - } - } else if (type instanceof Class) { - Class clazz = ((Class) type); - ISerializer serializer = tempBaseSerializerMap.get(clazz); - if (serializer != null) { - // 基础类型 - return BaseField.valueOf(serializer); - } else if (clazz.getComponentType() != null) { - // 是一个二维以上数组 - throw new RunException("不支持多维数组或集合嵌套数组[type:{}]类型,仅支持一维数组", type); - } else if (clazz.equals(List.class) || clazz.equals(Set.class) || clazz.equals(Map.class)) { - throw new RunException("不支持数组和集合联合使用[type:{}]类型", type); - } else { - // 是一个协议引用变量 - var referenceProtocolId = getProtocolIdByClass(clazz); - checkSubProtocol(clazz, referenceProtocolId, clazz); - tempSubProtocolIdMap.computeIfAbsent(getProtocolIdByClass(currentProtocolClass), it -> new HashSet<>()).add(referenceProtocolId); - return ObjectProtocolField.valueOf(referenceProtocolId); - } - } - throw new RunException("[type:{}]类型不正确", type); - } - - - /** - * 此方法仅在生成协议的时候调用,一旦运行,不能调用 - */ - public static Set getAllSubProtocolIds(short protocolId) { - AssertionUtils.notNull(tempSubProtocolIdMap, "[{}]已经初始完成,初始化完成过后不能调用getAllSubProtocolIds", ProtocolManager.class.getSimpleName()); - - if (!tempSubProtocolIdMap.containsKey(protocolId)) { - return Collections.emptySet(); - } - - var protocolClass = protocols[protocolId].protocolConstructor().getDeclaringClass(); - - var queue = new LinkedList<>(tempSubProtocolIdMap.get(protocolId)); - var allSubProtocolIdSet = new HashSet<>(queue); - while (!queue.isEmpty()) { - var firstSubProtocolId = queue.poll(); - if (tempSubProtocolIdMap.containsKey(firstSubProtocolId)) { - for (var subClassId : tempSubProtocolIdMap.get(firstSubProtocolId)) { - if (subClassId == protocolId) { - throw new RunException("[class:{}]在下层协议[class:{}]包含循环引用协议[class:{}]", protocolClass.getSimpleName(), protocols[firstSubProtocolId].protocolConstructor().getDeclaringClass(), protocolClass.getSimpleName()); - } - - if (!allSubProtocolIdSet.contains(subClassId)) { - allSubProtocolIdSet.add(subClassId); - queue.offer(subClassId); - } - } - } - } - return allSubProtocolIdSet; - } } diff --git a/protocol/src/main/java/com/zfoo/protocol/generate/GenerateProtocolFile.java b/protocol/src/main/java/com/zfoo/protocol/generate/GenerateProtocolFile.java index 7ab5a1e7..c43aff37 100644 --- a/protocol/src/main/java/com/zfoo/protocol/generate/GenerateProtocolFile.java +++ b/protocol/src/main/java/com/zfoo/protocol/generate/GenerateProtocolFile.java @@ -14,6 +14,7 @@ package com.zfoo.protocol.generate; import com.zfoo.protocol.ProtocolManager; import com.zfoo.protocol.registration.IProtocolRegistration; +import com.zfoo.protocol.registration.ProtocolAnalysis; import com.zfoo.protocol.registration.ProtocolRegistration; import com.zfoo.protocol.serializer.cs.GenerateCsUtils; import com.zfoo.protocol.serializer.js.GenerateJsUtils; @@ -75,7 +76,7 @@ public abstract class GenerateProtocolFile { // 需要生成的子协议,因为外层协议的内部有其它协议 var insideGenerateProtocols = outsideGenerateProtocols.stream() - .map(it -> ProtocolManager.getAllSubProtocolIds(it.protocolId())) + .map(it -> ProtocolAnalysis.getAllSubProtocolIds(it.protocolId())) .flatMap(it -> it.stream()) .map(it -> protocols[it]) .distinct() diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java b/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java index 6cf52757..1747c9cb 100644 --- a/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java +++ b/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java @@ -136,7 +136,7 @@ public abstract class EnhanceUtils { enhanceClazz.addField(constructorFiled); // 定义类所包含的所有子协议成员 - var allSubProtocolIds = ProtocolManager.getAllSubProtocolIds(protocolId) + var allSubProtocolIds = ProtocolAnalysis.getAllSubProtocolIds(protocolId) .stream() .sorted((a, b) -> Short.compare(a, b)) .collect(Collectors.toList()); diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java new file mode 100644 index 00000000..31e662cf --- /dev/null +++ b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java @@ -0,0 +1,541 @@ +/* + * Copyright (C) 2020 The zfoo Authors + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and limitations under the License. + */ + +package com.zfoo.protocol.registration; + +import com.zfoo.protocol.IPacket; +import com.zfoo.protocol.collection.CollectionUtils; +import com.zfoo.protocol.exception.RunException; +import com.zfoo.protocol.exception.UnknownException; +import com.zfoo.protocol.generate.GenerateOperation; +import com.zfoo.protocol.generate.GenerateProtocolDocument; +import com.zfoo.protocol.generate.GenerateProtocolFile; +import com.zfoo.protocol.generate.GenerateProtocolPath; +import com.zfoo.protocol.registration.field.*; +import com.zfoo.protocol.serializer.*; +import com.zfoo.protocol.serializer.cs.GenerateCsUtils; +import com.zfoo.protocol.serializer.js.GenerateJsUtils; +import com.zfoo.protocol.serializer.lua.GenerateLuaUtils; +import com.zfoo.protocol.util.AssertionUtils; +import com.zfoo.protocol.util.ReflectionUtils; +import com.zfoo.protocol.util.StringUtils; +import com.zfoo.protocol.xml.XmlProtocols; +import javassist.CannotCompileException; +import javassist.NotFoundException; + +import java.io.IOException; +import java.lang.reflect.*; +import java.util.*; +import java.util.stream.Collectors; + +import static com.zfoo.protocol.ProtocolManager.*; + +/** + * @author jaysunxiao + * @version 3.0 + */ +public class ProtocolAnalysis { + + + // 临时变量,启动完成就会销毁,协议名称保留字符,即协议的名称不能用以下名称命名 + private static Set tempProtocolReserved = Set.of("Buffer", "ByteBuf", "ByteBuffer", "LittleEndianByteBuffer", "NormalByteBuffer" + , "IPacket", "IProtocolRegistration", "ProtocolManager", "IFieldRegistration" + , "ByteBufUtils", "ArrayUtils", "CollectionUtils" + , "Boolean", "Byte", "Short", "Integer", "Long", "Float", "Double", "String", "Character", "Object"); + + // 临时变量,启动完成就会销毁,是一个基本类型序列化器 + private static Map, ISerializer> tempBaseSerializerMap = new HashMap<>(); + + // 临时变量,启动完成就会销毁,协议Id对应的Class类 + private static Map> tempProtocolClassMap = new HashMap<>(MAX_PROTOCOL_NUM); + + // 临时变量,启动完成就会销毁,协议下包含的子协议,只包含一层子协议 + private static Map> tempSubProtocolIdMap = new HashMap<>(MAX_PROTOCOL_NUM); + + static { + // 初始化基础类型序列化器 + tempBaseSerializerMap.put(boolean.class, BooleanSerializer.getInstance()); + tempBaseSerializerMap.put(Boolean.class, BooleanSerializer.getInstance()); + tempBaseSerializerMap.put(byte.class, ByteSerializer.getInstance()); + tempBaseSerializerMap.put(Byte.class, ByteSerializer.getInstance()); + tempBaseSerializerMap.put(short.class, ShortSerializer.getInstance()); + tempBaseSerializerMap.put(Short.class, ShortSerializer.getInstance()); + tempBaseSerializerMap.put(int.class, IntSerializer.getInstance()); + tempBaseSerializerMap.put(Integer.class, IntSerializer.getInstance()); + tempBaseSerializerMap.put(long.class, LongSerializer.getInstance()); + tempBaseSerializerMap.put(Long.class, LongSerializer.getInstance()); + tempBaseSerializerMap.put(float.class, FloatSerializer.getInstance()); + tempBaseSerializerMap.put(Float.class, FloatSerializer.getInstance()); + tempBaseSerializerMap.put(double.class, DoubleSerializer.getInstance()); + tempBaseSerializerMap.put(Double.class, DoubleSerializer.getInstance()); + tempBaseSerializerMap.put(char.class, CharSerializer.getInstance()); + tempBaseSerializerMap.put(Character.class, CharSerializer.getInstance()); + tempBaseSerializerMap.put(String.class, StringSerializer.getInstance()); + } + + public static synchronized void analyze(Set> protocolClassSet) { + analyze(protocolClassSet, GenerateOperation.NO_OPERATION); + } + + public static synchronized void analyze(Set> protocolClassSet, GenerateOperation generateOperation) { + AssertionUtils.notNull(tempSubProtocolIdMap, "[{}]已经初始完成,只能parseProtocol一次,请不要重复初始化", ProtocolAnalysis.class.getSimpleName()); + try { + for (var protocolClass : protocolClassSet) { + var id = getProtocolIdByClass(protocolClass); + var previous = tempProtocolClassMap.put(id, protocolClass); + if (previous != null) { + throw new RunException("[{}][{}]协议号[protocolId:{}]重复", protocolClass.getCanonicalName(), previous.getCanonicalName(), id); + } + } + + for (var protocolClass : protocolClassSet) { + try { + var registration = parseProtocolRegistration(protocolClass, ProtocolModule.DEFAULT_PROTOCOL_MODULE); + // 注册协议 + protocols[registration.protocolId()] = registration; + } catch (Exception e) { + throw new RuntimeException(StringUtils.format("解析协议[class:{}]异常", protocolClass), e); + } + } + + enhanceProtocolBefore(generateOperation); + + // 通过指定类注册的协议,全部使用字节码增强 + enhanceProtocolRegistration(Arrays.stream(protocols).filter(it -> Objects.nonNull(it)).collect(Collectors.toList())); + + enhanceProtocolAfter(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + + public static synchronized void analyze(XmlProtocols xmlProtocols, GenerateOperation generateOperation) { + AssertionUtils.notNull(tempSubProtocolIdMap, "[{}]已经初始完成,只能parseProtocol一次,请不要重复初始化", ProtocolAnalysis.class.getSimpleName()); + try { + var enhanceList = new ArrayList(); + + for (var moduleDefinition : xmlProtocols.getModules()) { + var module = new ProtocolModule(moduleDefinition.getId(), moduleDefinition.getName(), moduleDefinition.getVersion()); + + AssertionUtils.isTrue(module.getId() > 0, "[module:{}] [id:{}] 模块必须大于等于1", module.getName(), module.getId()); + AssertionUtils.isNull(modules[module.getId()], "duplicate [module:{}] [id:{}] Exception!", module.getName(), module.getId()); + AssertionUtils.notNull(moduleDefinition.getProtocols(), "[module:{}] does not have any protocols", module.getName()); + + modules[module.getId()] = module; + + for (var protocolDefinition : moduleDefinition.getProtocols()) { + var id = protocolDefinition.getId(); + var location = protocolDefinition.getLocation(); + var clazz = Class.forName(location); + + AssertionUtils.isTrue(id >= moduleDefinition.getMinId(), "模块[{}]中的协议[{}]的协议号必须大于或者等于[{}]", moduleDefinition.getName(), clazz.getSimpleName(), moduleDefinition.getMinId()); + AssertionUtils.isTrue(id < moduleDefinition.getMaxId(), "模块[{}]中的协议[{}]的协议号必须小于[{}]", moduleDefinition.getName(), clazz.getSimpleName(), moduleDefinition.getMaxId()); + AssertionUtils.isNull(protocols[id], "duplicate definition [id:{}] Exception!", id); + + var packet = (IPacket) ReflectionUtils.newInstance(clazz); + + // 协议号是否和id是否相等 + AssertionUtils.isTrue(packet.protocolId() == id, "[class:{}]协议序列号[{}]和协议文件里的协议序列号不相等", clazz.getCanonicalName(), PROTOCOL_ID); + + var previous = tempProtocolClassMap.put(id, clazz); + if (previous != null) { + throw new RunException("[{}][{}]协议号[protocolId:{}]重复", clazz.getCanonicalName(), previous.getCanonicalName(), id); + } + } + } + + for (var moduleDefinition : xmlProtocols.getModules()) { + var module = modules[moduleDefinition.getId()]; + for (var protocolDefinition : moduleDefinition.getProtocols()) { + var id = protocolDefinition.getId(); + var clazz = tempProtocolClassMap.get(id); + try { + var registration = parseProtocolRegistration(clazz, module); + if (protocolDefinition.isEnhance()) { + enhanceList.add(registration); + } + // 注册协议 + protocols[id] = registration; + } catch (Exception e) { + throw new UnknownException(e, "解析协议[id:{}][class:{}]异常", id, clazz); + } + } + } + + enhanceProtocolBefore(generateOperation); + + enhanceProtocolRegistration(enhanceList); + + enhanceProtocolAfter(); + } catch (Exception e) { + throw new UnknownException(e); + } + } + + private static void enhanceProtocolRegistration(List enhanceList) throws NoSuchMethodException, IllegalAccessException, InstantiationException, CannotCompileException, NotFoundException, InvocationTargetException, NoSuchFieldException { + // 字节码增强 + for (var registration : enhanceList) { + protocols[registration.protocolId()] = EnhanceUtils.createProtocolRegistration((ProtocolRegistration) registration); + } + + // 字节码增强过后,初始化各个子协议成员变量 + for (var registration : enhanceList) { + var enhanceProtocolRegistration = protocols[registration.protocolId()]; + var subProtocolIds = getAllSubProtocolIds(registration.protocolId()); + for (var subProtocolId : subProtocolIds) { + var protocolRegistrationField = enhanceProtocolRegistration.getClass().getDeclaredField(EnhanceUtils.getProtocolRegistrationFieldNameByProtocolId(subProtocolId)); + ReflectionUtils.makeAccessible(protocolRegistrationField); + ReflectionUtils.setField(protocolRegistrationField, enhanceProtocolRegistration, protocols[subProtocolId]); + } + } + } + + private static void enhanceProtocolBefore(GenerateOperation generateOperation) throws IOException { + // 检查协议格式 + checkAllProtocolClass(); + + // 检查模块格式 + checkAllModules(); + + // 生成协议 + GenerateProtocolFile.generate(protocols, generateOperation); + } + + private static void enhanceProtocolAfter() { + tempSubProtocolIdMap.clear(); + tempSubProtocolIdMap = null; + + tempProtocolReserved = null; + + tempBaseSerializerMap.clear(); + tempBaseSerializerMap = null; + + GenerateProtocolDocument.clear(); + GenerateProtocolPath.clear(); + GenerateProtocolFile.clear(); + GenerateCsUtils.clear(); + GenerateJsUtils.clear(); + GenerateLuaUtils.clear(); + + EnhanceUtils.clear(); + } + + + private static short checkProtocol(Class clazz) throws IllegalAccessException, InvocationTargetException, InstantiationException { + // 是否为一个简单的javabean + AssertionUtils.isTrue(clazz.getSuperclass().equals(Object.class), "[class:{}]不是简单的javabean,不能继承别的类", clazz.getCanonicalName()); + // 是否实现了IPacket接口 + AssertionUtils.isTrue(IPacket.class.isAssignableFrom(clazz), "[class:{}]没有实现接口[IPacket:{}]", clazz.getCanonicalName(), IPacket.class.getCanonicalName()); + // 不能是泛型类 + AssertionUtils.isTrue(CollectionUtils.isEmpty(clazz.getTypeParameters()), "[class:{}]不能是泛型类", clazz.getCanonicalName()); + + Field protocolIdField; + try { + protocolIdField = clazz.getDeclaredField(PROTOCOL_ID); + } catch (NoSuchFieldException e) { + throw new UnknownException(e, "[class:{}]没有[{}]协议序列号", clazz.getCanonicalName(), PROTOCOL_ID); + } + + // 是否被public修饰 + AssertionUtils.isTrue(Modifier.isPublic(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被public修饰", clazz.getCanonicalName(), PROTOCOL_ID); + // 是否被static修饰 + AssertionUtils.isTrue(Modifier.isStatic(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被static修饰", clazz.getCanonicalName(), PROTOCOL_ID); + // 是否被final修饰 + AssertionUtils.isTrue(Modifier.isFinal(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被final修饰", clazz.getCanonicalName(), PROTOCOL_ID); + // 是否被transient修饰 + AssertionUtils.isTrue(Modifier.isTransient(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被transient修饰", clazz.getCanonicalName(), PROTOCOL_ID); + // 命名只能包含字母,数字,下划线 + AssertionUtils.isTrue(clazz.getSimpleName().matches("[a-zA-Z0-9_]*"), "[class:{}]的命名只能包含字母,数字,下划线", clazz.getCanonicalName(), PROTOCOL_ID); + + // 必须要有一个空的构造器 + Constructor constructor; + try { + constructor = clazz.getDeclaredConstructor(); + } catch (NoSuchMethodException e) { + throw new UnknownException(e, "[class:{}]协议序列号[{}]必须有一个空的构造器", clazz.getCanonicalName(), PROTOCOL_ID); + } + ReflectionUtils.makeAccessible(protocolIdField); + IPacket packet = (IPacket) constructor.newInstance(); + + // 验证protocol()方法的返回是否和PROTOCOL_ID相等 + AssertionUtils.isTrue(Short.valueOf(packet.protocolId()).equals(protocolIdField.get(null)), "[class:{}]的protocolId返回的值和协议号的静态变量[{}]不相等", clazz.getCanonicalName(), PROTOCOL_ID); + return packet.protocolId(); + } + + private static void checkAllModules() { + // 模块id不能重复 + var moduleIdSet = new HashSet(); + Arrays.stream(modules) + .filter(it -> Objects.nonNull(it)) + .peek(it -> AssertionUtils.isTrue(!moduleIdSet.contains(it.getId()), "模块[{}]存在重复的id,模块的id不能重复", it)) + .forEach(it -> moduleIdSet.add(it.getId())); + + // 模块名称不能重复 + var moduleNameSet = new HashSet(); + Arrays.stream(modules) + .filter(it -> Objects.nonNull(it)) + .peek(it -> AssertionUtils.isTrue(!moduleNameSet.contains(it.getName()), "模块[{}]存在重复的name,模块名称不能重复", it)) + .forEach(it -> moduleNameSet.add(it.getName())); + } + + private static void checkAllProtocolClass() { + // 检查协议格式 + + // 协议的名称不能重复 + var allProtocolNameMap = new HashMap>(); + for (var protocolRegistration : protocols) { + if (protocolRegistration == null) { + continue; + } + + var protocolClass = protocolRegistration.protocolConstructor().getDeclaringClass(); + var protocolName = protocolClass.getSimpleName(); + if (allProtocolNameMap.containsKey(protocolName)) { + throw new RunException("[class:{}]和[class:{}]协议名称重复,协议不能含有重复的名称", protocolClass.getCanonicalName(), allProtocolNameMap.get(protocolName).getCanonicalName()); + } + + if (tempProtocolReserved.stream().anyMatch(it -> it.equalsIgnoreCase(protocolName))) { + throw new RunException("协议的名称[class:{}]不能是保留名称[{}]", protocolClass.getCanonicalName(), protocolName); + } + + allProtocolNameMap.put(protocolName, protocolClass); + } + + + // 检查循环协议 + for (var protocolEntry : tempSubProtocolIdMap.entrySet()) { + var protocolId = protocolEntry.getKey(); + var subProtocolSet = protocolEntry.getValue(); + if (subProtocolSet.contains(protocolId)) { + var protocolClass = protocols[protocolId].protocolConstructor().getDeclaringClass(); + throw new RunException("[class:{}]在第一层包含循环引用协议[class:{}]", protocolClass.getSimpleName(), protocolClass.getSimpleName()); + } + + getAllSubProtocolIds(protocolId); + } + } + + private static void checkSubProtocol(Class clazz, short id, Class subClass) { + var registerProtocolClass = tempProtocolClassMap.get(id); + if (registerProtocolClass == null || !registerProtocolClass.equals(subClass)) { + throw new RunException("协议[{}]的子协议[{}][{}]没有注册", clazz.getCanonicalName(), id, subClass.getCanonicalName()); + } + } + + private static ProtocolRegistration parseProtocolRegistration(Class clazz, ProtocolModule module) throws IllegalAccessException, NoSuchMethodException, InvocationTargetException, InstantiationException { + var protocolId = checkProtocol(clazz); + + var fields = new ArrayList(); + for (var field : clazz.getDeclaredFields()) { + var modifiers = field.getModifiers(); + if (Modifier.isTransient(modifiers) || Modifier.isStatic(modifiers)) { + continue; + } + + if (Modifier.isFinal(modifiers)) { + throw new RunException("[{}]协议号[protocolId:{}]中的[filed:{}]属性的访问修饰符不能为final" + , clazz.getCanonicalName(), protocolId, field.getName()); + } + + if (!Modifier.isPublic(modifiers) && !Modifier.isPrivate(modifiers)) { + throw new RunException("[{}]协议号[protocolId:{}]中的[filed:{}]属性的访问修饰符必须是public或者private" + , clazz.getCanonicalName(), protocolId, field.getName()); + } + + ReflectionUtils.makeAccessible(field); + fields.add(field); + } + + // 按变量名称从小到大排序 + fields.sort(PACKET_FIELD_COMPARATOR); + + var registrationList = new ArrayList(); + for (var field : fields) { + registrationList.add(toRegistration(clazz, field)); + } + + var constructor = clazz.getDeclaredConstructor(); + ReflectionUtils.makeAccessible(constructor); + var protocol = new ProtocolRegistration(); + protocol.setId(protocolId); + protocol.setConstructor(constructor); + protocol.setFields(fields.toArray(new Field[fields.size()])); + protocol.setFieldRegistrations(registrationList.toArray(new IFieldRegistration[registrationList.size()])); + protocol.setModule(module.getId()); + return protocol; + } + + private static IFieldRegistration toRegistration(Class clazz, Field field) { + Class fieldTypeClazz = field.getType(); + + ISerializer serializer = tempBaseSerializerMap.get(fieldTypeClazz); + + // 是一个基本类型变量 + if (serializer != null) { + return BaseField.valueOf(serializer); + } else if (fieldTypeClazz.getComponentType() != null) { + // 是一个数组 + Class arrayClazz = fieldTypeClazz.getComponentType(); + + IFieldRegistration registration = typeToRegistration(clazz, arrayClazz); + return ArrayField.valueOf(field, registration); + } else if (Set.class.isAssignableFrom(fieldTypeClazz)) { + if (!fieldTypeClazz.equals(Set.class)) { + throw new RunException("[class:{}]类型声明不正确,必须是Set接口类型", clazz.getCanonicalName()); + } + + Type type = field.getGenericType(); + + if (!(type instanceof ParameterizedType)) { + throw new RunException("[class:{}]类型声明不正确,不是泛型类[field:{}]", clazz.getCanonicalName(), field.getName()); + } + + Type[] types = ((ParameterizedType) type).getActualTypeArguments(); + + if (types.length != 1) { + throw new RunException("[class:{}]中Set类型声明不正确,[field:{}]必须声明泛型类", clazz.getCanonicalName(), field.getName()); + } + + IFieldRegistration registration = typeToRegistration(clazz, types[0]); + return SetField.valueOf(registration, type); + } else if (List.class.isAssignableFrom(fieldTypeClazz)) { + // 是一个List + if (!fieldTypeClazz.equals(List.class)) { + throw new RunException("[class:{}]类型声明不正确,必须是List接口类型", clazz.getCanonicalName()); + } + + Type type = field.getGenericType(); + + if (!(type instanceof ParameterizedType)) { + throw new RunException("[class:{}]类型声明不正确,不是泛型类[field:{}]", clazz.getCanonicalName(), field.getName()); + } + + Type[] types = ((ParameterizedType) type).getActualTypeArguments(); + + if (types.length != 1) { + throw new RunException("[class:{}]中List类型声明不正确,[field:{}]必须声明泛型类", clazz.getCanonicalName(), field.getName()); + } + + IFieldRegistration registration = typeToRegistration(clazz, types[0]); + return ListField.valueOf(registration, type); + + } else if (Map.class.isAssignableFrom(fieldTypeClazz)) { + if (!fieldTypeClazz.equals(Map.class)) { + throw new RunException("[class:{}]类型声明不正确,必须是Map接口类型", clazz.getCanonicalName()); + } + + Type type = field.getGenericType(); + + if (!(type instanceof ParameterizedType)) { + throw new RunException("[class:{}]中数组类型声明不正确,[field:{}]不是泛型类", clazz.getCanonicalName(), field.getName()); + } + + Type[] types = ((ParameterizedType) type).getActualTypeArguments(); + + if (types.length != 2) { + throw new RunException("[class:{}]中数组类型声明不正确,[field:{}]必须声明泛型类", clazz.getCanonicalName(), field.getName()); + } + + IFieldRegistration keyRegistration = typeToRegistration(clazz, types[0]); + IFieldRegistration valueRegistration = typeToRegistration(clazz, types[1]); + + return MapField.valueOf(keyRegistration, valueRegistration, type); + } else { + // 是一个协议引用变量 + var referenceProtocolId = getProtocolIdByClass(field.getType()); + checkSubProtocol(clazz, referenceProtocolId, field.getType()); + tempSubProtocolIdMap.computeIfAbsent(getProtocolIdByClass(clazz), it -> new HashSet<>()).add(referenceProtocolId); + return ObjectProtocolField.valueOf(referenceProtocolId); + } + } + + private static IFieldRegistration typeToRegistration(Class currentProtocolClass, Type type) { + if (type instanceof ParameterizedType) { + // 泛型类 + Class clazz = (Class) ((ParameterizedType) type).getRawType(); + if (Set.class.equals(clazz)) { + // Set> + IFieldRegistration registration = typeToRegistration(currentProtocolClass, ((ParameterizedType) type).getActualTypeArguments()[0]); + return SetField.valueOf(registration, type); + } else if (List.class.equals(clazz)) { + // List> + IFieldRegistration registration = typeToRegistration(currentProtocolClass, ((ParameterizedType) type).getActualTypeArguments()[0]); + return ListField.valueOf(registration, (ParameterizedType) type); + } else if (Map.class.equals(clazz)) { + // Map, List> + IFieldRegistration keyRegistration = typeToRegistration(currentProtocolClass, ((ParameterizedType) type).getActualTypeArguments()[0]); + IFieldRegistration valueRegistration = typeToRegistration(currentProtocolClass, ((ParameterizedType) type).getActualTypeArguments()[1]); + return MapField.valueOf(keyRegistration, valueRegistration, type); + } + } else if (type instanceof Class) { + Class clazz = ((Class) type); + ISerializer serializer = tempBaseSerializerMap.get(clazz); + if (serializer != null) { + // 基础类型 + return BaseField.valueOf(serializer); + } else if (clazz.getComponentType() != null) { + // 是一个二维以上数组 + throw new RunException("不支持多维数组或集合嵌套数组[type:{}]类型,仅支持一维数组", type); + } else if (clazz.equals(List.class) || clazz.equals(Set.class) || clazz.equals(Map.class)) { + throw new RunException("不支持数组和集合联合使用[type:{}]类型", type); + } else { + // 是一个协议引用变量 + var referenceProtocolId = getProtocolIdByClass(clazz); + checkSubProtocol(clazz, referenceProtocolId, clazz); + tempSubProtocolIdMap.computeIfAbsent(getProtocolIdByClass(currentProtocolClass), it -> new HashSet<>()).add(referenceProtocolId); + return ObjectProtocolField.valueOf(referenceProtocolId); + } + } + throw new RunException("[type:{}]类型不正确", type); + } + + private static short getProtocolIdByClass(Class clazz) { + var protocolIdField = ReflectionUtils.getFieldByNameInPOJOClass(clazz, PROTOCOL_ID); + ReflectionUtils.makeAccessible(protocolIdField); + return (short) ReflectionUtils.getField(protocolIdField, null); + } + + + /** + * 此方法仅在生成协议的时候调用,一旦运行,不能调用 + */ + public static Set getAllSubProtocolIds(short protocolId) { + AssertionUtils.notNull(tempSubProtocolIdMap, "[{}]已经初始完成,初始化完成过后不能调用getAllSubProtocolIds", ProtocolAnalysis.class.getSimpleName()); + + if (!tempSubProtocolIdMap.containsKey(protocolId)) { + return Collections.emptySet(); + } + + var protocolClass = protocols[protocolId].protocolConstructor().getDeclaringClass(); + + var queue = new LinkedList<>(tempSubProtocolIdMap.get(protocolId)); + var allSubProtocolIdSet = new HashSet<>(queue); + while (!queue.isEmpty()) { + var firstSubProtocolId = queue.poll(); + if (tempSubProtocolIdMap.containsKey(firstSubProtocolId)) { + for (var subClassId : tempSubProtocolIdMap.get(firstSubProtocolId)) { + if (subClassId == protocolId) { + throw new RunException("[class:{}]在下层协议[class:{}]包含循环引用协议[class:{}]", protocolClass.getSimpleName(), protocols[firstSubProtocolId].protocolConstructor().getDeclaringClass(), protocolClass.getSimpleName()); + } + + if (!allSubProtocolIdSet.contains(subClassId)) { + allSubProtocolIdSet.add(subClassId); + queue.offer(subClassId); + } + } + } + } + return allSubProtocolIdSet; + } + +} diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/js/GenerateJsUtils.java b/protocol/src/main/java/com/zfoo/protocol/serializer/js/GenerateJsUtils.java index 0f5ebbe6..3adfa7cd 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/js/GenerateJsUtils.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/js/GenerateJsUtils.java @@ -19,6 +19,7 @@ import com.zfoo.protocol.generate.GenerateProtocolDocument; import com.zfoo.protocol.generate.GenerateProtocolFile; import com.zfoo.protocol.generate.GenerateProtocolPath; import com.zfoo.protocol.registration.IProtocolRegistration; +import com.zfoo.protocol.registration.ProtocolAnalysis; import com.zfoo.protocol.registration.ProtocolRegistration; import com.zfoo.protocol.registration.field.IFieldRegistration; import com.zfoo.protocol.serializer.*; @@ -129,7 +130,7 @@ public abstract class GenerateJsUtils { var jsBuilder = new StringBuilder(); // 如果协议包含子协议,则需要导入ProtocolManager - var subProtocols = ProtocolManager.getAllSubProtocolIds(protocolId); + var subProtocols = ProtocolAnalysis.getAllSubProtocolIds(protocolId); if (CollectionUtils.isNotEmpty(subProtocols)) { var path = GenerateProtocolPath.getProtocolPath(protocolId); if (StringUtils.isBlank(path)) { diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/lua/GenerateLuaUtils.java b/protocol/src/main/java/com/zfoo/protocol/serializer/lua/GenerateLuaUtils.java index f7c19fec..5747d2fb 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/lua/GenerateLuaUtils.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/lua/GenerateLuaUtils.java @@ -19,6 +19,7 @@ import com.zfoo.protocol.generate.GenerateProtocolDocument; import com.zfoo.protocol.generate.GenerateProtocolFile; import com.zfoo.protocol.generate.GenerateProtocolPath; import com.zfoo.protocol.registration.IProtocolRegistration; +import com.zfoo.protocol.registration.ProtocolAnalysis; import com.zfoo.protocol.registration.ProtocolRegistration; import com.zfoo.protocol.registration.field.IFieldRegistration; import com.zfoo.protocol.serializer.*; @@ -169,7 +170,7 @@ public abstract class GenerateLuaUtils { // 如果协议包含子协议,则需要导入ProtocolManager - var subProtocols = ProtocolManager.getAllSubProtocolIds(protocolId); + var subProtocols = ProtocolAnalysis.getAllSubProtocolIds(protocolId); if (CollectionUtils.isNotEmpty(subProtocols)) { luaBuilder.append("local ProtocolManager = require(\"LuaProtocol.ProtocolManager\")").append(LS + LS); }