diff --git a/orm/src/main/java/com/zfoo/orm/lpmap/FileChannelMap.java b/orm/src/main/java/com/zfoo/orm/lpmap/FileChannelMap.java index cd3eda11..a35199cc 100644 --- a/orm/src/main/java/com/zfoo/orm/lpmap/FileChannelMap.java +++ b/orm/src/main/java/com/zfoo/orm/lpmap/FileChannelMap.java @@ -64,7 +64,7 @@ public class FileChannelMap implements LpMap, Closeable { this.indexFileRandomAccess = new RandomAccessFile(indexFile, "rw"); this.indexFileChannel = this.indexFileRandomAccess.getChannel(); - var protocolId = ProtocolManager.getProtocolIdByClass(clazz); + var protocolId = ProtocolManager.protocolId(clazz); protocolRegistration = ProtocolManager.getProtocol(protocolId); indexBuffer = ByteBufAllocator.DEFAULT.ioBuffer(16); diff --git a/orm/src/main/java/com/zfoo/orm/lpmap/FileHeapMap.java b/orm/src/main/java/com/zfoo/orm/lpmap/FileHeapMap.java index 2e61f95d..579a6793 100644 --- a/orm/src/main/java/com/zfoo/orm/lpmap/FileHeapMap.java +++ b/orm/src/main/java/com/zfoo/orm/lpmap/FileHeapMap.java @@ -46,7 +46,7 @@ public class FileHeapMap implements LpMap { try { this.dbFile = FileUtils.getOrCreateFile(dbPath, StringUtils.format("{}.db", clazz.getSimpleName())); - var protocolId = ProtocolManager.getProtocolIdByClass(clazz); + var protocolId = ProtocolManager.protocolId(clazz); protocolRegistration = ProtocolManager.getProtocol(protocolId); heapMap = new HeapMap<>(); diff --git a/protocol/src/main/java/com/zfoo/protocol/IPacket.java b/protocol/src/main/java/com/zfoo/protocol/IPacket.java index cf3ddf90..d96d80c3 100644 --- a/protocol/src/main/java/com/zfoo/protocol/IPacket.java +++ b/protocol/src/main/java/com/zfoo/protocol/IPacket.java @@ -31,7 +31,7 @@ public interface IPacket { * @return 协议号Id */ default short protocolId() { - return ProtocolManager.getProtocolIdByClass(this.getClass()); + return ProtocolManager.protocolId(this.getClass()); } } diff --git a/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java b/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java index 1a974245..d7750c81 100644 --- a/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java +++ b/protocol/src/main/java/com/zfoo/protocol/ProtocolManager.java @@ -18,13 +18,10 @@ import com.zfoo.protocol.registration.IProtocolRegistration; import com.zfoo.protocol.registration.ProtocolAnalysis; import com.zfoo.protocol.registration.ProtocolModule; import com.zfoo.protocol.util.AssertionUtils; -import com.zfoo.protocol.util.ReflectionUtils; import com.zfoo.protocol.xml.XmlProtocols; import io.netty.buffer.ByteBuf; -import java.util.Arrays; -import java.util.Objects; -import java.util.Set; +import java.util.*; /** * @author jaysunxiao @@ -39,6 +36,8 @@ public class ProtocolManager { public static final IProtocolRegistration[] protocols = new IProtocolRegistration[MAX_PROTOCOL_NUM]; public static final ProtocolModule[] modules = new ProtocolModule[MAX_MODULE_NUM]; + private static final Map, Short> protocolIdMap = new HashMap<>(); + static { // 初始化默认协议模块 modules[0] = ProtocolModule.DEFAULT_PROTOCOL_MODULE; @@ -86,9 +85,15 @@ public class ProtocolManager { return moduleOptional.get(); } - public static short getProtocolIdByClass(Class clazz) { - var protocolIdField = ReflectionUtils.getFieldByNameInPOJOClass(clazz, PROTOCOL_ID); - return (short) ReflectionUtils.getField(protocolIdField, null); + public static short protocolId(Class clazz) { + var protocolId = protocolIdMap.get(clazz); + if (protocolId == null) { + protocolId = ProtocolAnalysis.getProtocolIdByClass(clazz); + synchronized (protocolIdMap) { + protocolIdMap.put(clazz, protocolId); + } + } + return protocolId; } public static void initProtocol(Set> protocolClassSet) { diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java index 8938b352..e0f82304 100644 --- a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java +++ b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java @@ -435,6 +435,11 @@ public class ProtocolAnalysis { return allSubProtocolIdSet; } + public static short getProtocolIdByClass(Class clazz) { + var protocolIdField = ReflectionUtils.getFieldByNameInPOJOClass(clazz, PROTOCOL_ID); + return (short) ReflectionUtils.getField(protocolIdField, null); + } + // 协议智能语法分析,错误的协议定义将无法启动程序并给出错误警告 //----------------------------------------------------------------------- diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/protobuf/GenerateProtobufUtils.java b/protocol/src/main/java/com/zfoo/protocol/serializer/protobuf/GenerateProtobufUtils.java index 8bbf3ff6..df000cc2 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/protobuf/GenerateProtobufUtils.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/protobuf/GenerateProtobufUtils.java @@ -21,6 +21,7 @@ import com.zfoo.protocol.generate.GenerateOperation; import com.zfoo.protocol.generate.GenerateProtocolDocument; import com.zfoo.protocol.model.Pair; import com.zfoo.protocol.registration.IProtocolRegistration; +import com.zfoo.protocol.registration.ProtocolAnalysis; import com.zfoo.protocol.registration.ProtocolRegistration; import com.zfoo.protocol.registration.field.*; import com.zfoo.protocol.serializer.reflect.*; @@ -121,7 +122,7 @@ public abstract class GenerateProtobufUtils { for (var protos : xmlProtobuf.getProtos()) { for (var protocol : protos.getProtocols()) { var protocolClass = Class.forName(protocol.getLocation()); - var protocolId = ProtocolManager.getProtocolIdByClass(protocolClass); + var protocolId = ProtocolAnalysis.getProtocolIdByClass(protocolClass); var protocolRegistration = ProtocolManager.getProtocol(protocolId); if (allGenerateProtocols.contains(protocolRegistration)) { @@ -188,7 +189,7 @@ public abstract class GenerateProtobufUtils { for (var protocol : protos.getProtocols()) { var protocolClass = Class.forName(protocol.getLocation()); - var protocolId = ProtocolManager.getProtocolIdByClass(protocolClass); + var protocolId = ProtocolAnalysis.getProtocolIdByClass(protocolClass); var protocolRegistration = ProtocolManager.getProtocol(protocolId); var protocolDocument = GenerateProtocolDocument.getProtocolDocument(protocolId);