RPC小记
源码参考:https://github.com/Snailclimb/guide-rpc-framework
参考文档:https://www.yuque.com/books/share/b7a2512c-6f7a-4afe-9d7e-5936b4c4cab0
动态代理 首先看客户端的代码
1 2 3 4 5 6 7 8 @RpcScan(basePackage = {"com.richcoder"}) public class ClientMain { public static void main (String[] args) throws InterruptedException { AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext (ClientMain.class); HelloController controller = (HelloController) applicationContext.getBean("helloController" ); controller.test(); } }
在这里我们调用了 controller.test();
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 @Component public class HelloController { @RpcReference(version = "version1", group = "test1") private HelloService helloService; public void test () throws InterruptedException { String hello = this .helloService.hello(new Hello ("111" , "222" )); assert "Hello description is 222" .equals(hello); Thread.sleep(12000 ); for (int i = 0 ; i < 10 ; i++) { System.out.println(helloService.hello(new Hello ("111" , "222" ))); } } }
可以看出我们客户端是直接调用了 helloService.hello(),并没有在客户端的代码内看到将调用方法转换为网络传输对象
这些隐藏的操作都是由动态代理 来完成的,接下来看看如何实现动态代理
首先我们需要一个客户端的代理类,为我们实现将调用方法转换成网络传输对象
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 @Slf4j public class RpcClientProxy implements InvocationHandler { private static final String INTERFACE_NAME = "interfaceName" ; private final RpcRequestTransport transport; private final RpcServiceConfig config; public RpcClientProxy (RpcRequestTransport transport, RpcServiceConfig config) { this .transport = transport; this .config = config; } public <T> T getProxy (Class<T> clazz) { return (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class []{clazz}, this ); } @Override public Object invoke (Object proxy, Method method, Object[] args) throws Throwable { log.info("invoke method: [{}]" , method.getName()); RpcRequest request = RpcRequest.builder() .methodName(method.getName()) .interfaceName(method.getDeclaringClass().getName()) .paramTypes(method.getParameterTypes()) .parameters(args) .requestId(UUID.randomUUID().toString()) .group(config.getGroup()) .version(config.getVersion()) .build(); CompletableFuture<RpcResponse<Object>> future = (CompletableFuture<RpcResponse<Object>>) transport.sendRequest(request); RpcResponse<Object> response = future.get(); this .check(response, request); return response.getData(); } private void check (RpcResponse<Object> rpcResponse, RpcRequest rpcRequest) { if (rpcResponse == null ) { throw new RpcException (RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName()); } if (!rpcRequest.getRequestId().equals(rpcResponse.getRequestId())) { throw new RpcException (RpcErrorMessageEnum.REQUEST_NOT_MATCH_RESPONSE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName()); } if (rpcResponse.getCode() == null || !rpcResponse.getCode().equals(RpcResponseCodeEnum.SUCCESS.getCode())) { throw new RpcException (RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName()); } } }
我们采用一个类来实现 InvocationHandler 接口,完成自定义调用处理器,如果某个代理类将其指定为处理器,则该代理类调用方法时会自动将方法分配到 #invoke() 来调用
我们看 #invoke() 方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 @Override public Object invoke (Object proxy, Method method, Object[] args) throws Throwable { log.info("invoke method: [{}]" , method.getName()); RpcRequest request = RpcRequest.builder() .methodName(method.getName()) .interfaceName(method.getDeclaringClass().getName()) .paramTypes(method.getParameterTypes()) .parameters(args) .requestId(UUID.randomUUID().toString()) .group(config.getGroup()) .version(config.getVersion()) .build(); CompletableFuture<RpcResponse<Object>> future = (CompletableFuture<RpcResponse<Object>>) transport.sendRequest(request); RpcResponse<Object> response = future.get(); this .check(response, request); return response.getData(); }
当方法调用时,#invoke() 方法会根据调用方法将其转换为网络传输对象并发送给服务端
当服务端调用成功后,返回结果,并将结果返回给客户端
以上就是一个代理类做的实际工作,那么为什么 HelloController 中的 HelloService 的实例对象是一个代理对象呢 ?
我们看 HelloController 有一个 @Component 注解,接下来的操作是在容器初始化过程 完成的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 @Slf4j @Component public class SpringBeanPostProcessor implements BeanPostProcessor { private final RpcRequestTransport rpcClient; public SpringBeanPostProcessor () { this .rpcClient = ExtensionLoader.getExtensionLoader(RpcRequestTransport.class).getExtension("netty" ); } @Override public Object postProcessBeforeInitialization (Object bean, String beanName) throws BeansException { if (bean.getClass().isAnnotationPresent(RpcService.class)) { log.info("[{}] is annotated with [{}]" , bean.getClass().getName(), RpcService.class.getCanonicalName()); RpcService rpcService = bean.getClass().getAnnotation(RpcService.class); RpcServiceConfig.builder() .version(rpcService.version()) .group(rpcService.group()) .service(bean) .build(); } return bean; } @Override public Object postProcessAfterInitialization (Object bean, String beanName) throws BeansException { Class<?> clazz = bean.getClass(); Field[] fields = clazz.getDeclaredFields(); for (Field field : fields) { RpcReference rpcReference = field.getAnnotation(RpcReference.class); if (rpcReference != null ) { RpcServiceConfig config = RpcServiceConfig.builder() .version(rpcReference.version()) .group(rpcReference.group()) .build(); RpcClientProxy rpcClientProxy = new RpcClientProxy (rpcClient, config); Object clientProxy = rpcClientProxy.getProxy(field.getType()); field.setAccessible(true ); try { field.set(bean, clientProxy); } catch (IllegalAccessException e) { e.printStackTrace(); } } } return bean; } }
当我们用一个类来实现 BeanPostProcessor 接口时,能实现 IOC 容器功能的扩展,能在 bean 实例化过程中实现前置处理和后置处理,分别对应下图中的两个部分
我们首先看一下前置处理做了什么
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 @Override public Object postProcessBeforeInitialization (Object bean, String beanName) throws BeansException { if (bean.getClass().isAnnotationPresent(RpcService.class)) { log.info("[{}] is annotated with [{}]" , bean.getClass().getName(), RpcService.class.getCanonicalName()); RpcService rpcService = bean.getClass().getAnnotation(RpcService.class); RpcServiceConfig.builder() .version(rpcService.version()) .group(rpcService.group()) .service(bean) .build(); } return bean; }
判断是否包含注解 @RpcService
根据注解的内容来指定服务的版本和组
根据服务的版本和组将该服务注册到 ZooKeeper
其中 @RpcService 如下
1 2 3 4 5 6 7 8 9 @Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE}) @Inherited public @interface RpcService { String version () default "" ; String group () default "" ; }
综上,如果我们有某个类被 @RpcService 注解标记,则该类会在实例化时,将自身注册到服务中心
接下来我们看看后置处理做了什么
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 public Object postProcessAfterInitialization (Object bean, String beanName) throws BeansException { Class<?> clazz = bean.getClass(); Field[] fields = clazz.getDeclaredFields(); for (Field field : fields) { RpcReference rpcReference = field.getAnnotation(RpcReference.class); if (rpcReference != null ) { RpcServiceConfig config = RpcServiceConfig.builder() .version(rpcReference.version()) .group(rpcReference.group()) .build(); RpcClientProxy rpcClientProxy = new RpcClientProxy (rpcClient, config); Object proxy = rpcClientProxy.getProxy(field.getType()); field.setAccessible(true ); try { field.set(bean, proxy); } catch (IllegalAccessException e) { e.printStackTrace(); } } } return bean; }
获取 bean 实例对象的所有属性,并判断属性是否包含注解 @RpcReference
根据注解的内容来指定需要调用服务的版本和组
获取该属性的类,并实例化一个代理对象
将这个属性的值设置为代理对象
其中,@RpcReference 如下
1 2 3 4 5 6 7 8 9 10 11 @Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.FIELD}) @Inherited public @interface RpcReference { String version () default "" ; String group () default "" ; }
其实关键在于以下几个部分
1 2 3 4 5 6 RpcClientProxy rpcClientProxy = new RpcClientProxy (rpcClient, config);Object proxy = rpcClientProxy.getProxy(field.getType()); field.setAccessible(true ); field.set(bean, proxy);
我们实例化了一个 RpcClientProxy 对象,并调用了其 #getProxy() 方法返回了一个 field 的代理对象
1 2 3 4 public <T> T getProxy (Class<T> clazz) { return (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class []{clazz}, this ); }
这里传入的参数分别是 需要代理的类的加载器、需要代理的类和一个 InvocationHandler 的实现
回顾之前的代码,可以看出这个实现指的就是我们的 RpcClientProxy 自定义调用处理器
最后,指定 field 的值为代理对象
通过以上步骤,我们可以看出 HelloService 的实现类此时被指定为一个代理对象了,而调用这个代理对象的方法时,会被转发到 RpcClientProxy#invoke()
自定义注解
参考:https://juejin.cn/post/6844903942808141832
之前我们提到了 @RpcService 和 @RpcReference 自定义注解,那么 Spring 是如何识别这些注解的呢?
我们通过 @RpcScan 这个注解来扫描 com.richcoder 包下的内容
1 2 3 4 5 6 7 8 9 @Target({ElementType.TYPE, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Import(CustomScannerRegistrar.class) @Documented public @interface RpcScan { String[] basePackage(); }
其实现逻辑主要是在 CustomScannerRegistrar
CustomScannerRegistrar 实现了接口 ImportBeanDefinitionRegistrar 接口,Spring 容器启动时,会扫描所有实现该接口的类,并执行其 #registerBeanDefinitions() 方法,生成 BeanDefinition 对象,为后续实例化做准备
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 public class CustomScannerRegistrar implements ImportBeanDefinitionRegistrar , ResourceLoaderAware { private static final String SPRING_BEAN_BASE_PACKAGE = "com.richcoder" ; private static final String BASE_PACKAGE_ATTRIBUTE_NAME = "basePackage" ; private ResourceLoader resourceLoader; @Override public void setResourceLoader (ResourceLoader resourceLoader) { this .resourceLoader = resourceLoader; } @Override public void registerBeanDefinitions (AnnotationMetadata annotationMetadata, BeanDefinitionRegistry registry) { AnnotationAttributes rpcScanAnnotationAttributes = AnnotationAttributes.fromMap(annotationMetadata.getAnnotationAttributes(RpcScan.class.getName())); String[] rpcBasePackages = new String [0 ]; if (rpcScanAnnotationAttributes != null ) { rpcBasePackages = rpcScanAnnotationAttributes.getStringArray(BASE_PACKAGE_ATTRIBUTE_NAME); } if (rpcBasePackages.length == 0 ) { rpcBasePackages = new String []{((StandardAnnotationMetadata) annotationMetadata).getIntrospectedClass().getPackage().getName()}; } CustomScanner rpcServiceScanner = new CustomScanner (registry, RpcService.class); CustomScanner springBeanScanner = new CustomScanner (registry, Component.class); if (resourceLoader != null ) { rpcServiceScanner.setResourceLoader(resourceLoader); springBeanScanner.setResourceLoader(resourceLoader); } int springBeanAmount = springBeanScanner.scan(SPRING_BEAN_BASE_PACKAGE); log.info("springBeanScanner扫描的数量 [{}]" , springBeanAmount); int rpcServiceCount = rpcServiceScanner.scan(rpcBasePackages); log.info("rpcServiceScanner扫描的数量 [{}]" , rpcServiceCount); } }
其中,CustomScanner 如下,在构造器中指定过滤器,过滤掉那些不是 @RpcService 注解标记的内容
1 2 3 4 5 6 7 8 9 10 11 12 public class CustomScanner extends ClassPathBeanDefinitionScanner { public CustomScanner (BeanDefinitionRegistry registry, Class<? extends Annotation> annotationType) { super (registry); super .addIncludeFilter(new AnnotationTypeFilter (annotationType)); } @Override public int scan (String... basePackages) { return super .scan(basePackages); } }
服务中心 服务注册 服务器启动时,会把标有注解 @RpcService 的类发布到 ZooKeeper 上,看看服务端是如何把服务发布到 Zookeeper 上的
首先还是看之前 SpringBeanPostProcessor#postProcessBeforeInitialization() 方法,在这里开始的服务注册
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 private final ServiceProvider serviceProvider;private final RpcRequestTransport rpcClient;public SpringBeanPostProcessor () { this .serviceProvider = SingletonFactory.getInstance(ZkServiceProviderImpl.class); this .rpcClient = ExtensionLoader.getExtensionLoader(RpcRequestTransport.class).getExtension("netty" ); } public Object postProcessBeforeInitialization (Object bean, String beanName) throws BeansException { if (bean.getClass().isAnnotationPresent(RpcService.class)) { log.info("[{}] is annotated with [{}]" , bean.getClass().getName(), RpcService.class.getCanonicalName()); RpcService rpcService = bean.getClass().getAnnotation(RpcService.class); RpcServiceConfig config = RpcServiceConfig.builder() .version(rpcService.version()) .group(rpcService.group()) .service(bean) .build(); serviceProvider.publishService(config); } return bean; }
ServiceProvider 是一个接口,我们主要采用其实现类 ZkServiceProvider 来实现服务注册
接下来我们看看 ZkServiceProvider#publishService() 方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 private final Map<String, Object> serviceMap;private final Set<String> registeredService;private final ServiceRegistry serviceRegistry;public ZkServiceProviderImpl () { serviceMap = new ConcurrentHashMap <>(); registeredService = ConcurrentHashMap.newKeySet(); serviceRegistry = ExtensionLoader.getExtensionLoader(ServiceRegistry.class).getExtension("zk" ); } public void publishService (RpcServiceConfig rpcServiceConfig) { try { String host = InetAddress.getLocalHost().getHostAddress(); this .addService(rpcServiceConfig); serviceRegistry.registerService(rpcServiceConfig.getRpcServiceName(), new InetSocketAddress (host, RpcServer.PORT)); } catch (UnknownHostException e) { log.error("occur exception when getHostAddress" , e); } } public void addService (RpcServiceConfig rpcServiceConfig) { String rpcServiceName = rpcServiceConfig.getRpcServiceName(); if (registeredService.contains(rpcServiceName)) { return ; } registeredService.add(rpcServiceName); serviceMap.put(rpcServiceName, rpcServiceConfig.getService()); log.info("Add service: {} and interfaces:{}" , rpcServiceName, rpcServiceConfig.getService().getClass().getInterfaces()); }
可以看出 publishService 做了以下的事:
添加服务
检查该服务是否已经被添加
将服务添加到 serviceMap 中,其中 key 为服务名, value 为服务的实现类
1 2 3 public String getRpcServiceName () { return this .getServiceName() + this .getGroup() + this .getVersion(); }
注册服务
添加服务并没有用到 ZooKeeper,ZooKeeper 用于注册服务
ServiceRegistry 是一个用于服务注册的接口,我们看其实现类 ZkServiceRegistryImpl
1 2 3 4 5 6 7 8 9 10 11 public class ZkServiceRegistryImpl implements ServiceRegistry { @Override public void registerService (String rpcServiceName, InetSocketAddress inetSocketAddress) { String servicePath = CuratorUtils.ZK_REGISTER_ROOT_PATH + "/" + rpcServiceName + inetSocketAddress.toString(); CuratorFramework zkClient = CuratorUtils.getZkClient(); CuratorUtils.createPersistentNode(zkClient, servicePath); } }
可以看到,我们在这里实现了服务的注册
接下来我们具体看看服务注册是如何实现的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 private static final int BASE_SLEEP_TIME = 1000 ;private static final int MAX_RETRIES = 3 ;private static final String DEFAULT_ZOOKEEPER_ADDRESS = "127.0.0.1:2181" ;private static CuratorFramework zkClient;public static final String ZK_REGISTER_ROOT_PATH = "/rich-rpc" ;private static final Map<String, List<String>> SERVICE_ADDRESS_MAP = new ConcurrentHashMap <>();private static final Set<String> REGISTERED_PATH_SET = ConcurrentHashMap.newKeySet();private CuratorUtils () {} public static CuratorFramework getZkClient () { Properties properties = PropertiesFileUtil.readPropertiesFile(RpcConfigEnum.RPC_CONFIG_PATH.getPropertyValue()); String zookeeperAddress = properties != null && properties.getProperty(RpcConfigEnum.ZK_ADDRESS.getPropertyValue()) != null ? properties.getProperty(RpcConfigEnum.ZK_ADDRESS.getPropertyValue()) : DEFAULT_ZOOKEEPER_ADDRESS; if (zkClient != null && zkClient.getState() == CuratorFrameworkState.STARTED) { return zkClient; } RetryPolicy retryPolicy = new ExponentialBackoffRetry (BASE_SLEEP_TIME, MAX_RETRIES); zkClient = CuratorFrameworkFactory.builder() .connectString(zookeeperAddress) .retryPolicy(retryPolicy) .build(); zkClient.start(); try { if (!zkClient.blockUntilConnected(30 , TimeUnit.SECONDS)) { throw new RuntimeException ("Time out waiting to connect to ZK!" ); } } catch (InterruptedException e) { e.printStackTrace(); } return zkClient; } public static void createPersistentNode (CuratorFramework zkClient, String path) { try { if (REGISTERED_PATH_SET.contains(path) || zkClient.checkExists().forPath(path) != null ) { log.info("The node already exists. The node is:[{}]" , path); } else { zkClient.create().creatingParentsIfNeeded().withMode(CreateMode.PERSISTENT).forPath(path); log.info("The node was created successfully. The node is:[{}]" , path); } REGISTERED_PATH_SET.add(path); } catch (Exception e) { log.error("create persistent node for path [{}] fail" , path); } }
我们先看 #getZkClient() 做了什么
读取一个文件,这里 RpcConfigEnum.RPC_CONFIG_PATH 指定为 rpc.properties,即读取 resources 目录下的文件名为 rpc.properties 的文件
如果文件存在且文件内属性名为 RpcConfigEnum.ZK_ADDRESS 的属性值不为空,则根据文件内的 Zookeeper 地址进行设置,否则设置为默认地址,这里的 RpcConfigEnum.ZK_ADDRESS 为 rpc.zookeeper.address
如果 ZooKeeper 客户端已经启动且运行,则直接返回
设置重试策略并创建和启动客户端
至此,我们已经可以成功的启动 ZooKeeper 客户端,接下来看看如何将服务注册到 ZooKeeper 客户端内
判断我们是否已经注册过该服务
注册服务,在 ZooKeeper 客户端内添加一个永久节点
永久节点路径 :根路径 / 服务名(接口实现类名+组号+版本号) / 服务器ip+端口号
往 ZooKeeper 中添加一个永久节点意味着当 ZooKeeper 客户端关闭重启后,其节点依旧存在
到此,我们已经可以看到服务可以成功的注册到 ZooKeeper 上了
那么,如果某台服务器需要下线某个服务重启,该怎么办呢?
如果不进行操作,由于先前注册是一个永久节点,意味着 ZooKeeper 上会一直存在这个服务的服务路径,当客户端选择该路径连接服务端时,会发现服务器不提供该服务
因此,我们每次重启服务端时,需要将其原本注册的服务清除,重新注册
我们看 RpcServer 是怎么做的
1 2 3 4 5 public void start () { CustomShutdownHook.getCustomShutdownHook().clearAll(); }
通过调用一个工具类,在每次启动服务器之前清除 ZooKeeper 上的服务
接下来看工具类做了什么
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 public class CustomShutdownHook { private static final CustomShutdownHook CUSTOM_SHUTDOWN_HOOK = new CustomShutdownHook (); public static CustomShutdownHook getCustomShutdownHook () { return CUSTOM_SHUTDOWN_HOOK; } public void clearAll () { log.info("addShutdownHook for clearAll" ); Runtime.getRuntime().addShutdownHook(new Thread (() -> { try { InetSocketAddress inetSocketAddress = new InetSocketAddress (InetAddress.getLocalHost().getHostAddress(), RpcServer.PORT); CuratorUtils.clearRegistry(CuratorUtils.getZkClient(), inetSocketAddress); } catch (UnknownHostException ignored) { } ThreadPoolFactoryUtil.shutDownAllThreadPool(); })); } }
这里我们通过一个新线程将之前注册的服务清除
接下来我们看看清除服务的具体实现
1 2 3 4 5 6 7 8 9 10 11 12 public static void clearRegistry (CuratorFramework zkClient, InetSocketAddress address) { REGISTERED_PATH_SET.stream().parallel().forEach(p -> { try { if (p.endsWith(address.toString())) { zkClient.delete().forPath(p); } } catch (Exception e) { log.error("clear registry for path [{}] fail" , p); } }); log.info("All registered services on the server are cleared:[{}]" , REGISTERED_PATH_SET.toString()); }
遍历 REGISTERED_PATH_SET 集合,查看集合中哪一项是以当前 ip + 端口为结尾的,清除该节点
举个例子,之前我们往 ZooKeeper 中注册有两个服务,其节点分别为 /rich-rpc/com.richcoder.HelloServicegroup1version1/127.0.0.1:9999 和 /rich-rpc/com.richcoder.HelloService2group2version2/127.0.0.1:9999,此时我们下线了 HelloServicegroup1version1 并重启服务器,由于服务器的 ip + 端口是固定的,启动时便会删除以ip + 端口为结尾的路径的节点,即将两个节点都删除了,然后重新注册服务,此时便只有 HelloServicegroup2version2 了
服务发现 客户端发送请求时,需要知道向哪个服务器发送请求,此时就需要在 ZooKeeper 注册中心中获取该服务支持的服务器地址
在代理类 RpcClientProxy 中我们可以看到调用服务时,通过 rpcRequestTransport.sendRpcRequest(rpcRequest); 来发送服务请求
那我们就看看 RpcRequestTransport 的实现类 NettyRpcClient
1 2 3 4 5 6 7 8 9 10 11 private final ServiceDiscovery serviceDiscovery;@Override public Object sendRpcRequest (RpcRequest rpcRequest) { CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture <>(); InetSocketAddress inetSocketAddress = serviceDiscovery.lookupService(rpcRequest); return resultFuture; }
可以看到我们通过 ServiceDiscovery#lookupService() 来确定我们要向哪个服务器发送请求
我们看其实现类 ZkServiceDiscoveryImpl
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 @Slf4j public class ZkServiceDiscoveryImpl implements ServiceDiscovery { private final LoadBalance loadBalance; public ZkServiceDiscoveryImpl () { this .loadBalance = ExtensionLoader.getExtensionLoader(LoadBalance.class).getExtension("loadBalance" ); } @Override public InetSocketAddress lookupService (RpcRequest rpcRequest) { String rpcServiceName = rpcRequest.getRpcServiceName(); CuratorFramework zkClient = CuratorUtils.getZkClient(); List<String> serviceUrlList = CuratorUtils.getChildrenNodes(zkClient, rpcServiceName); if (CollectionUtil.isEmpty(serviceUrlList)) { throw new RpcException (RpcErrorMessageEnum.SERVICE_CAN_NOT_BE_FOUND, rpcServiceName); } String targetServiceUrl = loadBalance.selectServiceAddress(serviceUrlList, rpcRequest); log.info("Successfully found the service address:[{}]" , targetServiceUrl); String[] socketAddressArray = targetServiceUrl.split(":" ); String host = socketAddressArray[0 ]; int port = Integer.parseInt(socketAddressArray[1 ]); return new InetSocketAddress (host, port); } }
可以看出 ZkServiceDiscoveryImpl 主要做了两件事
获取提供服务的地址集合
根据负载均衡策略选择出一个服务器地址
负载均衡的内容之后了解,在这里我们看看如何实现获取提供服务的地址集合
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 public static List<String> getChildrenNodes (CuratorFramework zkClient, String rpcServiceName) { if (SERVICE_ADDRESS_MAP.containsKey(rpcServiceName)) { return SERVICE_ADDRESS_MAP.get(rpcServiceName); } List<String> result = null ; String servicePath = ZK_REGISTER_ROOT_PATH + "/" + rpcServiceName; try { result = zkClient.getChildren().forPath(servicePath); SERVICE_ADDRESS_MAP.put(rpcServiceName, result); registerWatcher(rpcServiceName, zkClient); } catch (Exception e) { log.error("get children nodes for path [{}] fail" , servicePath); } return result; } private static void registerWatcher (String rpcServiceName, CuratorFramework zkClient) throws Exception { String servicePath = ZK_REGISTER_ROOT_PATH + "/" + rpcServiceName; PathChildrenCache pathChildrenCache = new PathChildrenCache (zkClient, servicePath, true ); PathChildrenCacheListener pathChildrenCacheListener = (curatorFramework, pathChildrenCacheEvent) -> { List<String> serviceAddresses = curatorFramework.getChildren().forPath(servicePath); SERVICE_ADDRESS_MAP.put(rpcServiceName, serviceAddresses); }; pathChildrenCache.getListenable().addListener(pathChildrenCacheListener); pathChildrenCache.start(); }
查看之前是否已经获取过该服务的地址列表
获取服务的子节点,这里服务的节点是**根路径 / 服务名(接口实现类名+组号+版本号)**,并将子节点保存到 map 中
监听该节点的变化
监听节点的参考:https://www.cnblogs.com/weihuang6620/p/10821800.html
PathChildrenCache 有以下几个特点
永久监听指定节点下的节点
只能监听指定节点下一级节点的变化,比如说指定节点 ”/example” , 在下面添加 ”node1” 可以监听到,但是添加 ”node1/n1” 就不能被监听到了
可以监听到的事件:节点创建、节点数据的变化、节点删除等
当我们监听到节点的子节点发生变化时,如 /rich-rpc/com.richcoder.HelloServicegroup1version1 本来有两个提供服务的服务器,地址分别为 A 和 B,当服务器 A 下线了这个服务,通过之前的分析我们可以看出服务器 A 重启后,该服务的子节点下就没有服务器 A 的地址了,此时监听到此事件后,会更新提供该服务的地址列表
负载均衡 最后我们看以下负载均衡策略,客户端会根据负载均衡策略来选择一个合适的服务端连接
1 2 3 public interface LoadBalance { String selectServiceAddress (List<String> serviceUrlList, RpcRequest rpcRequest) ; }
看其实现类 AbstractLoadBalance
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 public abstract class AbstractLoadBalance implements LoadBalance { @Override public String selectServiceAddress (List<String> serviceAddresses, RpcRequest rpcRequest) { if (CollectionUtil.isEmpty(serviceAddresses)) { return null ; } if (serviceAddresses.size() == 1 ) { return serviceAddresses.get(0 ); } return doSelect(serviceAddresses, rpcRequest); } protected abstract String doSelect (List<String> serviceAddresses, RpcRequest rpcRequest) ; }
这是一个抽象类,其实现 LoadBalance 接口后,在 selectServiceAddress() 方法中进行一些简单的判断,最后通过 doSelect() 方法来实现选择
我们主要实现 doSelect() 方法来实现不同的负载均衡策略
首先我们看简单的随机负载均衡策略
通过一个随机值在服务列表中选择一个即可
1 2 3 4 5 6 7 public class RandomLoadBalance extends AbstractLoadBalance { @Override protected String doSelect (List<String> serviceAddresses, RpcRequest rpcRequest) { Random random = new Random (); return serviceAddresses.get(random.nextInt(serviceAddresses.size())); } }
接着看一致性哈希负载均衡算法
参考:https://dubbo.apache.org/zh/blog/2019/05/01/dubbo-%E4%B8%80%E8%87%B4%E6%80%A7hash%E8%B4%9F%E8%BD%BD%E5%9D%87%E8%A1%A1%E5%AE%9E%E7%8E%B0%E5%89%96%E6%9E%90/
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 @Slf4j public class ConsistentHashLoadBalance extends AbstractLoadBalance { private final ConcurrentHashMap<String, ConsistentHashSelector> selectors = new ConcurrentHashMap <>(); @Override protected String doSelect (List<String> serviceAddresses, RpcRequest rpcRequest) { int identityHashCode = System.identityHashCode(serviceAddresses); String rpcServiceName = rpcRequest.getRpcServiceName(); ConsistentHashSelector selector = selectors.get(rpcServiceName); if (selector == null || selector.identityHashCode != identityHashCode) { selectors.put(rpcServiceName, new ConsistentHashSelector (serviceAddresses, 160 , identityHashCode)); selector = selectors.get(rpcServiceName); } return selector.select(rpcServiceName + Arrays.stream(rpcRequest.getParameters())); } static class ConsistentHashSelector { private final TreeMap<Long, String> virtualInvokers; private final int identityHashCode; ConsistentHashSelector(List<String> invokers, int replicaNumber, int identityHashCode) { this .virtualInvokers = new TreeMap <>(); this .identityHashCode = identityHashCode; for (String invoker : invokers) { for (int i = 0 ; i < replicaNumber / 4 ; i++) { byte [] digest = md5(invoker + i); for (int h = 0 ; h < 4 ; h++) { long m = hash(digest, h); virtualInvokers.put(m, invoker); } } } } static byte [] md5(String key) { MessageDigest md; try { md = MessageDigest.getInstance("MD5" ); byte [] bytes = key.getBytes(StandardCharsets.UTF_8); md.update(bytes); } catch (NoSuchAlgorithmException e) { throw new IllegalStateException (e.getMessage(), e); } return md.digest(); } static long hash (byte [] digest, int idx) { return ((long ) (digest[3 + idx * 4 ] & 255 ) << 24 | (long ) (digest[2 + idx * 4 ] & 255 ) << 16 | (long ) (digest[1 + idx * 4 ] & 255 ) << 8 | (long ) (digest[idx * 4 ] & 255 )) & 4294967295L ; } public String select (String rpcServiceKey) { byte [] digest = md5(rpcServiceKey); return selectForKey(hash(digest, 0 )); } public String selectForKey (long hashCode) { Map.Entry<Long, String> entry = virtualInvokers.tailMap(hashCode, true ).firstEntry(); if (entry == null ) { entry = virtualInvokers.firstEntry(); } return entry.getValue(); } } }
其中,ConsistentHashSelector#md5() 方法和 ConsistentHashSelector#hash() 方法可以看作是工具方法,分别实现对字符串的编码和哈希值计算
当调用 ConsistentHashLoadBalance#doSelect() 方法时,会先根据我们调用的服务来获取一个选择器
如果选择器不存在则会创建一个选择器,并将选择器保存在 selectors 集合中,最后调用 ConsistentHashSelector#select() 方法来完成结果的选择
我们看以下 select() 方法,它对传进来的值进行 md5 编码后,调用了 selectForKey()
1 2 3 4 5 6 7 8 9 public String selectForKey (long hashCode) { Map.Entry<Long, String> entry = virtualInvokers.tailMap(hashCode, true ).firstEntry(); if (entry == null ) { entry = virtualInvokers.firstEntry(); } return entry.getValue(); }
在 selectForKey() 中我们会选择 TreeMap 中键值大于等于 当前 hashCode 的 Entry 列表,并从列表中返回第一个 值来作为负载均衡选择的服务器地址
那么 TreeMap 中的值是从哪儿来的,这我们得看 ConsistentHashSelector 的构造方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ConsistentHashSelector(List<String> invokers, int replicaNumber, int identityHashCode) { this .virtualInvokers = new TreeMap <>(); this .identityHashCode = identityHashCode; for (String invoker : invokers) { for (int i = 0 ; i < replicaNumber / 4 ; i++) { byte [] digest = md5(invoker + i); for (int h = 0 ; h < 4 ; h++) { long m = hash(digest, h); virtualInvokers.put(m, invoker); } } } }
假设我们服务列表中包含地址 127.0.0.1:20880,传入 replicaNumber 的值为160,那么它会对地址实现拼接 ,从 127.0.0.1:2088000 一直到 127.0.0.1:2088039 来进行编码,并对编码后的结果实现4次散列 ,将最后的散列值作为 key,127.0.0.1:20880 作为值 存入 TreeMap
@SPI
dubbo的SPI介绍:https://mp.weixin.qq.com/s?__biz=MzAwNDA2OTM1Ng==&mid=2453145662&idx=1&sn=0ba56d58eedca7f04b4d013b84080f31&scene=21#wechat_redirect
SPI (Service Provider Interface)
Dubbo 就依靠 SPI 机制实现了插件化功能,几乎将所有的功能组件做成基于 SPI 实现,并且默认提供了很多可以直接使用的扩展点,实现了面向功能进行拆分的对扩展开放的架构 。
先贴一个 Java 的 SPI 机制
1 2 3 4 5 6 7 8 public class SPIMain { public static void main (String[] args) { ServiceLoader<IShout> shouts = ServiceLoader.load(IShout.class); for (IShout s : shouts) { s.shout(); } } }
其中有资源文件 org.foo.demo.IShout
1 2 org.foo.demo.animal.Dog org.foo.demo.animal.Cat
这就是 Java 实现的最简单的 SPI 机制,但是它有缺陷
Java SPI 在查找扩展实现类的时候遍历 SPI 的配置文件并且将实现类全部实例化 ,假设一个实现类初始化过程比较消耗资源且耗时,但是你的代码里面又用不上它,这就产生了资源的浪费。
Dubbo SPI 除了可以按需加载实现类之外,增加了 IOC 和 AOP 的特性,还有个自适应扩展机制。
我们在自己的 RPC 中使用了简略的 Dubbo SPI 机制,接下来看看代码中的实际用法
以 ServiceDiscovery 为例
1 2 3 4 @SPI public interface ServiceDiscovery { InetSocketAddress lookupService (RpcRequest rpcRequest) ; }
我们给 ServiceDiscovery 添加了注解 @SPI,在使用过程中可以自定义配置使用其实现类 ExtensionLoader.getExtensionLoader(ServiceDiscovery.class).getExtension("zk");
这是怎么做到的?
我们先看 ExtensionLoader#getExtensionLoader() 这个方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 private static final Map<Class<?>, ExtensionLoader<?>> EXTENSION_LOADERS = new ConcurrentHashMap <>();public static <S> ExtensionLoader<S> getExtensionLoader (Class<S> type) { if (type == null ) { throw new IllegalArgumentException ("Extension type should not be null." ); } if (!type.isInterface()) { throw new IllegalArgumentException ("Extension type must be an interface." ); } if (type.getAnnotation(SPI.class) == null ) { throw new IllegalArgumentException ("Extension type must be annotated by @SPI" ); } ExtensionLoader<S> extensionLoader = (ExtensionLoader<S>) EXTENSION_LOADERS.get(type); if (extensionLoader == null ) { EXTENSION_LOADERS.putIfAbsent(type, new ExtensionLoader <S>(type)); extensionLoader = (ExtensionLoader<S>) EXTENSION_LOADERS.get(type); } return extensionLoader; }
通过缓存来获取一个类加载器,如果缓存中不存在,则自己创建一个加载器并缓存
通过这个加载器,我们可以通过 ExtensionLoader#getExtension() 加载实现类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 private final Map<String, Holder<Object>> cachedInstances = new ConcurrentHashMap <>();public T getExtension (String name) { if (StringUtil.isBlank(name)) { throw new IllegalArgumentException ("Extension name should not be null or empty." ); } Holder<Object> holder = cachedInstances.get(name); if (holder == null ) { cachedInstances.putIfAbsent(name, new Holder <>()); holder = cachedInstances.get(name); } Object instance = holder.get(); if (instance == null ) { synchronized (holder) { instance = holder.get(); if (instance == null ) { instance = createExtension(name); holder.set(instance); } } } return (T) instance; }
我们看看如何来创建一个实例对象的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 private static final Map<Class<?>, Object> EXTENSION_INSTANCES = new ConcurrentHashMap <>();private T createExtension (String name) { Class<?> clazz = getExtensionClasses().get(name); if (clazz == null ) { throw new RuntimeException ("No such extension of name " + name); } T instance = (T) EXTENSION_INSTANCES.get(clazz); if (instance == null ) { try { EXTENSION_INSTANCES.putIfAbsent(clazz, clazz.newInstance()); instance = (T) EXTENSION_INSTANCES.get(clazz); } catch (Exception e) { log.error(e.getMessage()); } } return instance; }
接下来看看是怎么通过全类名来获取类对象的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 private final Holder<Map<String, Class<?>>> cachedClasses = new Holder <>(); private Map<String, Class<?>> getExtensionClasses() { Map<String, Class<?>> classes = cachedClasses.get(); if (classes == null ) { synchronized (cachedClasses) { classes = cachedClasses.get(); if (classes == null ) { classes = new HashMap <>(); loadDirectory(classes); cachedClasses.set(classes); } } } return classes; }
我们看如何通过目录加载类的实例对象
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 private void loadDirectory (Map<String, Class<?>> extensionClasses) { String fileName = ExtensionLoader.SERVICE_DIRECTORY + type.getName(); try { Enumeration<URL> urls; ClassLoader classLoader = ExtensionLoader.class.getClassLoader(); urls = classLoader.getResources(fileName); if (urls != null ) { while (urls.hasMoreElements()) { URL resourceUrl = urls.nextElement(); loadResource(extensionClasses, classLoader, resourceUrl); } } } catch (IOException e) { log.error(e.getMessage()); } } private void loadResource (Map<String, Class<?>> extensionClasses, ClassLoader classLoader, URL resourceUrl) { try (BufferedReader reader = new BufferedReader (new InputStreamReader (resourceUrl.openStream(), UTF_8))) { String line; while ((line = reader.readLine()) != null ) { final int ci = line.indexOf('#' ); if (ci >= 0 ) { line = line.substring(0 , ci); } line = line.trim(); if (line.length() > 0 ) { try { final int ei = line.indexOf('=' ); String name = line.substring(0 , ei).trim(); String clazzName = line.substring(ei + 1 ).trim(); if (name.length() > 0 && clazzName.length() > 0 ) { Class<?> clazz = classLoader.loadClass(clazzName); extensionClasses.put(name, clazz); } } catch (ClassNotFoundException e) { log.error(e.getMessage()); } } } } catch (IOException e) { log.error(e.getMessage()); } }
参考 Dubbo 的 SPI 机制:
可以看出其实已经省略了很多的功能,如依赖注入、包装、Adaptive等
请求流程 客户端 从之前的动态代理部分,我们已经知道,从请求到返回响应结果这部分,实际上都是代理类在做,也就是 RpcClientProxy#invoke()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 private final RpcRequestTransport transport;private final RpcServiceConfig config;public Object invoke (Object proxy, Method method, Object[] args) throws Throwable { log.info("invoke method: [{}]" , method.getName()); RpcRequest request = RpcRequest.builder() .methodName(method.getName()) .interfaceName(method.getDeclaringClass().getName()) .paramTypes(method.getParameterTypes()) .parameters(args) .requestId(UUID.randomUUID().toString()) .group(config.getGroup()) .version(config.getVersion()) .build(); CompletableFuture<RpcResponse<Object>> future = (CompletableFuture<RpcResponse<Object>>) transport.sendRequest(request); RpcResponse<Object> response = future.get(); this .check(response, request); return response.getData(); }
我们看 RpcRequestTransport 接口的实现类 RpcClient 的 sendRequest() 方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 public Object sendRequest (RpcRequest rpcRequest) { CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture <>(); InetSocketAddress address = serviceDiscovery.lookupService(rpcRequest); Channel channel = getChannel(address); if (channel.isActive()) { unprocessedRequests.put(rpcRequest.getRequestId(), resultFuture); RpcMessage rpcMessage = RpcMessage.builder().data(rpcRequest) .codec(SerializationTypeEnum.KYRO.getCode()) .compress(CompressTypeEnum.GZIP.getCode()) .messageType(RpcConstant.REQUEST_TYPE).build(); channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) future -> { if (future.isSuccess()) { log.info("client send message: [{}]" , rpcMessage); } else { future.channel().close(); resultFuture.completeExceptionally(future.cause()); log.error("Send failed:" , future.cause()); } }); } else { throw new IllegalStateException (); } return resultFuture; } public Channel getChannel (InetSocketAddress address) { Channel channel = channelProvider.get(address); if (channel == null ) { channel = doConnect(address); channelProvider.set(address, channel); } return channel; } @SneakyThrows public Channel doConnect (InetSocketAddress inetSocketAddress) { ChannelFuture channel = bootstrap.connect(inetSocketAddress); channel.sync(); return channel.channel(); }
其主要做了这几件事:
根据负载均衡策略获取服务地址
根据服务地址获取该地址的 Channel
发送协议并返回一个异步结果
这里采用硬编码 的方式来实例化协议对象,后续有待改进
那么我们看发送协议的部分
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 private final ServiceDiscovery serviceDiscovery;private final UnprocessedRequests unprocessedRequests;private final ChannelProvider channelProvider;private final Bootstrap bootstrap;private final EventLoopGroup group;public RpcClient () { bootstrap = new Bootstrap (); group = new NioEventLoopGroup (); bootstrap.group(group) .channel(NioSocketChannel.class) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000 ) .handler(new ChannelInitializer <NioSocketChannel>() { @Override protected void initChannel (NioSocketChannel ch) throws Exception { ch.pipeline().addLast(new IdleStateHandler (0 , 5 , 0 , TimeUnit.SECONDS)); ch.pipeline().addLast(new RpcMessageEncoder ()); ch.pipeline().addLast(new RpcMessageDecoder ()); ch.pipeline().addLast(new RpcClientHandler ()); } }); this .unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class); this .channelProvider = SingletonFactory.getInstance(ChannelProvider.class); this .serviceDiscovery = ExtensionLoader.getExtensionLoader(ServiceDiscovery.class).getExtension("zk" ); }
可以看到我们的客户端启动类添加了4个通道处理器,其作用分别为 心跳包处理器、协议编码器、协议解码器、结果处理器
前三个不做过多的介绍,在这里我们看第四个处理器
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 public class RpcClientHandler extends ChannelInboundHandlerAdapter { private final UnprocessedRequests unprocessedRequests; private final RpcClient client; public RpcClientHandler () { this .unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class); this .client = SingletonFactory.getInstance(RpcClient.class); } @Override public void channelRead (ChannelHandlerContext ctx, Object msg) throws Exception { try { log.info("client receive msg: [{}]" , msg); if (msg instanceof RpcMessage) { RpcMessage tmp = (RpcMessage) msg; byte messageType = tmp.getMessageType(); if (messageType == RpcConstant.HEARTBEAT_RESPONSE_TYPE) { log.info("heart [{}]" , tmp.getData()); } else if (messageType == RpcConstant.RESPONSE_TYPE) { RpcResponse<Object> rpcResponse = (RpcResponse<Object>) tmp.getData(); log.info("response: {}" , rpcResponse); unprocessedRequests.complete(rpcResponse); } } } finally { ReferenceCountUtil.release(msg); } } @Override public void userEventTriggered (ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof IdleStateEvent) { IdleState state = ((IdleStateEvent) evt).state(); if (state == IdleState.WRITER_IDLE) { log.info("write idle happen [{}]" , ctx.channel().remoteAddress()); Channel channel = client.getChannel((InetSocketAddress) ctx.channel().remoteAddress()); RpcMessage rpcMessage = new RpcMessage (); rpcMessage.setCodec(SerializationTypeEnum.KYRO.getCode()); rpcMessage.setCompress(CompressTypeEnum.GZIP.getCode()); rpcMessage.setMessageType(RpcConstant.HEARTBEAT_REQUEST_TYPE); rpcMessage.setData(RpcConstant.PING); channel.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); } } else { super .userEventTriggered(ctx, evt); } } @Override public void exceptionCaught (ChannelHandlerContext ctx, Throwable cause) { log.error("client catch exception:" , cause); cause.printStackTrace(); ctx.close(); } }
在此主要关注的是 #channelRead() 方法,其主要作用是处理服务端的响应结果
在通过前面解码器的处理后,最后发送到该处理器的内容实际上就是我们自定义的协议
可以看到,在之前发送请求和后面处理服务器的返回结果时,我们都用到了 UnprocessedRequests 这个类,那么它做了什么呢
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 public class UnprocessedRequests { private static final Map<String, CompletableFuture<RpcResponse<Object>>> UNPROCESSED_RESPONSE_FUTURES = new ConcurrentHashMap <>(); public void put (String requestId, CompletableFuture<RpcResponse<Object>> future) { UNPROCESSED_RESPONSE_FUTURES.put(requestId, future); } public void complete (RpcResponse<Object> response) { CompletableFuture<RpcResponse<Object>> future = UNPROCESSED_RESPONSE_FUTURES.remove(response.getRequestId()); if (future != null ) { future.complete(response); } else { throw new IllegalStateException (); } } }
当发送请求时,我们会将 requestId 作为 Map 的 key,异步返回结果 CompletableFuture 作为 Map 的 value 存入。
当客户端接收到返回结果时,会将这对 key-value 组合从 Map 中移除,并设置异步结果完成
为什么要使用 CompletableFuture ?主要原因有二
我们在 RpcClientProxy#invoke() 中使用 CompletableFuture#get() 时,第一时间会阻塞,直到获取到返回结果。而获取返回结果我们通过 UnprocessedRequests#complete() 中的 CompletableFuture#complete() 来实现。这样可以保证客户端能够在服务端返回结果后获取结果内容 。
我们在 RpcClientHandler 中获取的 RpcResponse 实际上是在 NIO 线程中获取的,而我们最后返回结果是需要在主线程中返回结果。因此,CompletableFuture 的作用也可以是主线程和 NIO 线程之间数据的传输。
服务端 接下来我们看看服务端的内容
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 @Component public class RpcServer { public static final int PORT = 9999 ; @SneakyThrows public void start () { CustomShutdownHook.getCustomShutdownHook().clearAll(); DefaultEventExecutorGroup serviceHandlerGroup = new DefaultEventExecutorGroup ( RuntimeUtil.cpus() * 2 , ThreadPoolFactoryUtil.createThreadFactory("service-handler-group" , false ) ); NioEventLoopGroup boss = new NioEventLoopGroup (1 ); NioEventLoopGroup worker = new NioEventLoopGroup (); String host = InetAddress.getLocalHost().getHostAddress(); try { ServerBootstrap bootstrap = new ServerBootstrap (); bootstrap.group(boss, worker) .channel(NioServerSocketChannel.class) .childHandler(new ChannelInitializer <NioSocketChannel>() { @Override protected void initChannel (NioSocketChannel ch) throws Exception { ch.pipeline().addLast(new IdleStateHandler (30 , 0 , 0 , TimeUnit.SECONDS)); ch.pipeline().addLast(new RpcMessageEncoder ()); ch.pipeline().addLast(new RpcMessageDecoder ()); ch.pipeline().addLast(serviceHandlerGroup, new RpcServerHandler ()); } }); ChannelFuture future = bootstrap.bind(host, PORT).sync(); future.channel().closeFuture().sync(); } catch (InterruptedException e) { log.error("occur exception when start server:" , e); } finally { log.error("shutdown bossGroup and workerGroup" ); boss.shutdownGracefully(); worker.shutdownGracefully(); } } }
和之前一样,我们关注第四个处理器请求处理器
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 @Slf4j public class RpcServerHandler extends ChannelInboundHandlerAdapter { private final RpcRequestHandler rpcRequestHandler; public RpcServerHandler () { this .rpcRequestHandler = SingletonFactory.getInstance(RpcRequestHandler.class); } @Override public void channelRead (ChannelHandlerContext ctx, Object msg) throws Exception { try { if (msg instanceof RpcMessage) { log.info("server receive msg: [{}] " , msg); byte messageType = ((RpcMessage) msg).getMessageType(); RpcMessage responseMessage = new RpcMessage (); responseMessage.setCodec(SerializationTypeEnum.KYRO.getCode()); responseMessage.setCompress(CompressTypeEnum.GZIP.getCode()); if (messageType == RpcConstant.HEARTBEAT_REQUEST_TYPE) { responseMessage.setMessageType(RpcConstant.HEARTBEAT_RESPONSE_TYPE); responseMessage.setData(RpcConstant.PONG); } else { responseMessage.setMessageType(RpcConstant.RESPONSE_TYPE); RpcRequest rpcRequest = (RpcRequest) ((RpcMessage) msg).getData(); Object result = rpcRequestHandler.handle(rpcRequest); log.info(String.format("server get result: %s" , result.toString())); if (ctx.channel().isActive() && ctx.channel().isWritable()) { RpcResponse<Object> rpcResponse = RpcResponse.success(result, rpcRequest.getRequestId()); responseMessage.setData(rpcResponse); } else { RpcResponse<Object> rpcResponse = RpcResponse.fail(RpcResponseCodeEnum.FAIL); responseMessage.setData(rpcResponse); log.error("not writable now, message dropped" ); } } ctx.writeAndFlush(responseMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); } } finally { ReferenceCountUtil.release(msg); } } @Override public void userEventTriggered (ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof IdleStateEvent) { IdleState state = ((IdleStateEvent) evt).state(); if (state == IdleState.READER_IDLE) { log.info("idle check happen, so close the connection" ); ctx.close(); } } else { super .userEventTriggered(ctx, evt); } } @Override public void exceptionCaught (ChannelHandlerContext ctx, Throwable cause) { log.error("server catch exception" ); cause.printStackTrace(); ctx.close(); } }
还是看 channelRead() 方法,主要做了以下几件事
解析客户端发送过来的协议,判断发送内容是否为心跳包
如果是心跳包则同样返回一个心跳包回去,否则调用服务来处理方法调用
将 response 作为消息体写入协议,并发送协议给客户端
看服务端的处理器时,可以看到我们为 RpcServerHandler 指定了一个线程池 serviceHandlerGroup,同样这个线程池在每次服务端启动时 CustomShutdownHook.getCustomShutdownHook().clearAll() 会被清理,网上说这是为了将处理过程交给自定义的线程池来处理,而不是 NIO 线程池,这里有啥好处暂时还不是很清楚。