diff --git a/net/src/main/java/com/zfoo/net/consumer/balancer/ConsistentHashConsumerLoadBalancer.java b/net/src/main/java/com/zfoo/net/consumer/balancer/ConsistentHashConsumerLoadBalancer.java index 89d2f3d2..9595e09c 100644 --- a/net/src/main/java/com/zfoo/net/consumer/balancer/ConsistentHashConsumerLoadBalancer.java +++ b/net/src/main/java/com/zfoo/net/consumer/balancer/ConsistentHashConsumerLoadBalancer.java @@ -24,9 +24,7 @@ import com.zfoo.protocol.registration.ProtocolModule; import com.zfoo.util.math.ConsistentHash; import org.springframework.lang.Nullable; -import java.util.HashSet; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.stream.Collectors; /** @@ -42,7 +40,7 @@ public class ConsistentHashConsumerLoadBalancer extends AbstractConsumerLoadBala public static final ConsistentHashConsumerLoadBalancer INSTANCE = new ConsistentHashConsumerLoadBalancer(); private volatile int lastClientSessionChangeId = 0; - private static final Map> consistentHashMap = new ConcurrentHashMap<>(); + private static final AtomicReferenceArray> consistentHashMap = new AtomicReferenceArray<>(ProtocolManager.MAX_MODULE_NUM); private static final int VIRTUAL_NODE_NUMS = 200; private ConsistentHashConsumerLoadBalancer() { @@ -68,17 +66,19 @@ public class ConsistentHashConsumerLoadBalancer extends AbstractConsumerLoadBala // 如果更新时间不匹配,则更新到最新的服务提供者 var currentClientSessionChangeId = NetContext.getSessionManager().getClientSessionChangeId(); if (currentClientSessionChangeId != lastClientSessionChangeId) { - var modules = new HashSet<>(consistentHashMap.keySet()); - - for (var module : modules) { + for (byte i = 0; i < ProtocolManager.MAX_MODULE_NUM; i++) { + var consistentHash = consistentHashMap.get(i); + if (consistentHash == null) { + continue; + } + var module = ProtocolManager.moduleByModuleId(i); updateModuleToConsistentHash(module); } - lastClientSessionChangeId = currentClientSessionChangeId; } var module = ProtocolManager.moduleByProtocolId(packet.protocolId()); - var consistentHash = consistentHashMap.get(module); + var consistentHash = consistentHashMap.get(module.getId()); if (consistentHash == null) { consistentHash = updateModuleToConsistentHash(module); } @@ -87,25 +87,18 @@ public class ConsistentHashConsumerLoadBalancer extends AbstractConsumerLoadBala } var sid = consistentHash.getRealNode(argument).getValue(); return NetContext.getSessionManager().getClientSession(sid); - } @Nullable private ConsistentHash updateModuleToConsistentHash(ProtocolModule module) { - var sessionStringList = getSessionsByModule(module) - .stream() + var sessionStringList = getSessionsByModule(module).stream() .map(session -> new Pair<>(session.getConsumerAttribute().toString(), session.getSid())) .sorted((a, b) -> a.getKey().compareTo(b.getKey())) .collect(Collectors.toList()); - if (CollectionUtils.isEmpty(sessionStringList) && !consistentHashMap.containsKey(module)) { - consistentHashMap.remove(module); - return null; - } - - var consistentHash = new ConsistentHash<>(sessionStringList, VIRTUAL_NODE_NUMS); - consistentHashMap.put(module, consistentHash); + var consistentHash = CollectionUtils.isNotEmpty(sessionStringList) ? new ConsistentHash<>(sessionStringList, VIRTUAL_NODE_NUMS) : null; + consistentHashMap.set(module.getId(), consistentHash); return consistentHash; }