From 389de2b6b3d248a4daa428f5c650f2e3fa1e0385 Mon Sep 17 00:00:00 2001 From: jaysunxiao Date: Sun, 22 May 2022 20:52:48 +0800 Subject: [PATCH] =?UTF-8?q?perf[protocol]:=20=E6=94=AF=E6=8C=81cpp?= =?UTF-8?q?=E5=8D=8F=E8=AE=AE=E5=A2=9E=E5=8A=A0=E5=AD=97=E6=AE=B5=E5=90=91?= =?UTF-8?q?=E5=89=8D=E5=85=BC=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- protocol/README.md | 3 +- .../com/zfoo/protocol/ProtocolManager.java | 1 - .../protocol/registration/EnhanceUtils.java | 73 ++-- .../registration/ProtocolAnalysis.java | 380 +++++++++--------- .../registration/ProtocolRegistration.java | 5 + .../registration/anno/Compatible.java | 29 ++ .../serializer/cpp/GenerateCppUtils.java | 32 +- .../com/zfoo/protocol/util/JsonUtils.java | 4 +- protocol/src/main/resources/cpp/ByteBuffer.h | 4 + .../src/test/cpp/cppProtocol/ByteBuffer.h | 4 + .../cpp/cppProtocol/Packet/ComplexObject.h | 20 +- .../zfoo/protocol/packet/ComplexObject.java | 23 ++ 13 files changed, 327 insertions(+), 253 deletions(-) create mode 100644 protocol/src/main/java/com/zfoo/protocol/registration/anno/Compatible.java diff --git a/README.md b/README.md index 3f6a099c..506238f3 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ - 性能需求极高的项目,如网站和游戏服务器框架,单服滚服,全球服,直播聊天,IM系统,实时推送 - 节省研发成本的项目,如想节省,开发,部署,运维成本 - 适合作为 **Godot,Unity,Cocos,Webgl,H5** 的后端基础框架,网络通信协议支持 tcp udp websocket http -- 语言支持 **Java Javascript C# Lua GDScript**,可以轻易实现跨平台 +- 协议目前原生支持 **C++ Java Javascript C# Lua GDScript**,可以轻易实现跨平台 - 喜欢 [KISS法则](https://baike.baidu.com/item/KISS原则/3242383) 的项目 ,简单的配置,优雅的代码 Ⅲ. 详细的教程和完整的工程案例 diff --git a/protocol/README.md b/protocol/README.md index 4d4c8c19..fe317cfa 100644 --- a/protocol/README.md +++ b/protocol/README.md @@ -2,7 +2,7 @@ - [zfoo protocol](https://github.com/zfoo-project/zfoo/blob/main/protocol/README.md) 是目前的Java二进制序列化和反序列化最快的框架,并且为序列化后字节最少的框架 -- 协议目前原生支持Java Javascript C# Lua GDScript,协议理论上可以跨平台 +- 协议目前原生支持 **C++ Java Javascript C# Lua GDScript**,可以轻易实现跨平台 - 使用Javassist字节码增强动态生成顺序执行的序列化和反序列化函数,顺序执行的函数可以轻易的被JIT编译以达到极致的性能 - 兼容protobuf,支持生成protobuf协议文件,提供从pojo到proto的生成方式 @@ -54,7 +54,6 @@ cpu: i9900k ``` 无漏洞注入风险,只有初始化时会进行字节码增强,后期不会再进行任何字节码的操作 数据压缩体积小,压缩体积比kryo和protobuf都要小;比kryo小是因为kryo需要写入每个对象的注册号 -跨平台可以轻易实现,目前已经原生支持Java,Javascript,C#,Lua,目前kryo无法跨平台,protobuf可以跨平台 智能语法分析,错误的协议定义将无法启动程序并给出错误警告 提升开发效率,完全支持POJO方式开发,使用非常简单 ``` diff --git a/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java b/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java index 223127c4..e549fb47 100644 --- a/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java +++ b/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java @@ -36,7 +36,6 @@ public class ProtocolManager { public static final String PROTOCOL_ID = "PROTOCOL_ID"; public static final short MAX_PROTOCOL_NUM = Short.MAX_VALUE; public static final byte MAX_MODULE_NUM = Byte.MAX_VALUE; - public static final Comparator PACKET_FIELD_COMPARATOR = (a, b) -> a.getName().compareTo(b.getName()); public static final IProtocolRegistration[] protocols = new IProtocolRegistration[MAX_PROTOCOL_NUM]; public static final ProtocolModule[] modules = new ProtocolModule[MAX_MODULE_NUM]; diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java b/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java index 458c33cd..cd6afc90 100644 --- a/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java +++ b/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java @@ -18,6 +18,7 @@ import com.zfoo.protocol.buffer.ByteBufUtils; import com.zfoo.protocol.collection.ArrayUtils; import com.zfoo.protocol.collection.CollectionUtils; import com.zfoo.protocol.generate.GenerateProtocolFile; +import com.zfoo.protocol.registration.anno.Compatible; import com.zfoo.protocol.registration.field.IFieldRegistration; import com.zfoo.protocol.serializer.enhance.*; import com.zfoo.protocol.serializer.reflect.*; @@ -51,12 +52,7 @@ public abstract class EnhanceUtils { public static String byteBufUtilsWriteInt0 = byteBufUtils + ".writeInt($1, 0);"; static { - var classArray = new Class[]{ - IPacket.class, - IProtocolRegistration.class, - IFieldRegistration.class, - ByteBuf.class - }; + var classArray = new Class[]{IPacket.class, IProtocolRegistration.class, IFieldRegistration.class, ByteBuf.class}; var classPool = ClassPool.getDefault(); @@ -118,12 +114,11 @@ public abstract class EnhanceUtils { * @return 返回类的名称格式:EnhanceUtilsProtocolRegistration1 */ public static IProtocolRegistration createProtocolRegistration(ProtocolRegistration registration) throws NotFoundException, CannotCompileException, NoSuchMethodException, IllegalAccessException, InvocationTargetException, InstantiationException { - var classPool = ClassPool.getDefault(); - GenerateProtocolFile.index.set(0); - short protocolId = registration.getId(); - IFieldRegistration[] packetFields = registration.getFieldRegistrations(); + var classPool = ClassPool.getDefault(); + var protocolId = registration.getId(); + var packetFields = registration.getFieldRegistrations(); // 定义类名称 CtClass enhanceClazz = classPool.makeClass(ProtocolRegistration.class.getCanonicalName() + protocolId); @@ -198,22 +193,18 @@ public abstract class EnhanceUtils { // see: ProtocolRegistration.write() private static String writeMethodBody(ProtocolRegistration registration) { - short protocolId = registration.getId(); - Constructor constructor = registration.getConstructor(); - Field[] fields = registration.getFields(); - IFieldRegistration[] fieldRegistrations = registration.getFieldRegistrations(); + var constructor = registration.getConstructor(); + var fields = registration.getFields(); + var fieldRegistrations = registration.getFieldRegistrations(); + var packetClazz = constructor.getDeclaringClass(); - Class packetClazz = constructor.getDeclaringClass(); - - StringBuilder builder = new StringBuilder(); - builder.append("{"); - builder.append(packetClazz.getCanonicalName() + " packet = (" + packetClazz.getCanonicalName() + ")$2;"); - builder.append("if(ByteBufUtils.writePacketFlag($1, packet)){") - .append("return;}"); - for (int i = 0; i < fields.length; i++) { - Field field = fields[i]; - IFieldRegistration fieldRegistration = fieldRegistrations[i]; + var builder = new StringBuilder(); + builder.append("{").append(packetClazz.getCanonicalName() + " packet = (" + packetClazz.getCanonicalName() + ")$2;"); + builder.append("if(ByteBufUtils.writePacketFlag($1, packet)){").append("return;}"); + for (var i = 0; i < fields.length; i++) { + var field = fields[i]; + var fieldRegistration = fieldRegistrations[i]; if (Modifier.isPublic(field.getModifiers())) { enhanceSerializer(fieldRegistration.serializer()) @@ -223,31 +214,29 @@ public abstract class EnhanceUtils { .writeObject(builder, StringUtils.format("packet.{}()", ReflectionUtils.fieldToGetMethod(packetClazz, field)), field, fieldRegistration); } } - - builder.append("}"); return builder.toString(); } // see: ProtocolRegistration.read() - private static String readMethodBody(ProtocolRegistration registration) throws NoSuchMethodException { - short protocolId = registration.getId(); - Constructor constructor = registration.getConstructor(); - Field[] fields = registration.getFields(); - IFieldRegistration[] fieldRegistrations = registration.getFieldRegistrations(); + private static String readMethodBody(ProtocolRegistration registration) { + var constructor = registration.getConstructor(); + var fields = registration.getFields(); + var fieldRegistrations = registration.getFieldRegistrations(); - StringBuilder builder = new StringBuilder(); - builder.append("{"); - builder.append("if(!" + EnhanceUtils.byteBufUtilsReadBoolean + "){") - .append("return null;}"); - Class packetClazz = constructor.getDeclaringClass(); + var builder = new StringBuilder(); + builder.append("{").append("if(!" + EnhanceUtils.byteBufUtilsReadBoolean + "){").append("return null;}"); + var packetClazz = constructor.getDeclaringClass(); builder.append(packetClazz.getCanonicalName() + " packet=new " + packetClazz.getCanonicalName() + "();"); - for (int i = 0; i < fields.length; i++) { - Field field = fields[i]; - IFieldRegistration fieldRegistration = fieldRegistrations[i]; - - String readObject = enhanceSerializer(fieldRegistration.serializer()).readObject(builder, field, fieldRegistration); + for (var i = 0; i < fields.length; i++) { + var field = fields[i]; + var fieldRegistration = fieldRegistrations[i]; + var readObject = enhanceSerializer(fieldRegistration.serializer()).readObject(builder, field, fieldRegistration); + // 协议向后兼容 + if (field.isAnnotationPresent(Compatible.class)) { + builder.append(StringUtils.format("if(!$1.isReadable()){ return packet; }")); + } if (Modifier.isPublic(field.getModifiers())) { builder.append(StringUtils.format("packet.{}={};", field.getName(), readObject)); @@ -260,9 +249,7 @@ public abstract class EnhanceUtils { return builder.toString(); } - public static String getProtocolRegistrationFieldNameByProtocolId(short id) { return StringUtils.format("{}{}", StringUtils.uncapitalize(ProtocolRegistration.class.getSimpleName()), id); } - } diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java index 694c725d..0b0c17f6 100644 --- a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java +++ b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java @@ -15,12 +15,14 @@ package com.zfoo.protocol.registration; import com.zfoo.protocol.IPacket; import com.zfoo.protocol.ProtocolManager; import com.zfoo.protocol.collection.ArrayUtils; +import com.zfoo.protocol.collection.CollectionUtils; import com.zfoo.protocol.exception.RunException; import com.zfoo.protocol.exception.UnknownException; import com.zfoo.protocol.generate.GenerateOperation; import com.zfoo.protocol.generate.GenerateProtocolDocument; import com.zfoo.protocol.generate.GenerateProtocolFile; import com.zfoo.protocol.generate.GenerateProtocolPath; +import com.zfoo.protocol.registration.anno.Compatible; import com.zfoo.protocol.registration.field.*; import com.zfoo.protocol.serializer.cpp.GenerateCppUtils; import com.zfoo.protocol.serializer.csharp.GenerateCsUtils; @@ -49,6 +51,11 @@ import static com.zfoo.protocol.ProtocolManager.*; */ public class ProtocolAnalysis { + // 临时变量,启动完成就会销毁,协议Id对应的Class类 + private static final Map> protocolClassMap = new HashMap<>(MAX_PROTOCOL_NUM); + + // 临时变量,启动完成就会销毁,协议下包含的子协议,只包含一层子协议 + private static Map> subProtocolIdMap = new HashMap<>(MAX_PROTOCOL_NUM); // 临时变量,启动完成就会销毁,协议名称保留字符,即协议的名称不能用以下名称命名 private static Set protocolReserved = Set.of("Buffer", "ByteBuf", "ByteBuffer", "LittleEndianByteBuffer", "NormalByteBuffer" @@ -58,13 +65,7 @@ public class ProtocolAnalysis { , "Collections", "Iterator", "List", "ArrayList", "Map", "HashMap", "Set", "HashSet"); // 临时变量,启动完成就会销毁,是一个基本类型序列化器 - private static Map, ISerializer> baseSerializerMap = new HashMap<>(); - - // 临时变量,启动完成就会销毁,协议Id对应的Class类 - private static final Map> protocolClassMap = new HashMap<>(MAX_PROTOCOL_NUM); - - // 临时变量,启动完成就会销毁,协议下包含的子协议,只包含一层子协议 - private static Map> subProtocolIdMap = new HashMap<>(MAX_PROTOCOL_NUM); + private static Map, ISerializer> baseSerializerMap = new HashMap<>(128); static { // 初始化基础类型序列化器 @@ -91,29 +92,15 @@ public class ProtocolAnalysis { AssertionUtils.notNull(subProtocolIdMap, "[{}]已经初始完成,请不要重复初始化", ProtocolManager.class.getSimpleName()); try { for (var protocolClass : protocolClassSet) { - var id = getProtocolIdByClass(protocolClass); - var previous = protocolClassMap.put(id, protocolClass); - if (previous != null) { - throw new RunException("[{}][{}]协议号[protocolId:{}]重复", protocolClass.getCanonicalName(), previous.getCanonicalName(), id); - } + checkProtocol(protocolClass); } - for (var protocolClass : protocolClassSet) { - try { - var registration = parseProtocolRegistration(protocolClass, ProtocolModule.DEFAULT_PROTOCOL_MODULE); - // 注册协议 - protocols[registration.protocolId()] = registration; - } catch (Exception e) { - throw new RuntimeException(StringUtils.format("解析协议[class:{}]异常", protocolClass), e); - } + var registration = parseProtocolRegistration(protocolClass, ProtocolModule.DEFAULT_PROTOCOL_MODULE); + protocols[registration.protocolId()] = registration; } - - enhanceProtocolBefore(generateOperation); - // 通过指定类注册的协议,全部使用字节码增强 - enhanceProtocolRegistration(Arrays.stream(protocols).filter(it -> Objects.nonNull(it)).collect(Collectors.toList())); - - enhanceProtocolAfter(generateOperation); + var enhanceList = Arrays.stream(protocols).filter(it -> Objects.nonNull(it)).collect(Collectors.toList()); + enhance(generateOperation, enhanceList); } catch (Exception e) { throw new RuntimeException(e); } @@ -143,15 +130,9 @@ public class ProtocolAnalysis { AssertionUtils.isTrue(id < moduleDefinition.getMaxId(), "模块[{}]中的协议[{}]的协议号必须小于[{}]", moduleDefinition.getName(), clazz.getSimpleName(), moduleDefinition.getMaxId()); AssertionUtils.isNull(protocols[id], "duplicate definition [id:{}] Exception!", id); - var packet = (IPacket) ReflectionUtils.newInstance(clazz); - // 协议号是否和id是否相等 - AssertionUtils.isTrue(packet.protocolId() == id, "[class:{}]协议序列号[{}]和协议文件里的协议序列号不相等", clazz.getCanonicalName(), PROTOCOL_ID); - - var previous = protocolClassMap.put(id, clazz); - if (previous != null) { - throw new RunException("[{}][{}]协议号[protocolId:{}]重复", clazz.getCanonicalName(), previous.getCanonicalName(), id); - } + AssertionUtils.isTrue(getProtocolIdByClass(clazz) == id, "[class:{}]协议序列号[{}]和协议文件里的协议序列号不相等", clazz.getCanonicalName(), PROTOCOL_ID); + checkProtocol(clazz); } } @@ -160,29 +141,35 @@ public class ProtocolAnalysis { for (var protocolDefinition : moduleDefinition.getProtocols()) { var id = protocolDefinition.getId(); var clazz = protocolClassMap.get(id); - try { - var registration = parseProtocolRegistration(clazz, module); - if (protocolDefinition.isEnhance()) { - enhanceList.add(registration); - } - // 注册协议 - protocols[id] = registration; - } catch (Exception e) { - throw new UnknownException(e, "解析协议[id:{}][class:{}]异常", id, clazz); + var registration = parseProtocolRegistration(clazz, module); + if (protocolDefinition.isEnhance()) { + enhanceList.add(registration); } + // 注册协议 + protocols[id] = registration; } } - - enhanceProtocolBefore(generateOperation); - - enhanceProtocolRegistration(enhanceList); - - enhanceProtocolAfter(generateOperation); + enhance(generateOperation, enhanceList); } catch (Exception e) { throw new UnknownException(e); } } + private static void enhance(GenerateOperation generateOperation, List enhanceList) throws IOException, ClassNotFoundException, NotFoundException, CannotCompileException, NoSuchFieldException, InvocationTargetException, NoSuchMethodException, IllegalAccessException, InstantiationException { + enhanceProtocolBefore(generateOperation); + enhanceProtocolRegistration(enhanceList); + enhanceProtocolAfter(generateOperation); + } + + private static void enhanceProtocolBefore(GenerateOperation generateOperation) throws IOException, ClassNotFoundException { + // 检查协议格式 + checkAllProtocolClass(); + // 检查模块格式 + checkAllModules(); + // 生成协议 + GenerateProtocolFile.generate(generateOperation); + } + private static void enhanceProtocolRegistration(List enhanceList) throws NoSuchMethodException, IllegalAccessException, InstantiationException, CannotCompileException, NotFoundException, InvocationTargetException, NoSuchFieldException { // 字节码增强 for (var registration : enhanceList) { @@ -201,29 +188,14 @@ public class ProtocolAnalysis { } } - private static void enhanceProtocolBefore(GenerateOperation generateOperation) throws IOException, ClassNotFoundException { - // 检查协议格式 - checkAllProtocolClass(); - - // 检查模块格式 - checkAllModules(); - - // 生成协议 - GenerateProtocolFile.generate(generateOperation); - } - private static void enhanceProtocolAfter(GenerateOperation generateOperation) { - subProtocolIdMap.clear(); subProtocolIdMap = null; - protocolReserved = null; - - baseSerializerMap.clear(); baseSerializerMap = null; EnhanceUtils.clear(); - if (generateOperation.getGenerateLanguages().isEmpty()) { + if (CollectionUtils.isEmpty(generateOperation.getGenerateLanguages())) { return; } @@ -238,146 +210,70 @@ public class ProtocolAnalysis { GenerateProtobufUtils.clear(); } - - private static short checkProtocol(Class clazz) throws IllegalAccessException, InvocationTargetException, InstantiationException { - // 是否为一个简单的javabean - ReflectionUtils.assertIsPojoClass(clazz); - // 是否实现了IPacket接口 - AssertionUtils.isTrue(IPacket.class.isAssignableFrom(clazz), "[class:{}]没有实现接口[IPacket:{}]", clazz.getCanonicalName(), IPacket.class.getCanonicalName()); - // 不能是泛型类 - AssertionUtils.isTrue(ArrayUtils.isEmpty(clazz.getTypeParameters()), "[class:{}]不能是泛型类", clazz.getCanonicalName()); - - Field protocolIdField; - try { - protocolIdField = clazz.getDeclaredField(PROTOCOL_ID); - } catch (NoSuchFieldException e) { - throw new UnknownException(e, "[class:{}]没有[{}]协议序列号", clazz.getCanonicalName(), PROTOCOL_ID); - } - - // 是否被public修饰 - AssertionUtils.isTrue(Modifier.isPublic(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被public修饰", clazz.getCanonicalName(), PROTOCOL_ID); - // 是否被static修饰 - AssertionUtils.isTrue(Modifier.isStatic(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被static修饰", clazz.getCanonicalName(), PROTOCOL_ID); - // 是否被final修饰 - AssertionUtils.isTrue(Modifier.isFinal(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被final修饰", clazz.getCanonicalName(), PROTOCOL_ID); - // 是否被transient修饰 - AssertionUtils.isTrue(Modifier.isTransient(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被transient修饰", clazz.getCanonicalName(), PROTOCOL_ID); - // 命名只能包含字母,数字,下划线 - AssertionUtils.isTrue(clazz.getSimpleName().matches("[a-zA-Z0-9_]*"), "[class:{}]的命名只能包含字母,数字,下划线", clazz.getCanonicalName(), PROTOCOL_ID); - - // 必须要有一个空的构造器 - Constructor constructor = ReflectionUtils.publicEmptyConstructor(clazz); - - ReflectionUtils.makeAccessible(protocolIdField); - IPacket packet = (IPacket) constructor.newInstance(); - - // 验证protocol()方法的返回是否和PROTOCOL_ID相等 - AssertionUtils.isTrue(Short.valueOf(packet.protocolId()).equals(protocolIdField.get(null)), "[class:{}]的protocolId返回的值和协议号的静态变量[{}]不相等", clazz.getCanonicalName(), PROTOCOL_ID); - return packet.protocolId(); - } - - private static void checkAllModules() { - // 模块id不能重复 - var moduleIdSet = new HashSet(); - Arrays.stream(modules) - .filter(it -> Objects.nonNull(it)) - .peek(it -> AssertionUtils.isTrue(!moduleIdSet.contains(it.getId()), "模块[{}]存在重复的id,模块的id不能重复", it)) - .forEach(it -> moduleIdSet.add(it.getId())); - - // 模块名称不能重复 - var moduleNameSet = new HashSet(); - Arrays.stream(modules) - .filter(it -> Objects.nonNull(it)) - .peek(it -> AssertionUtils.isTrue(!moduleNameSet.contains(it.getName()), "模块[{}]存在重复的name,模块名称不能重复", it)) - .forEach(it -> moduleNameSet.add(it.getName())); - } - - private static void checkAllProtocolClass() { - // 检查协议格式 - - // 协议的名称不能重复 - var allProtocolNameMap = new HashMap>(); - for (var protocolRegistration : protocols) { - if (protocolRegistration == null) { - continue; - } - - var protocolClass = protocolRegistration.protocolConstructor().getDeclaringClass(); - var protocolName = protocolClass.getSimpleName(); - if (allProtocolNameMap.containsKey(protocolName)) { - throw new RunException("[class:{}]和[class:{}]协议名称重复,协议不能含有重复的名称", protocolClass.getCanonicalName(), allProtocolNameMap.get(protocolName).getCanonicalName()); - } - - if (protocolReserved.stream().anyMatch(it -> it.equalsIgnoreCase(protocolName))) { - throw new RunException("协议的名称[class:{}]不能是保留名称[{}]", protocolClass.getCanonicalName(), protocolName); - } - - allProtocolNameMap.put(protocolName, protocolClass); - } - - - // 检查循环协议 - for (var protocolEntry : subProtocolIdMap.entrySet()) { - var protocolId = protocolEntry.getKey(); - var subProtocolSet = protocolEntry.getValue(); - if (subProtocolSet.contains(protocolId)) { - var protocolClass = protocols[protocolId].protocolConstructor().getDeclaringClass(); - throw new RunException("[class:{}]在第一层包含循环引用协议[class:{}]", protocolClass.getSimpleName(), protocolClass.getSimpleName()); - } - - getAllSubProtocolIds(protocolId); - } - } - - private static void checkSubProtocol(Class clazz, short id, Class subClass) { - var registerProtocolClass = protocolClassMap.get(id); - if (registerProtocolClass == null || !registerProtocolClass.equals(subClass)) { - throw new RunException("协议[{}]的子协议[{}][{}]没有注册", clazz.getCanonicalName(), id, subClass.getCanonicalName()); - } - } - - private static ProtocolRegistration parseProtocolRegistration(Class clazz, ProtocolModule module) throws IllegalAccessException, NoSuchMethodException, InvocationTargetException, InstantiationException { - var protocolId = checkProtocol(clazz); - - // 对象需要被序列化的属性 - var fields = new ArrayList(); + private static List customFieldOrder(Class clazz) { + var notCompatibleFields = new ArrayList(); + var compatibleFieldMap = new HashMap(); for (var field : clazz.getDeclaredFields()) { var modifiers = field.getModifiers(); if (Modifier.isTransient(modifiers) || Modifier.isStatic(modifiers)) { continue; } - if (Modifier.isFinal(modifiers)) { - throw new RunException("[{}]协议号[protocolId:{}]中的[filed:{}]属性的访问修饰符不能为final" - , clazz.getCanonicalName(), protocolId, field.getName()); + throw new RunException("[{}]协议号中的[field:{}]属性的访问修饰符不能为final", clazz.getCanonicalName(), field.getName()); } - if (!Modifier.isPublic(modifiers) && !Modifier.isPrivate(modifiers)) { - throw new RunException("[{}]协议号[protocolId:{}]中的[filed:{}]属性的访问修饰符必须是public或者private" - , clazz.getCanonicalName(), protocolId, field.getName()); + throw new RunException("[{}]协议号中的[field:{}]属性的访问修饰符必须是public或者private", clazz.getCanonicalName(), field.getName()); } ReflectionUtils.makeAccessible(field); - fields.add(field); + if (field.isAnnotationPresent(Compatible.class)) { + var order = field.getAnnotation(Compatible.class).order(); + var oldField = compatibleFieldMap.put(order, field); + if (oldField != null) { + throw new RunException("[{}]协议号中的[field:{}]和[field:{}]不能有相同的Compatible顺序[order:{}]", clazz.getCanonicalName(), oldField.getName(), field.getName(), oldField, order); + } + } else { + notCompatibleFields.add(field); + } } - // 按变量名称从小到大排序 - fields.sort(PACKET_FIELD_COMPARATOR); + // 默认无法兼容的协议变量名称从小到大排序,如果想自定义私有协议规则,修改这个排序规则即可 + // 如果为了增加协议的安全性,每个版本都可以重新修改协议排序规则,让每个版本的协议都不相同,间接实现加密 + notCompatibleFields.sort((a, b) -> a.getName().compareTo(b.getName())); - var registrationList = new ArrayList(); - for (var field : fields) { - registrationList.add(toRegistration(clazz, field)); + // 可兼容的协议变量默认都添加到最后 + var compatibleFields = compatibleFieldMap.entrySet() + .stream() + .sorted((a, b) -> a.getKey() - b.getKey()) + .map(it -> it.getValue()) + .collect(Collectors.toList()); + notCompatibleFields.addAll(compatibleFields); + return notCompatibleFields; + } + + private static ProtocolRegistration parseProtocolRegistration(Class clazz, ProtocolModule module) { + var protocolId = getProtocolIdByClass(clazz); + // 对象需要被序列化的属性 + var fields = customFieldOrder(clazz); + + try { + var registrationList = new ArrayList(); + for (var field : fields) { + registrationList.add(toRegistration(clazz, field)); + } + + var constructor = clazz.getDeclaredConstructor(); + ReflectionUtils.makeAccessible(constructor); + var protocol = new ProtocolRegistration(); + protocol.setId(protocolId); + protocol.setConstructor(constructor); + protocol.setFields(ArrayUtils.listToArray(fields, Field.class)); + protocol.setFieldRegistrations(ArrayUtils.listToArray(registrationList, IFieldRegistration.class)); + protocol.setModule(module.getId()); + return protocol; + } catch (Exception e) { + throw new RuntimeException(StringUtils.format("解析协议[class:{}]异常", clazz), e); } - - var constructor = clazz.getDeclaredConstructor(); - ReflectionUtils.makeAccessible(constructor); - var protocol = new ProtocolRegistration(); - protocol.setId(protocolId); - protocol.setConstructor(constructor); - protocol.setFields(ArrayUtils.listToArray(fields, Field.class)); - protocol.setFieldRegistrations(ArrayUtils.listToArray(registrationList, IFieldRegistration.class)); - protocol.setModule(module.getId()); - return protocol; } private static IFieldRegistration toRegistration(Class clazz, Field field) { @@ -543,4 +439,108 @@ public class ProtocolAnalysis { return allSubProtocolIdSet; } + // 协议智能语法分析,错误的协议定义将无法启动程序并给出错误警告 + //----------------------------------------------------------------------- + + private static void checkProtocol(Class clazz) throws IllegalAccessException, InvocationTargetException, InstantiationException { + // 是否为一个简单的javabean + ReflectionUtils.assertIsPojoClass(clazz); + // 是否实现了IPacket接口 + AssertionUtils.isTrue(IPacket.class.isAssignableFrom(clazz), "[class:{}]没有实现接口[IPacket:{}]", clazz.getCanonicalName(), IPacket.class.getCanonicalName()); + // 不能是泛型类 + AssertionUtils.isTrue(ArrayUtils.isEmpty(clazz.getTypeParameters()), "[class:{}]不能是泛型类", clazz.getCanonicalName()); + + Field protocolIdField; + try { + protocolIdField = clazz.getDeclaredField(PROTOCOL_ID); + } catch (NoSuchFieldException e) { + throw new UnknownException(e, "[class:{}]没有[{}]协议序列号", clazz.getCanonicalName(), PROTOCOL_ID); + } + + // 是否被public修饰 + AssertionUtils.isTrue(Modifier.isPublic(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被public修饰", clazz.getCanonicalName(), PROTOCOL_ID); + // 是否被static修饰 + AssertionUtils.isTrue(Modifier.isStatic(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被static修饰", clazz.getCanonicalName(), PROTOCOL_ID); + // 是否被final修饰 + AssertionUtils.isTrue(Modifier.isFinal(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被final修饰", clazz.getCanonicalName(), PROTOCOL_ID); + // 是否被transient修饰 + AssertionUtils.isTrue(Modifier.isTransient(protocolIdField.getModifiers()), "[class:{}]协议序列号[{}]没有被transient修饰", clazz.getCanonicalName(), PROTOCOL_ID); + // 命名只能包含字母,数字,下划线 + AssertionUtils.isTrue(clazz.getSimpleName().matches("[a-zA-Z0-9_]*"), "[class:{}]的命名只能包含字母,数字,下划线", clazz.getCanonicalName(), PROTOCOL_ID); + + // 必须要有一个空的构造器 + Constructor constructor = ReflectionUtils.publicEmptyConstructor(clazz); + + ReflectionUtils.makeAccessible(protocolIdField); + IPacket packet = (IPacket) constructor.newInstance(); + var protocolId = (short) protocolIdField.get(null); + // 验证protocol()方法的返回是否和PROTOCOL_ID相等 + AssertionUtils.isTrue(packet.protocolId() == protocolId, "[class:{}]的protocolId返回的值和协议号的静态变量[{}]不相等", clazz.getCanonicalName(), PROTOCOL_ID); + + var previous = protocolClassMap.put(protocolId, clazz); + if (previous != null) { + throw new RunException("[{}][{}]协议号[protocolId:{}]重复", clazz.getCanonicalName(), previous.getCanonicalName(), protocolId); + } + } + + private static void checkSubProtocol(Class clazz, short id, Class subClass) { + var registerProtocolClass = protocolClassMap.get(id); + if (registerProtocolClass == null || !registerProtocolClass.equals(subClass)) { + throw new RunException("协议[{}]的子协议[{}][{}]没有注册", clazz.getCanonicalName(), id, subClass.getCanonicalName()); + } + } + + private static void checkAllModules() { + // 模块id不能重复 + var moduleIdSet = new HashSet(); + Arrays.stream(modules) + .filter(it -> Objects.nonNull(it)) + .peek(it -> AssertionUtils.isTrue(!moduleIdSet.contains(it.getId()), "模块[{}]存在重复的id,模块的id不能重复", it)) + .forEach(it -> moduleIdSet.add(it.getId())); + + // 模块名称不能重复 + var moduleNameSet = new HashSet(); + Arrays.stream(modules) + .filter(it -> Objects.nonNull(it)) + .peek(it -> AssertionUtils.isTrue(!moduleNameSet.contains(it.getName()), "模块[{}]存在重复的name,模块名称不能重复", it)) + .forEach(it -> moduleNameSet.add(it.getName())); + } + + private static void checkAllProtocolClass() { + // 检查协议格式 + + // 协议的名称不能重复 + var allProtocolNameMap = new HashMap>(); + for (var protocolRegistration : protocols) { + if (protocolRegistration == null) { + continue; + } + + var protocolClass = protocolRegistration.protocolConstructor().getDeclaringClass(); + var protocolName = protocolClass.getSimpleName(); + if (allProtocolNameMap.containsKey(protocolName)) { + throw new RunException("[class:{}]和[class:{}]协议名称重复,协议不能含有重复的名称", protocolClass.getCanonicalName(), allProtocolNameMap.get(protocolName).getCanonicalName()); + } + + if (protocolReserved.stream().anyMatch(it -> it.equalsIgnoreCase(protocolName))) { + throw new RunException("协议的名称[class:{}]不能是保留名称[{}]", protocolClass.getCanonicalName(), protocolName); + } + + allProtocolNameMap.put(protocolName, protocolClass); + } + + + // 检查循环协议 + for (var protocolEntry : subProtocolIdMap.entrySet()) { + var protocolId = protocolEntry.getKey(); + var subProtocolSet = protocolEntry.getValue(); + if (subProtocolSet.contains(protocolId)) { + var protocolClass = protocols[protocolId].protocolConstructor().getDeclaringClass(); + throw new RunException("[class:{}]在第一层包含循环引用协议[class:{}]", protocolClass.getSimpleName(), protocolClass.getSimpleName()); + } + + getAllSubProtocolIds(protocolId); + } + } + } diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolRegistration.java b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolRegistration.java index 1c23528b..4c261a93 100644 --- a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolRegistration.java +++ b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolRegistration.java @@ -15,6 +15,7 @@ package com.zfoo.protocol.registration; import com.zfoo.protocol.IPacket; import com.zfoo.protocol.buffer.ByteBufUtils; +import com.zfoo.protocol.registration.anno.Compatible; import com.zfoo.protocol.registration.field.IFieldRegistration; import com.zfoo.protocol.serializer.reflect.ISerializer; import com.zfoo.protocol.util.ReflectionUtils; @@ -99,6 +100,10 @@ public class ProtocolRegistration implements IProtocolRegistration { for (int i = 0, length = fields.length; i < length; i++) { Field field = fields[i]; + // 协议向后兼容 + if (field.isAnnotationPresent(Compatible.class) && !buffer.isReadable()) { + break; + } IFieldRegistration packetFieldRegistration = fieldRegistrations[i]; ISerializer serializer = packetFieldRegistration.serializer(); Object fieldValue = serializer.readObject(buffer, packetFieldRegistration); diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/anno/Compatible.java b/protocol/src/main/java/com/zfoo/protocol/registration/anno/Compatible.java new file mode 100644 index 00000000..34be182f --- /dev/null +++ b/protocol/src/main/java/com/zfoo/protocol/registration/anno/Compatible.java @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2020 The zfoo Authors + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and limitations under the License. + * + */ + +package com.zfoo.protocol.registration.anno; + +import java.lang.annotation.*; + +/** + * @author jaysunxiao + * @version 3.0 + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.FIELD}) +public @interface Compatible { + + int order(); + +} diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/cpp/GenerateCppUtils.java b/protocol/src/main/java/com/zfoo/protocol/serializer/cpp/GenerateCppUtils.java index 4cc3aea3..ee2476b0 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/cpp/GenerateCppUtils.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/cpp/GenerateCppUtils.java @@ -23,6 +23,7 @@ import com.zfoo.protocol.model.Pair; import com.zfoo.protocol.registration.IProtocolRegistration; import com.zfoo.protocol.registration.ProtocolAnalysis; import com.zfoo.protocol.registration.ProtocolRegistration; +import com.zfoo.protocol.registration.anno.Compatible; import com.zfoo.protocol.registration.field.IFieldRegistration; import com.zfoo.protocol.serializer.enhance.EnhanceObjectProtocolSerializer; import com.zfoo.protocol.serializer.reflect.*; @@ -267,18 +268,18 @@ public abstract class GenerateCppUtils { var fields = registration.getFields(); var fieldRegistrations = registration.getFieldRegistrations(); - var csBuilder = new StringBuilder(); + var cppBuilder = new StringBuilder(); for (int i = 0; i < fields.length; i++) { var field = fields[i]; var fieldRegistration = fieldRegistrations[i]; var serializer = cppSerializer(fieldRegistration.serializer()); if (IPacket.class.isAssignableFrom(field.getType())) { - serializer.writeObject(csBuilder, "&message->" + field.getName(), 3, field, fieldRegistration); + serializer.writeObject(cppBuilder, "&message->" + field.getName(), 3, field, fieldRegistration); } else { - serializer.writeObject(csBuilder, "message->" + field.getName(), 3, field, fieldRegistration); + serializer.writeObject(cppBuilder, "message->" + field.getName(), 3, field, fieldRegistration); } } - return csBuilder.toString(); + return cppBuilder.toString(); } @@ -286,21 +287,26 @@ public abstract class GenerateCppUtils { var fields = registration.getFields(); var fieldRegistrations = registration.getFieldRegistrations(); - var csBuilder = new StringBuilder(); + var cppBuilder = new StringBuilder(); for (int i = 0; i < fields.length; i++) { var field = fields[i]; var fieldRegistration = fieldRegistrations[i]; - var readObject = cppSerializer(fieldRegistration.serializer()).readObject(csBuilder, 3, field, fieldRegistration); - csBuilder.append(TAB + TAB + TAB); - if (IPacket.class.isAssignableFrom(field.getType())) { - csBuilder.append(StringUtils.format("packet->{} = *{};", field.getName(), readObject)); - } else { - csBuilder.append(StringUtils.format("packet->{} = {};", field.getName(), readObject)); + + if (field.isAnnotationPresent(Compatible.class)) { + cppBuilder.append(TAB + TAB + TAB).append(StringUtils.format("if (!buffer.isReadable()) { return packet; }")).append(LS); } - csBuilder.append(LS); + var readObject = cppSerializer(fieldRegistration.serializer()).readObject(cppBuilder, 3, field, fieldRegistration); + cppBuilder.append(TAB + TAB + TAB); + if (IPacket.class.isAssignableFrom(field.getType())) { + cppBuilder.append(StringUtils.format("packet->{} = *{};", field.getName(), readObject)); + } else { + cppBuilder.append(StringUtils.format("packet->{} = {};", field.getName(), readObject)); + } + + cppBuilder.append(LS); } - return csBuilder.toString(); + return cppBuilder.toString(); } diff --git a/protocol/src/main/java/com/zfoo/protocol/util/JsonUtils.java b/protocol/src/main/java/com/zfoo/protocol/util/JsonUtils.java index 927e548d..4e06cc91 100644 --- a/protocol/src/main/java/com/zfoo/protocol/util/JsonUtils.java +++ b/protocol/src/main/java/com/zfoo/protocol/util/JsonUtils.java @@ -193,9 +193,9 @@ public abstract class JsonUtils { // 循环遍历子节点下的信息 while (iterator.hasNext()) { var node = iterator.next(); - var filed = node.getKey(); + var field = node.getKey(); var value = node.getValue().asText(); - jsonMap.put(filed, value); + jsonMap.put(field, value); } } return jsonMap; diff --git a/protocol/src/main/resources/cpp/ByteBuffer.h b/protocol/src/main/resources/cpp/ByteBuffer.h index fafe9735..d6d45445 100644 --- a/protocol/src/main/resources/cpp/ByteBuffer.h +++ b/protocol/src/main/resources/cpp/ByteBuffer.h @@ -132,6 +132,10 @@ namespace zfoo { } } + inline bool isReadable() { + return m_writerIndex > m_readerIndex; + } + inline void writeBool(const bool &value) { ensureCapacity(1); int8_t v = value ? 1 : 0; diff --git a/protocol/src/test/cpp/cppProtocol/ByteBuffer.h b/protocol/src/test/cpp/cppProtocol/ByteBuffer.h index fafe9735..d6d45445 100644 --- a/protocol/src/test/cpp/cppProtocol/ByteBuffer.h +++ b/protocol/src/test/cpp/cppProtocol/ByteBuffer.h @@ -132,6 +132,10 @@ namespace zfoo { } } + inline bool isReadable() { + return m_writerIndex > m_readerIndex; + } + inline void writeBool(const bool &value) { ensureCapacity(1); int8_t v = value ? 1 : 0; diff --git a/protocol/src/test/cpp/cppProtocol/Packet/ComplexObject.h b/protocol/src/test/cpp/cppProtocol/Packet/ComplexObject.h index 6a1f23d4..99d16584 100644 --- a/protocol/src/test/cpp/cppProtocol/Packet/ComplexObject.h +++ b/protocol/src/test/cpp/cppProtocol/Packet/ComplexObject.h @@ -69,10 +69,13 @@ namespace zfoo { set> sss; set ssss; set> sssss; + // 如果要修改协议并且兼容老协议,需要加上Compatible注解,按照增加的顺序添加order + int32_t myCompatible; + ObjectA myObject; ~ComplexObject() override = default; - static ComplexObject valueOf(int8_t a, int8_t aa, vector aaa, vector aaaa, int16_t b, int16_t bb, vector bbb, vector bbbb, int32_t c, int32_t cc, vector ccc, vector cccc, int64_t d, int64_t dd, vector ddd, vector dddd, float e, float ee, vector eee, vector eeee, double f, double ff, vector fff, vector ffff, bool g, bool gg, vector ggg, vector gggg, char h, char hh, vector hhh, vector hhhh, string jj, vector jjj, ObjectA kk, vector kkk, list l, list>> ll, list> lll, list llll, list> lllll, map m, map mm, map> mmm, map>, list>>> mmmm, map>, set>> mmmmm, set s, set>> ss, set> sss, set ssss, set> sssss) { + static ComplexObject valueOf(int8_t a, int8_t aa, vector aaa, vector aaaa, int16_t b, int16_t bb, vector bbb, vector bbbb, int32_t c, int32_t cc, vector ccc, vector cccc, int64_t d, int64_t dd, vector ddd, vector dddd, float e, float ee, vector eee, vector eeee, double f, double ff, vector fff, vector ffff, bool g, bool gg, vector ggg, vector gggg, char h, char hh, vector hhh, vector hhhh, string jj, vector jjj, ObjectA kk, vector kkk, list l, list>> ll, list> lll, list llll, list> lllll, map m, map mm, map> mmm, map>, list>>> mmmm, map>, set>> mmmmm, set s, set>> ss, set> sss, set ssss, set> sssss, int32_t myCompatible, ObjectA myObject) { auto packet = ComplexObject(); packet.a = a; packet.aa = aa; @@ -125,6 +128,8 @@ namespace zfoo { packet.sss = sss; packet.ssss = ssss; packet.sssss = sssss; + packet.myCompatible = myCompatible; + packet.myObject = myObject; return packet; } @@ -235,6 +240,10 @@ namespace zfoo { if (_.ssss < ssss) { return false; } if (sssss < _.sssss) { return true; } if (_.sssss < sssss) { return false; } + if (myCompatible < _.myCompatible) { return true; } + if (_.myCompatible < myCompatible) { return false; } + if (myObject < _.myObject) { return true; } + if (_.myObject < myObject) { return false; } return false; } }; @@ -353,6 +362,8 @@ namespace zfoo { for (auto i18 : message->sssss) { buffer.writeIntStringMap(i18); } + buffer.writeInt(message->myCompatible); + buffer.writePacket(&message->myObject, 102); } IPacket *read(ByteBuffer &buffer) override { @@ -547,6 +558,13 @@ namespace zfoo { result119.emplace(map122); } packet->sssss = result119; + if (!buffer.isReadable()) { return packet; } + int32_t result123 = buffer.readInt(); + packet->myCompatible = result123; + if (!buffer.isReadable()) { return packet; } + auto result124 = buffer.readPacket(102); + auto *result125 = (ObjectA *) result124.get(); + packet->myObject = *result125; return packet; } }; diff --git a/protocol/src/test/java/com/zfoo/protocol/packet/ComplexObject.java b/protocol/src/test/java/com/zfoo/protocol/packet/ComplexObject.java index 8bdb3a11..d16516f4 100644 --- a/protocol/src/test/java/com/zfoo/protocol/packet/ComplexObject.java +++ b/protocol/src/test/java/com/zfoo/protocol/packet/ComplexObject.java @@ -14,6 +14,7 @@ package com.zfoo.protocol.packet; import com.zfoo.protocol.IPacket; +import com.zfoo.protocol.registration.anno.Compatible; import java.util.*; @@ -99,6 +100,12 @@ public class ComplexObject implements IPacket { private Set ssss; private Set> sssss; + // 如果要修改协议并且兼容老协议,需要加上Compatible注解,按照增加的顺序添加order + @Compatible(order = 1) + private int myCompatible; + @Compatible(order = 2) + private ObjectA myObject; + @Override public short protocolId() { return PROTOCOL_ID; @@ -512,6 +519,22 @@ public class ComplexObject implements IPacket { this.sssss = sssss; } + public int getMyCompatible() { + return myCompatible; + } + + public void setMyCompatible(int myCompatible) { + this.myCompatible = myCompatible; + } + + public ObjectA getMyObject() { + return myObject; + } + + public void setMyObject(ObjectA myObject) { + this.myObject = myObject; + } + @Override public boolean equals(Object o) { if (this == o) return true;