diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java index 5561f9f1..0eb28e46 100644 --- a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java +++ b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java @@ -99,7 +99,9 @@ public class ProtocolAnalysis { try { // 检查协议类是否合法 for (var protocolClass : protocolClassSet) { - checkProtocol(protocolClass); + var protocolId = getProtocolIdAndCheckClass(protocolClass); + AssertionUtils.isTrue(protocolId >= 0, "[class:{}]必须使用注解@Protocol注解标注或者使用[{}]字段", protocolClass.getCanonicalName(), PROTOCOL_ID); + initProtocolClass(protocolId, protocolClass); } // 协议id和协议信息对应起来 @@ -134,17 +136,22 @@ public class ProtocolAnalysis { for (var protocolDefinition : moduleDefinition.getProtocols()) { var location = protocolDefinition.getLocation(); var clazz = Class.forName(location); - var id = getProtocolIdAndCheckClass(clazz); + var protocolId = protocolDefinition.getId(); - AssertionUtils.isTrue(id >= moduleDefinition.getMinId(), "模块[{}]中的协议[{}]的协议号必须大于或者等于[{}]", moduleDefinition.getName(), clazz.getSimpleName(), moduleDefinition.getMinId()); - AssertionUtils.isTrue(id < moduleDefinition.getMaxId(), "模块[{}]中的协议[{}]的协议号必须小于[{}]", moduleDefinition.getName(), clazz.getSimpleName(), moduleDefinition.getMaxId()); - AssertionUtils.isNull(protocols[id], "duplicate definition [id:{}] Exception!", id); - - // 协议号是否和id是否相等,如果xml文件中没有填protocolId则不检测 - if (protocolDefinition.getId() >= 0) { - AssertionUtils.isTrue(protocolDefinition.getId() == id, "[class:{}]协议序列号[{}]和协议文件里的协议序列号不相等", clazz.getCanonicalName(), PROTOCOL_ID); + // 如果xml文件中没有填protocolId则只需要获取到protocolId即可 + if (protocolId < 0) { + protocolId = getProtocolIdAndCheckClass(clazz); + AssertionUtils.isTrue(protocolId >= 0, "[class:{}]在使用xml方式注册协议,如果xml没有提供协议号,则需要使用注解或者协议字段标注协议号", clazz.getCanonicalName()); + } else { + var id = getProtocolIdAndCheckClass(clazz); + // 使用xml方式注册协议可以,协议class不需要使用注解或者字段标注协议号 + if (id >= 0) { + AssertionUtils.isTrue(protocolId == id, "[class:{}]协议序列号[{}]和协议文件里的协议序列号不相等", clazz.getCanonicalName(), PROTOCOL_ID); + } } - checkProtocol(clazz); + AssertionUtils.isTrue(protocolId >= moduleDefinition.getMinId(), "模块[{}]中的协议[{}]的协议号必须大于或者等于[{}]", moduleDefinition.getName(), clazz.getSimpleName(), moduleDefinition.getMinId()); + AssertionUtils.isTrue(protocolId < moduleDefinition.getMaxId(), "模块[{}]中的协议[{}]的协议号必须小于[{}]", moduleDefinition.getName(), clazz.getSimpleName(), moduleDefinition.getMaxId()); + initProtocolClass(protocolId, clazz); } } @@ -169,30 +176,11 @@ public class ProtocolAnalysis { } private static void enhance(GenerateOperation generateOperation, List enhanceList) throws IOException, ClassNotFoundException, NotFoundException, CannotCompileException, NoSuchFieldException, InvocationTargetException, NoSuchMethodException, IllegalAccessException, InstantiationException { - initProtocolIdMap(); enhanceProtocolBefore(generateOperation); enhanceProtocolRegistration(enhanceList); enhanceProtocolAfter(generateOperation); } - private static void initProtocolIdMap() { - for (var protocol : protocols) { - if (protocol == null) { - continue; - } - var clazz = protocol.protocolConstructor().getDeclaringClass(); - var protocolId = protocol.protocolId(); - protocolIdMap.put(clazz, protocolId); - protocolIdPrimitiveMap.putPrimitive(clazz.hashCode(), protocolId); - } - var distinctHashcode = protocolIdMap.keySet().stream().map(Object::hashCode).distinct().count(); - if (distinctHashcode == protocolIdMap.size()) { - protocolIdMap = null; - } else { - protocolIdPrimitiveMap = null; - } - } - private static void enhanceProtocolBefore(GenerateOperation generateOperation) throws IOException, ClassNotFoundException { // 检查协议格式 checkAllProtocolClass(); @@ -221,6 +209,13 @@ public class ProtocolAnalysis { } private static void enhanceProtocolAfter(GenerateOperation generateOperation) { + var distinctHashcode = protocolIdMap.keySet().stream().map(Object::hashCode).distinct().count(); + if (distinctHashcode == protocolIdMap.size()) { + protocolIdMap = null; + } else { + protocolIdPrimitiveMap = null; + } + subProtocolIdMap = null; protocolReserved = null; baseSerializerMap = null; @@ -426,7 +421,7 @@ public class ProtocolAnalysis { } else { // 是一个协议引用变量 var referenceProtocolId = getProtocolIdAndCheckClass(clazz); - checkSubProtocol(clazz, referenceProtocolId, clazz); + checkSubProtocol(currentProtocolClass, referenceProtocolId, clazz); subProtocolIdMap.computeIfAbsent(getProtocolIdAndCheckClass(currentProtocolClass), it -> new HashSet<>()).add(referenceProtocolId); return ObjectProtocolField.valueOf(referenceProtocolId); } @@ -469,17 +464,9 @@ public class ProtocolAnalysis { // 协议智能语法分析,错误的协议定义将无法启动程序并给出错误警告 //----------------------------------------------------------------------- - - private static void checkProtocol(Class clazz) { - // 是否为一个简单的javabean - ReflectionUtils.assertIsPojoClass(clazz); - // 是否实现了IPacket接口 - AssertionUtils.isTrue(IPacket.class.isAssignableFrom(clazz), "[class:{}]没有实现接口[IPacket:{}]", clazz.getCanonicalName(), IPacket.class.getCanonicalName()); - // 不能是泛型类 - AssertionUtils.isTrue(ArrayUtils.isEmpty(clazz.getTypeParameters()), "[class:{}]不能是泛型类", clazz.getCanonicalName()); - - var protocolId = getProtocolIdAndCheckClass(clazz); - + private static void initProtocolClass(short protocolId, Class clazz) { + protocolIdMap.put(clazz, protocolId); + protocolIdPrimitiveMap.putPrimitive(clazz.hashCode(), protocolId); var previous = protocolClassMap.put(protocolId, clazz); if (previous != null) { throw new RunException("[{}][{}]协议号[protocolId:{}]重复", clazz.getCanonicalName(), previous.getCanonicalName(), protocolId); @@ -487,6 +474,13 @@ public class ProtocolAnalysis { } public static short getProtocolIdAndCheckClass(Class clazz) { + // 是否为一个简单的javabean + ReflectionUtils.assertIsPojoClass(clazz); + // 是否实现了IPacket接口 + AssertionUtils.isTrue(IPacket.class.isAssignableFrom(clazz), "[class:{}]没有实现接口[IPacket:{}]", clazz.getCanonicalName(), IPacket.class.getCanonicalName()); + // 不能是泛型类 + AssertionUtils.isTrue(ArrayUtils.isEmpty(clazz.getTypeParameters()), "[class:{}]不能是泛型类", clazz.getCanonicalName()); + Field protocolIdField = null; try { protocolIdField = clazz.getDeclaredField(PROTOCOL_ID);