diff --git a/net/src/main/java/com/zfoo/net/router/IRouter.java b/net/src/main/java/com/zfoo/net/router/IRouter.java index 623fb319..54b05873 100644 --- a/net/src/main/java/com/zfoo/net/router/IRouter.java +++ b/net/src/main/java/com/zfoo/net/router/IRouter.java @@ -52,8 +52,7 @@ public interface IRouter { * @param argument 参数,主要用来计算一致性hashId。 * 1.IConsumer会使用这个参数计算负载到哪个服务提供者; * 2.服务提供者收到请求过后会使用这个参数来计算再哪个线程执行任务; - * 3.如果是异步请求,消费者收到消息过后会通过这个参数计算再哪个线程执行回调。 - * 综上所述,这个参数会在上面三种情况使用。 + * 综上所述,这个参数会在上面两种情况使用。 * @return 服务器返回的消息Response * @throws Exception 如果超时或者其它异常 */ diff --git a/net/src/main/java/com/zfoo/net/router/Router.java b/net/src/main/java/com/zfoo/net/router/Router.java index 0fc25947..026196a9 100644 --- a/net/src/main/java/com/zfoo/net/router/Router.java +++ b/net/src/main/java/com/zfoo/net/router/Router.java @@ -303,7 +303,7 @@ public class Router implements IRouter { } } - }, TaskBus.executor(executorConsistentHash)); + }, TaskBus.currentThreadExecutor()); SignalBridge.addSignalAttachment(clientSignalAttachment); diff --git a/net/src/main/java/com/zfoo/net/task/TaskBus.java b/net/src/main/java/com/zfoo/net/task/TaskBus.java index 069869dd..26ee579d 100644 --- a/net/src/main/java/com/zfoo/net/task/TaskBus.java +++ b/net/src/main/java/com/zfoo/net/task/TaskBus.java @@ -17,12 +17,19 @@ import com.zfoo.net.NetContext; import com.zfoo.net.task.dispatcher.AbstractTaskDispatch; import com.zfoo.net.task.dispatcher.ITaskDispatch; import com.zfoo.net.task.model.PacketReceiverTask; +import com.zfoo.protocol.util.AssertionUtils; import com.zfoo.protocol.util.StringUtils; +import com.zfoo.util.math.RandomUtils; +import io.netty.util.concurrent.FastThreadLocalThread; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; /** * @author jaysunxiao @@ -43,6 +50,8 @@ public final class TaskBus { */ private static final ExecutorService[] executors; + private static final Map threadMap = new ConcurrentHashMap<>(); + static { var localConfig = NetContext.getConfigManager().getLocalConfig(); var providerConfig = localConfig.getProvider(); @@ -55,11 +64,40 @@ public final class TaskBus { executors = new ExecutorService[EXECUTOR_SIZE]; for (int i = 0; i < executors.length; i++) { - var namedThreadFactory = new TaskThreadFactory(); - executors[i] = Executors.newSingleThreadExecutor(namedThreadFactory); + var namedThreadFactory = new TaskThreadFactory(i + 1); + var executor = Executors.newSingleThreadExecutor(namedThreadFactory); + namedThreadFactory.executor = executor; + executors[i] = executor; } } + public static class TaskThreadFactory implements ThreadFactory { + + public ExecutorService executor; + + private final int poolNumber; + private final AtomicInteger threadNumber = new AtomicInteger(1); + private final ThreadGroup group; + + public TaskThreadFactory(int poolNumber) { + var s = System.getSecurityManager(); + group = (s != null) ? s.getThreadGroup() : Thread.currentThread().getThreadGroup(); + this.poolNumber = poolNumber; + } + + @Override + public Thread newThread(Runnable runnable) { + var threadName = StringUtils.format("task-p{}-t{}", poolNumber, threadNumber.getAndIncrement()); + var t = new FastThreadLocalThread(group, runnable, threadName, 0); + t.setDaemon(false); + t.setPriority(Thread.NORM_PRIORITY); + t.setUncaughtExceptionHandler((thread, e) -> logger.error(thread.toString(), e)); + AssertionUtils.notNull(executor); + var threadId = t.getId(); + threadMap.put(threadId, executor); + return t; + } + } /** * Actor模型,最主要的就是线程模型,Actor模型保证了某个Actor所代表的任务永远不会同时在两条线程同时处理任务,这就就避免了并发。 @@ -87,4 +125,15 @@ public final class TaskBus { public static ExecutorService executor(int executorConsistentHash) { return executors[Math.abs(executorConsistentHash % EXECUTOR_SIZE)]; } + + // 在task线程的异步请求,请求成功过后依然在相同的task线程执行回调任务 + public static ExecutorService currentThreadExecutor() { + var threadId = Thread.currentThread().getId(); + var executor = threadMap.get(threadId); + if (executor == null) { + return executor(RandomUtils.randomInt()); + } + return executor; + } + } diff --git a/net/src/main/java/com/zfoo/net/task/TaskThreadFactory.java b/net/src/main/java/com/zfoo/net/task/TaskThreadFactory.java deleted file mode 100644 index 9b5f71cc..00000000 --- a/net/src/main/java/com/zfoo/net/task/TaskThreadFactory.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (C) 2020 The zfoo Authors - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and limitations under the License. - * - */ - -package com.zfoo.net.task; - -import io.netty.util.concurrent.FastThreadLocalThread; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.atomic.AtomicInteger; - -/** - * @author jaysunxiao - * @version 3.0 - */ -public class TaskThreadFactory implements ThreadFactory { - - private static final Logger logger = LoggerFactory.getLogger(TaskThreadFactory.class); - - private static final AtomicInteger poolNumber = new AtomicInteger(1); - private final ThreadGroup group; - private final AtomicInteger threadNumber = new AtomicInteger(1); - private final String namePrefix; - - TaskThreadFactory() { - var s = System.getSecurityManager(); - group = (s != null) ? s.getThreadGroup() : Thread.currentThread().getThreadGroup(); - namePrefix = "task-p" + poolNumber.getAndIncrement() + "-t"; - } - - @Override - public Thread newThread(Runnable runnable) { - var t = new FastThreadLocalThread(group, runnable, namePrefix + threadNumber.getAndIncrement(), 0); - t.setDaemon(false); - t.setPriority(Thread.NORM_PRIORITY); - t.setUncaughtExceptionHandler((thread, e) -> logger.error(thread.toString(), e)); - return t; - } - -}