feat[protocol]: 子协议会自动注册协议号protocolId,如果子协议没有指定protocolId则自动生成protocolId

This commit is contained in:
godotg
2022-10-26 16:38:25 +08:00
parent 248faeec13
commit 6476e0b4bb
5 changed files with 53 additions and 15 deletions
@@ -123,4 +123,11 @@ public class ProtocolManager {
ProtocolAnalysis.analyze(xmlProtocols, generateOperation);
}
/**
* 子协议会自动注册协议号protocolId,如果子协议没有指定protocolId则自动生成protocolId
*/
public static void initProtocolAuto(Set<Class<?>> protocolClassSet, GenerateOperation generateOperation) {
ProtocolAnalysis.analyzeAuto(protocolClassSet, generateOperation);
}
}
@@ -35,6 +35,7 @@ import com.zfoo.protocol.serializer.protobuf.GenerateProtobufUtils;
import com.zfoo.protocol.serializer.reflect.*;
import com.zfoo.protocol.serializer.typescript.GenerateTsUtils;
import com.zfoo.protocol.util.AssertionUtils;
import com.zfoo.protocol.util.ClassUtils;
import com.zfoo.protocol.util.ReflectionUtils;
import com.zfoo.protocol.util.StringUtils;
import com.zfoo.protocol.xml.XmlProtocols;
@@ -118,6 +119,50 @@ public class ProtocolAnalysis {
}
}
public static synchronized void analyzeAuto(Set<Class<?>> protocolClassSet, GenerateOperation generateOperation) {
AssertionUtils.notNull(subProtocolIdMap, "[{}]已经初始完成,请不要重复初始化", ProtocolManager.class.getSimpleName());
try {
// 获取所有协议类
var relevantClassSet = new HashSet<>(protocolClassSet);
for (var clazz : protocolClassSet) {
relevantClassSet.addAll(ClassUtils.relevantClass(clazz));
}
var relevantClassList = relevantClassSet.stream()
.sorted((a, b) -> a.getCanonicalName().compareTo(b.getCanonicalName()))
.collect(Collectors.toList());
// 检查协议类是否合法
var noProtocolIds = new ArrayList<Class<?>>();
for (var protocolClass : relevantClassList) {
var protocolId = getProtocolIdAndCheckClass(protocolClass);
if (protocolId >= 0) {
initProtocolClass(protocolId, protocolClass);
} else {
noProtocolIds.add(protocolClass);
}
}
var countProtocolId = (short) 0;
for (var protocolClass : noProtocolIds) {
while (protocolClassMap.containsKey(countProtocolId)) {
countProtocolId++;
}
initProtocolClass(countProtocolId, protocolClass);
}
// 协议id和协议信息对应起来
for (var protocolClass : relevantClassSet) {
var registration = parseProtocolRegistration(protocolClass, ProtocolModule.DEFAULT_PROTOCOL_MODULE);
protocols[registration.protocolId()] = registration;
}
// 通过指定类注册的协议,全部使用字节码增强
var enhanceList = Arrays.stream(protocols).filter(Objects::nonNull).collect(Collectors.toList());
enhance(generateOperation, enhanceList);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public static synchronized void analyze(XmlProtocols xmlProtocols, GenerateOperation generateOperation) {
AssertionUtils.notNull(subProtocolIdMap, "[{}]已经初始完成,请不要重复初始化", ProtocolManager.class.getSimpleName());
@@ -320,7 +320,7 @@ public class SpeedTest {
// op.getGenerateLanguages().add(CodeLanguage.Protobuf);
// zfoo协议注册(其实就是:将Set里面的协议号和对应的类注册好,这样子就可以根据协议号知道是反序列化为哪个类)
ProtocolManager.initProtocol(Set.of(ComplexObject.class, NormalObject.class, SimpleObject.class, ObjectA.class, ObjectB.class), op);
ProtocolManager.initProtocolAuto(Set.of(ComplexObject.class, NormalObject.class, SimpleObject.class), op);
for (int i = 0; i < executors.length; i++) {
executors[i] = Executors.newSingleThreadExecutor();
@@ -24,19 +24,12 @@ import java.util.Objects;
*/
public class ObjectA implements IPacket {
public static final transient short PROTOCOL_ID = 102;
private int a;
private Map<Integer, String> m;
private ObjectB objectB;
@Override
public short protocolId() {
return PROTOCOL_ID;
}
public int getA() {
return a;
}
@@ -23,15 +23,8 @@ import java.util.Objects;
*/
public class ObjectB implements IPacket {
public static final transient short PROTOCOL_ID = 103;
private boolean flag;
@Override
public short protocolId() {
return PROTOCOL_ID;
}
public boolean isFlag() {
return flag;
}