RPC小记

源码参考:https://github.com/Snailclimb/guide-rpc-framework

参考文档:https://www.yuque.com/books/share/b7a2512c-6f7a-4afe-9d7e-5936b4c4cab0

动态代理

首先看客户端的代码

1
2
3
4
5
6
7
8
@RpcScan(basePackage = {"com.richcoder"})
public class ClientMain {
public static void main(String[] args) throws InterruptedException {
AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(ClientMain.class);
HelloController controller = (HelloController) applicationContext.getBean("helloController");
controller.test();
}
}

在这里我们调用了 controller.test();

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@Component
public class HelloController {
@RpcReference(version = "version1", group = "test1")
private HelloService helloService;

public void test() throws InterruptedException {
String hello = this.helloService.hello(new Hello("111", "222"));
//如需使用 assert 断言,需要在 VM options 添加参数:-ea
assert "Hello description is 222".equals(hello);
Thread.sleep(12000);
for (int i = 0; i < 10; i++) {
System.out.println(helloService.hello(new Hello("111", "222")));
}
}
}

可以看出我们客户端是直接调用了 helloService.hello(),并没有在客户端的代码内看到将调用方法转换为网络传输对象

这些隐藏的操作都是由动态代理来完成的,接下来看看如何实现动态代理

首先我们需要一个客户端的代理类,为我们实现将调用方法转换成网络传输对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@Slf4j
public class RpcClientProxy implements InvocationHandler {

private static final String INTERFACE_NAME = "interfaceName";
private final RpcRequestTransport transport;
private final RpcServiceConfig config;

public RpcClientProxy(RpcRequestTransport transport, RpcServiceConfig config) {
this.transport = transport;
this.config = config;
}
public <T> T getProxy(Class<T> clazz) {
return (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class[]{clazz}, this);
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
log.info("invoke method: [{}]", method.getName());
// 构建请求
RpcRequest request = RpcRequest.builder()
.methodName(method.getName())
.interfaceName(method.getDeclaringClass().getName())
.paramTypes(method.getParameterTypes())
.parameters(args)
.requestId(UUID.randomUUID().toString())
.group(config.getGroup())
.version(config.getVersion())
.build();
// 发送请求
CompletableFuture<RpcResponse<Object>> future = (CompletableFuture<RpcResponse<Object>>) transport.sendRequest(request);
// 获取响应
RpcResponse<Object> response = future.get();
// 检查请求和响应的合理
this.check(response, request);
return response.getData();
}

private void check(RpcResponse<Object> rpcResponse, RpcRequest rpcRequest) {
if (rpcResponse == null) {
throw new RpcException(RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName());
}
if (!rpcRequest.getRequestId().equals(rpcResponse.getRequestId())) {
throw new RpcException(RpcErrorMessageEnum.REQUEST_NOT_MATCH_RESPONSE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName());
}
if (rpcResponse.getCode() == null || !rpcResponse.getCode().equals(RpcResponseCodeEnum.SUCCESS.getCode())) {
throw new RpcException(RpcErrorMessageEnum.SERVICE_INVOCATION_FAILURE, INTERFACE_NAME + ":" + rpcRequest.getInterfaceName());
}
}
}

我们采用一个类来实现 InvocationHandler 接口,完成自定义调用处理器,如果某个代理类将其指定为处理器,则该代理类调用方法时会自动将方法分配到 #invoke() 来调用

我们看 #invoke() 方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
log.info("invoke method: [{}]", method.getName());
// 构建请求
RpcRequest request = RpcRequest.builder()
.methodName(method.getName())
.interfaceName(method.getDeclaringClass().getName())
.paramTypes(method.getParameterTypes())
.parameters(args)
.requestId(UUID.randomUUID().toString())
.group(config.getGroup())
.version(config.getVersion())
.build();
// 发送请求
CompletableFuture<RpcResponse<Object>> future = (CompletableFuture<RpcResponse<Object>>) transport.sendRequest(request);
// 获取响应
RpcResponse<Object> response = future.get();
// 检查请求和响应的合理
this.check(response, request);
return response.getData();
}

当方法调用时,#invoke() 方法会根据调用方法将其转换为网络传输对象并发送给服务端

当服务端调用成功后,返回结果,并将结果返回给客户端

以上就是一个代理类做的实际工作,那么为什么 HelloController 中的 HelloService 的实例对象是一个代理对象呢

我们看 HelloController 有一个 @Component 注解,接下来的操作是在容器初始化过程完成的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@Slf4j
@Component
public class SpringBeanPostProcessor implements BeanPostProcessor {
//private final ServiceProvider serviceProvider;
private final RpcRequestTransport rpcClient;

public SpringBeanPostProcessor() {
//this.serviceProvider = SingletonFactory.getInstance(ZkServiceProviderImpl.class);
this.rpcClient = ExtensionLoader.getExtensionLoader(RpcRequestTransport.class).getExtension("netty");
}

@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
// 判断 bean 是否含有注解 RpcService
if (bean.getClass().isAnnotationPresent(RpcService.class)) {
log.info("[{}] is annotated with [{}]", bean.getClass().getName(), RpcService.class.getCanonicalName());
RpcService rpcService = bean.getClass().getAnnotation(RpcService.class);
// 配置实现类的版本和组
RpcServiceConfig.builder()
.version(rpcService.version())
.group(rpcService.group())
.service(bean)
.build();
// TODO 往zk中注册服务
}
return bean;
}

@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
Class<?> clazz = bean.getClass();
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
// 判断 bean 是否含有标有注解RpcReference 的属性
RpcReference rpcReference = field.getAnnotation(RpcReference.class);
if (rpcReference != null) {
// 配置版本和组
RpcServiceConfig config = RpcServiceConfig.builder()
.version(rpcReference.version())
.group(rpcReference.group())
.build();
RpcClientProxy rpcClientProxy = new RpcClientProxy(rpcClient, config);
// 获取接口的代理类
// 当调用接口方法时 会调用代理类的invoke()方法
Object clientProxy = rpcClientProxy.getProxy(field.getType());
field.setAccessible(true);
try {
// 将bean的属性值设置为代理类
field.set(bean, clientProxy);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
return bean;
}
}

当我们用一个类来实现 BeanPostProcessor 接口时,能实现 IOC 容器功能的扩展,能在 bean 实例化过程中实现前置处理和后置处理,分别对应下图中的两个部分

bean 的实例化

我们首先看一下前置处理做了什么

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
// 判断 bean 是否含有注解 RpcService
if (bean.getClass().isAnnotationPresent(RpcService.class)) {
log.info("[{}] is annotated with [{}]", bean.getClass().getName(), RpcService.class.getCanonicalName());
RpcService rpcService = bean.getClass().getAnnotation(RpcService.class);
// 配置实现类的版本和组
RpcServiceConfig.builder()
.version(rpcService.version())
.group(rpcService.group())
.service(bean)
.build();
// TODO 往zk中注册服务
}
return bean;
}
  1. 判断是否包含注解 @RpcService
  2. 根据注解的内容来指定服务的版本和组
  3. 根据服务的版本和组将该服务注册到 ZooKeeper

其中 @RpcService 如下

1
2
3
4
5
6
7
8
9
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
@Inherited
public @interface RpcService {
String version() default "";

String group() default "";
}

综上,如果我们有某个类被 @RpcService 注解标记,则该类会在实例化时,将自身注册到服务中心

接下来我们看看后置处理做了什么

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
Class<?> clazz = bean.getClass();
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
// 判断 bean 是否含有标有注解RpcReference 的属性
RpcReference rpcReference = field.getAnnotation(RpcReference.class);
if (rpcReference != null) {
// 配置版本和组
RpcServiceConfig config = RpcServiceConfig.builder()
.version(rpcReference.version())
.group(rpcReference.group())
.build();
RpcClientProxy rpcClientProxy = new RpcClientProxy(rpcClient, config);
// 获取接口的代理类
// 当调用接口方法时 会调用代理类的invoke()方法
Object proxy = rpcClientProxy.getProxy(field.getType());
field.setAccessible(true);
try {
// 将bean的属性值设置为代理类
field.set(bean, proxy);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
return bean;
}
  1. 获取 bean 实例对象的所有属性,并判断属性是否包含注解 @RpcReference
  2. 根据注解的内容来指定需要调用服务的版本和组
  3. 获取该属性的类,并实例化一个代理对象
  4. 将这个属性的值设置为代理对象

其中,@RpcReference 如下

1
2
3
4
5
6
7
8
9
10
11
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
@Inherited
public @interface RpcReference {

String version() default "";

String group() default "";

}

其实关键在于以下几个部分

1
2
3
4
5
6
//    @RpcReference(version = "version1", group = "test1")
// private HelloService helloService;
RpcClientProxy rpcClientProxy = new RpcClientProxy(rpcClient, config);
Object proxy = rpcClientProxy.getProxy(field.getType()); // field.getType() -> HelloService
field.setAccessible(true);
field.set(bean, proxy);

我们实例化了一个 RpcClientProxy 对象,并调用了其 #getProxy() 方法返回了一个 field 的代理对象

1
2
3
4
public <T> T getProxy(Class<T> clazz) {
// ClassLoader loader Class<?>[] interfaces InvocationHandler h
return (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class[]{clazz}, this);
}

这里传入的参数分别是 需要代理的类的加载器、需要代理的类和一个 InvocationHandler 的实现

回顾之前的代码,可以看出这个实现指的就是我们的 RpcClientProxy 自定义调用处理器

最后,指定 field 的值为代理对象

通过以上步骤,我们可以看出 HelloService 的实现类此时被指定为一个代理对象了,而调用这个代理对象的方法时,会被转发到 RpcClientProxy#invoke()

自定义注解

参考:https://juejin.cn/post/6844903942808141832

之前我们提到了 @RpcService@RpcReference 自定义注解,那么 Spring 是如何识别这些注解的呢?

我们通过 @RpcScan 这个注解来扫描 com.richcoder 包下的内容

1
2
3
4
5
6
7
8
9
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Import(CustomScannerRegistrar.class) // 具体的实现逻辑
@Documented
public @interface RpcScan {

String[] basePackage();

}

其实现逻辑主要是在 CustomScannerRegistrar

CustomScannerRegistrar 实现了接口 ImportBeanDefinitionRegistrar 接口,Spring 容器启动时,会扫描所有实现该接口的类,并执行其 #registerBeanDefinitions() 方法,生成 BeanDefinition 对象,为后续实例化做准备

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
public class CustomScannerRegistrar implements ImportBeanDefinitionRegistrar, ResourceLoaderAware {

private static final String SPRING_BEAN_BASE_PACKAGE = "com.richcoder";
private static final String BASE_PACKAGE_ATTRIBUTE_NAME = "basePackage";
private ResourceLoader resourceLoader;

@Override
public void setResourceLoader(ResourceLoader resourceLoader) {
this.resourceLoader = resourceLoader;
}

@Override
public void registerBeanDefinitions(AnnotationMetadata annotationMetadata, BeanDefinitionRegistry registry) {
//获取注解的属性和值
AnnotationAttributes rpcScanAnnotationAttributes = AnnotationAttributes.fromMap(annotationMetadata.getAnnotationAttributes(RpcScan.class.getName()));
String[] rpcBasePackages = new String[0];
if (rpcScanAnnotationAttributes != null) {
//获取到basePackage的值
rpcBasePackages = rpcScanAnnotationAttributes.getStringArray(BASE_PACKAGE_ATTRIBUTE_NAME);
}
if (rpcBasePackages.length == 0) {
//如果没有设置basePackage 扫描路径,就扫描对应包下面的值
rpcBasePackages = new String[]{((StandardAnnotationMetadata) annotationMetadata).getIntrospectedClass().getPackage().getName()};
}
// 自定义包扫描器
// Scan the RpcService annotation
CustomScanner rpcServiceScanner = new CustomScanner(registry, RpcService.class);
// Scan the Component annotation
CustomScanner springBeanScanner = new CustomScanner(registry, Component.class);
if (resourceLoader != null) {
rpcServiceScanner.setResourceLoader(resourceLoader);
springBeanScanner.setResourceLoader(resourceLoader);
}
// 扫描注解
int springBeanAmount = springBeanScanner.scan(SPRING_BEAN_BASE_PACKAGE);
log.info("springBeanScanner扫描的数量 [{}]", springBeanAmount);
int rpcServiceCount = rpcServiceScanner.scan(rpcBasePackages);
log.info("rpcServiceScanner扫描的数量 [{}]", rpcServiceCount);
}
}

其中,CustomScanner 如下,在构造器中指定过滤器,过滤掉那些不是 @RpcService 注解标记的内容

1
2
3
4
5
6
7
8
9
10
11
12
public class CustomScanner extends ClassPathBeanDefinitionScanner {
public CustomScanner(BeanDefinitionRegistry registry, Class<? extends Annotation> annotationType) {
super(registry);
//添加过滤条件,这里是只添加了特定注解annotationType才会被扫描到
super.addIncludeFilter(new AnnotationTypeFilter(annotationType));
}

@Override
public int scan(String... basePackages) {
return super.scan(basePackages);
}
}

服务中心

服务注册

服务器启动时,会把标有注解 @RpcService 的类发布到 ZooKeeper 上,看看服务端是如何把服务发布到 Zookeeper 上的

首先还是看之前 SpringBeanPostProcessor#postProcessBeforeInitialization() 方法,在这里开始的服务注册

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
private final ServiceProvider serviceProvider;
private final RpcRequestTransport rpcClient;
public SpringBeanPostProcessor() {
this.serviceProvider = SingletonFactory.getInstance(ZkServiceProviderImpl.class);
this.rpcClient = ExtensionLoader.getExtensionLoader(RpcRequestTransport.class).getExtension("netty");
}
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
// 判断 bean 是否含有注解 RpcService
if (bean.getClass().isAnnotationPresent(RpcService.class)) {
log.info("[{}] is annotated with [{}]", bean.getClass().getName(), RpcService.class.getCanonicalName());
RpcService rpcService = bean.getClass().getAnnotation(RpcService.class);
// 配置实现类的版本和组
RpcServiceConfig config = RpcServiceConfig.builder()
.version(rpcService.version())
.group(rpcService.group())
.service(bean)
.build();
// 往zk中注册服务
serviceProvider.publishService(config);
}
return bean;
}

ServiceProvider 是一个接口,我们主要采用其实现类 ZkServiceProvider 来实现服务注册

接下来我们看看 ZkServiceProvider#publishService() 方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
private final Map<String, Object> serviceMap;
private final Set<String> registeredService;
private final ServiceRegistry serviceRegistry;

public ZkServiceProviderImpl() {
serviceMap = new ConcurrentHashMap<>();
registeredService = ConcurrentHashMap.newKeySet();
serviceRegistry = ExtensionLoader.getExtensionLoader(ServiceRegistry.class).getExtension("zk");
}
public void publishService(RpcServiceConfig rpcServiceConfig) {
try {
// 获取本机的ip
String host = InetAddress.getLocalHost().getHostAddress();
this.addService(rpcServiceConfig);
serviceRegistry.registerService(rpcServiceConfig.getRpcServiceName(), new InetSocketAddress(host, RpcServer.PORT));
} catch (UnknownHostException e) {
log.error("occur exception when getHostAddress", e);
}
}
public void addService(RpcServiceConfig rpcServiceConfig) {
// 获取服务名
String rpcServiceName = rpcServiceConfig.getRpcServiceName();
if (registeredService.contains(rpcServiceName)) {
return;
}
// 将服务名作为 key 添加到 set
registeredService.add(rpcServiceName);
// 保存服务
serviceMap.put(rpcServiceName, rpcServiceConfig.getService());
log.info("Add service: {} and interfaces:{}", rpcServiceName, rpcServiceConfig.getService().getClass().getInterfaces());
}

可以看出 publishService 做了以下的事:

  1. 添加服务

    1. 检查该服务是否已经被添加

    2. 将服务添加到 serviceMap 中,其中 key 为服务名, value 为服务的实现类

      1
      2
      3
      public String getRpcServiceName() {
      return this.getServiceName() + this.getGroup() + this.getVersion();
      }
  2. 注册服务

添加服务并没有用到 ZooKeeper,ZooKeeper 用于注册服务

ServiceRegistry 是一个用于服务注册的接口,我们看其实现类 ZkServiceRegistryImpl

1
2
3
4
5
6
7
8
9
10
11
public class ZkServiceRegistryImpl implements ServiceRegistry {
@Override
public void registerService(String rpcServiceName, InetSocketAddress inetSocketAddress) {
// 获取 zookeeper 服务地址
String servicePath = CuratorUtils.ZK_REGISTER_ROOT_PATH + "/" + rpcServiceName + inetSocketAddress.toString();
// 获取 zookeeper 客户端
CuratorFramework zkClient = CuratorUtils.getZkClient();
// 注册服务
CuratorUtils.createPersistentNode(zkClient, servicePath);
}
}

可以看到,我们在这里实现了服务的注册

接下来我们具体看看服务注册是如何实现的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
private static final int BASE_SLEEP_TIME = 1000;
private static final int MAX_RETRIES = 3;
private static final String DEFAULT_ZOOKEEPER_ADDRESS = "127.0.0.1:2181";
private static CuratorFramework zkClient;

public static final String ZK_REGISTER_ROOT_PATH = "/rich-rpc";
private static final Map<String, List<String>> SERVICE_ADDRESS_MAP = new ConcurrentHashMap<>();
private static final Set<String> REGISTERED_PATH_SET = ConcurrentHashMap.newKeySet();

private CuratorUtils() {
}

public static CuratorFramework getZkClient() {
// 读取文件
Properties properties = PropertiesFileUtil.readPropertiesFile(RpcConfigEnum.RPC_CONFIG_PATH.getPropertyValue());
// 读取文件下的zk地址
String zookeeperAddress = properties != null
&&
properties.getProperty(RpcConfigEnum.ZK_ADDRESS.getPropertyValue()) != null ?
properties.getProperty(RpcConfigEnum.ZK_ADDRESS.getPropertyValue()) : DEFAULT_ZOOKEEPER_ADDRESS;
// if zkClient has been started, return directly
if (zkClient != null && zkClient.getState() == CuratorFrameworkState.STARTED) {
return zkClient;
}
// Retry strategy. Retry 3 times, and will increase the sleep time between retries.
RetryPolicy retryPolicy = new ExponentialBackoffRetry(BASE_SLEEP_TIME, MAX_RETRIES);
zkClient = CuratorFrameworkFactory.builder()
// the server to connect to (can be a server list)
.connectString(zookeeperAddress)
.retryPolicy(retryPolicy)
.build();
zkClient.start();
try {
// wait 30s until connect to the zookeeper
if (!zkClient.blockUntilConnected(30, TimeUnit.SECONDS)) {
throw new RuntimeException("Time out waiting to connect to ZK!");
}
} catch (InterruptedException e) {
e.printStackTrace();
}
return zkClient;
}

public static void createPersistentNode(CuratorFramework zkClient, String path) {
try {
if (REGISTERED_PATH_SET.contains(path) || zkClient.checkExists().forPath(path) != null) {
log.info("The node already exists. The node is:[{}]", path);
} else {
//eg: /my-rpc/github.javaguide.HelloService/127.0.0.1:9999
zkClient.create().creatingParentsIfNeeded().withMode(CreateMode.PERSISTENT).forPath(path);
log.info("The node was created successfully. The node is:[{}]", path);
}
REGISTERED_PATH_SET.add(path);
} catch (Exception e) {
log.error("create persistent node for path [{}] fail", path);
}
}

我们先看 #getZkClient() 做了什么

  1. 读取一个文件,这里 RpcConfigEnum.RPC_CONFIG_PATH 指定为 rpc.properties,即读取 resources 目录下的文件名为 rpc.properties 的文件
  2. 如果文件存在且文件内属性名为 RpcConfigEnum.ZK_ADDRESS 的属性值不为空,则根据文件内的 Zookeeper 地址进行设置,否则设置为默认地址,这里的 RpcConfigEnum.ZK_ADDRESSrpc.zookeeper.address
  3. 如果 ZooKeeper 客户端已经启动且运行,则直接返回
  4. 设置重试策略并创建和启动客户端

至此,我们已经可以成功的启动 ZooKeeper 客户端,接下来看看如何将服务注册到 ZooKeeper 客户端内

  1. 判断我们是否已经注册过该服务
  2. 注册服务,在 ZooKeeper 客户端内添加一个永久节点

永久节点路径:根路径 / 服务名(接口实现类名+组号+版本号) / 服务器ip+端口号

往 ZooKeeper 中添加一个永久节点意味着当 ZooKeeper 客户端关闭重启后,其节点依旧存在

到此,我们已经可以看到服务可以成功的注册到 ZooKeeper 上了

那么,如果某台服务器需要下线某个服务重启,该怎么办呢?

如果不进行操作,由于先前注册是一个永久节点,意味着 ZooKeeper 上会一直存在这个服务的服务路径,当客户端选择该路径连接服务端时,会发现服务器不提供该服务

因此,我们每次重启服务端时,需要将其原本注册的服务清除,重新注册

我们看 RpcServer 是怎么做的

1
2
3
4
5
public void start() {
// 清除zk服务
CustomShutdownHook.getCustomShutdownHook().clearAll();
// 以下内容为启动服务器
}

通过调用一个工具类,在每次启动服务器之前清除 ZooKeeper 上的服务

接下来看工具类做了什么

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public class CustomShutdownHook {
private static final CustomShutdownHook CUSTOM_SHUTDOWN_HOOK = new CustomShutdownHook();

public static CustomShutdownHook getCustomShutdownHook() {
return CUSTOM_SHUTDOWN_HOOK;
}

public void clearAll() {
log.info("addShutdownHook for clearAll");
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
InetSocketAddress inetSocketAddress = new InetSocketAddress(InetAddress.getLocalHost().getHostAddress(), RpcServer.PORT);
// 清除服务
CuratorUtils.clearRegistry(CuratorUtils.getZkClient(), inetSocketAddress);
} catch (UnknownHostException ignored) {
}
ThreadPoolFactoryUtil.shutDownAllThreadPool();
}));
}
}

这里我们通过一个新线程将之前注册的服务清除

接下来我们看看清除服务的具体实现

1
2
3
4
5
6
7
8
9
10
11
12
public static void clearRegistry(CuratorFramework zkClient, InetSocketAddress address) {
REGISTERED_PATH_SET.stream().parallel().forEach(p -> {
try {
if (p.endsWith(address.toString())) {
zkClient.delete().forPath(p);
}
} catch (Exception e) {
log.error("clear registry for path [{}] fail", p);
}
});
log.info("All registered services on the server are cleared:[{}]", REGISTERED_PATH_SET.toString());
}

遍历 REGISTERED_PATH_SET 集合,查看集合中哪一项是以当前 ip + 端口为结尾的,清除该节点

举个例子,之前我们往 ZooKeeper 中注册有两个服务,其节点分别为 /rich-rpc/com.richcoder.HelloServicegroup1version1/127.0.0.1:9999/rich-rpc/com.richcoder.HelloService2group2version2/127.0.0.1:9999,此时我们下线了 HelloServicegroup1version1 并重启服务器,由于服务器的 ip + 端口是固定的,启动时便会删除以ip + 端口为结尾的路径的节点,即将两个节点都删除了,然后重新注册服务,此时便只有 HelloServicegroup2version2

服务发现

客户端发送请求时,需要知道向哪个服务器发送请求,此时就需要在 ZooKeeper 注册中心中获取该服务支持的服务器地址

在代理类 RpcClientProxy 中我们可以看到调用服务时,通过 rpcRequestTransport.sendRpcRequest(rpcRequest); 来发送服务请求

那我们就看看 RpcRequestTransport 的实现类 NettyRpcClient

1
2
3
4
5
6
7
8
9
10
11
private final ServiceDiscovery serviceDiscovery;
@Override
public Object sendRpcRequest(RpcRequest rpcRequest) {
// build return value
CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture<>();
// get server address
InetSocketAddress inetSocketAddress = serviceDiscovery.lookupService(rpcRequest);
// get server address related channel
// 发送服务请求
return resultFuture;
}

可以看到我们通过 ServiceDiscovery#lookupService() 来确定我们要向哪个服务器发送请求

我们看其实现类 ZkServiceDiscoveryImpl

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@Slf4j
public class ZkServiceDiscoveryImpl implements ServiceDiscovery {
private final LoadBalance loadBalance;

public ZkServiceDiscoveryImpl() {
this.loadBalance = ExtensionLoader.getExtensionLoader(LoadBalance.class).getExtension("loadBalance");
}

@Override
public InetSocketAddress lookupService(RpcRequest rpcRequest) {
String rpcServiceName = rpcRequest.getRpcServiceName();
CuratorFramework zkClient = CuratorUtils.getZkClient();
// 获取提供服务的服务器地址集合
List<String> serviceUrlList = CuratorUtils.getChildrenNodes(zkClient, rpcServiceName);
if (CollectionUtil.isEmpty(serviceUrlList)) {
throw new RpcException(RpcErrorMessageEnum.SERVICE_CAN_NOT_BE_FOUND, rpcServiceName);
}
// load balancing 通过负载均衡确认一个服务器地址
String targetServiceUrl = loadBalance.selectServiceAddress(serviceUrlList, rpcRequest);
log.info("Successfully found the service address:[{}]", targetServiceUrl);
String[] socketAddressArray = targetServiceUrl.split(":");
String host = socketAddressArray[0];
int port = Integer.parseInt(socketAddressArray[1]);
return new InetSocketAddress(host, port);
}
}

可以看出 ZkServiceDiscoveryImpl 主要做了两件事

  1. 获取提供服务的地址集合
  2. 根据负载均衡策略选择出一个服务器地址

负载均衡的内容之后了解,在这里我们看看如何实现获取提供服务的地址集合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public static List<String> getChildrenNodes(CuratorFramework zkClient, String rpcServiceName) {
// 根据服务名获取 ip 列表
if (SERVICE_ADDRESS_MAP.containsKey(rpcServiceName)) {
return SERVICE_ADDRESS_MAP.get(rpcServiceName);
}
List<String> result = null;
String servicePath = ZK_REGISTER_ROOT_PATH + "/" + rpcServiceName;
try {
// 查询服务名下节点的值
result = zkClient.getChildren().forPath(servicePath);
// 保存节点值
SERVICE_ADDRESS_MAP.put(rpcServiceName, result);
// 监听节点
registerWatcher(rpcServiceName, zkClient);
} catch (Exception e) {
log.error("get children nodes for path [{}] fail", servicePath);
}
return result;
}

private static void registerWatcher(String rpcServiceName, CuratorFramework zkClient) throws Exception {
String servicePath = ZK_REGISTER_ROOT_PATH + "/" + rpcServiceName;
PathChildrenCache pathChildrenCache = new PathChildrenCache(zkClient, servicePath, true);
PathChildrenCacheListener pathChildrenCacheListener = (curatorFramework, pathChildrenCacheEvent) -> {
// 子节点发生变化则将新的子节点存入map
List<String> serviceAddresses = curatorFramework.getChildren().forPath(servicePath);
SERVICE_ADDRESS_MAP.put(rpcServiceName, serviceAddresses);
};
pathChildrenCache.getListenable().addListener(pathChildrenCacheListener);
pathChildrenCache.start();
}
  1. 查看之前是否已经获取过该服务的地址列表
  2. 获取服务的子节点,这里服务的节点是**根路径 / 服务名(接口实现类名+组号+版本号)**,并将子节点保存到 map 中
  3. 监听该节点的变化

监听节点的参考:https://www.cnblogs.com/weihuang6620/p/10821800.html

PathChildrenCache 有以下几个特点

  • 永久监听指定节点下的节点
  • 只能监听指定节点下一级节点的变化,比如说指定节点 ”/example” , 在下面添加 ”node1” 可以监听到,但是添加 ”node1/n1” 就不能被监听到了
  • 可以监听到的事件:节点创建、节点数据的变化、节点删除等

当我们监听到节点的子节点发生变化时,如 /rich-rpc/com.richcoder.HelloServicegroup1version1 本来有两个提供服务的服务器,地址分别为 A 和 B,当服务器 A 下线了这个服务,通过之前的分析我们可以看出服务器 A 重启后,该服务的子节点下就没有服务器 A 的地址了,此时监听到此事件后,会更新提供该服务的地址列表

负载均衡

最后我们看以下负载均衡策略,客户端会根据负载均衡策略来选择一个合适的服务端连接

1
2
3
public interface LoadBalance {
String selectServiceAddress(List<String> serviceUrlList, RpcRequest rpcRequest);
}

看其实现类 AbstractLoadBalance

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public abstract class AbstractLoadBalance implements LoadBalance {
@Override
public String selectServiceAddress(List<String> serviceAddresses, RpcRequest rpcRequest) {
if (CollectionUtil.isEmpty(serviceAddresses)) {
return null;
}
if (serviceAddresses.size() == 1) {
return serviceAddresses.get(0);
}
return doSelect(serviceAddresses, rpcRequest);
}

protected abstract String doSelect(List<String> serviceAddresses, RpcRequest rpcRequest);

}

这是一个抽象类,其实现 LoadBalance 接口后,在 selectServiceAddress() 方法中进行一些简单的判断,最后通过 doSelect() 方法来实现选择

我们主要实现 doSelect() 方法来实现不同的负载均衡策略

首先我们看简单的随机负载均衡策略

通过一个随机值在服务列表中选择一个即可

1
2
3
4
5
6
7
public class RandomLoadBalance extends AbstractLoadBalance {
@Override
protected String doSelect(List<String> serviceAddresses, RpcRequest rpcRequest) {
Random random = new Random();
return serviceAddresses.get(random.nextInt(serviceAddresses.size()));
}
}

接着看一致性哈希负载均衡算法

参考:https://dubbo.apache.org/zh/blog/2019/05/01/dubbo-%E4%B8%80%E8%87%B4%E6%80%A7hash%E8%B4%9F%E8%BD%BD%E5%9D%87%E8%A1%A1%E5%AE%9E%E7%8E%B0%E5%89%96%E6%9E%90/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@Slf4j
public class ConsistentHashLoadBalance extends AbstractLoadBalance {
private final ConcurrentHashMap<String, ConsistentHashSelector> selectors = new ConcurrentHashMap<>();

@Override
protected String doSelect(List<String> serviceAddresses, RpcRequest rpcRequest) {
// 获取对象的hash值
int identityHashCode = System.identityHashCode(serviceAddresses);
// build rpc service name by rpcRequest
String rpcServiceName = rpcRequest.getRpcServiceName();
// 获取一个选择器
ConsistentHashSelector selector = selectors.get(rpcServiceName);
// check for updates
if (selector == null || selector.identityHashCode != identityHashCode) {
selectors.put(rpcServiceName, new ConsistentHashSelector(serviceAddresses, 160, identityHashCode));
selector = selectors.get(rpcServiceName);
}
return selector.select(rpcServiceName + Arrays.stream(rpcRequest.getParameters()));
}

static class ConsistentHashSelector {
private final TreeMap<Long, String> virtualInvokers;

private final int identityHashCode;

ConsistentHashSelector(List<String> invokers, int replicaNumber, int identityHashCode) {
this.virtualInvokers = new TreeMap<>();
this.identityHashCode = identityHashCode;

// 遍历服务地址
for (String invoker : invokers) {
// 对服务地址进行编码 加强散列效果
// eg: 127.0.0.1:20880 -> 127.0.0.1:2088000 127.0.0.1:2088001 ... 127.0.0.1:2088039
// 再对这些进行4次位数级别的散列
for (int i = 0; i < replicaNumber / 4; i++) {
byte[] digest = md5(invoker + i);
for (int h = 0; h < 4; h++) {
long m = hash(digest, h);
virtualInvokers.put(m, invoker);
}
}
}
}

/**
* md5编码
* @param key 需要编码的字符串内容
* @return 以字节数组的形式返回编码结果
*/
static byte[] md5(String key) {
MessageDigest md;
try {
md = MessageDigest.getInstance("MD5");
byte[] bytes = key.getBytes(StandardCharsets.UTF_8);
md.update(bytes);
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException(e.getMessage(), e);
}

return md.digest();
}

/**
* 对字节数组内容进行hash散列
* @param digest 传入的字节数组
* @param idx 位数 相当于偏移量
* @return 散列结果
*/
static long hash(byte[] digest, int idx) {
return ((long) (digest[3 + idx * 4] & 255) << 24 | (long) (digest[2 + idx * 4] & 255) << 16 | (long) (digest[1 + idx * 4] & 255) << 8 | (long) (digest[idx * 4] & 255)) & 4294967295L;
}

public String select(String rpcServiceKey) {
byte[] digest = md5(rpcServiceKey);
return selectForKey(hash(digest, 0));
}

public String selectForKey(long hashCode) {
// tailMap() 方法返回一个键大于等于hashCode的Entry列表
Map.Entry<Long, String> entry = virtualInvokers.tailMap(hashCode, true).firstEntry();

// entry 为空说明当前计算的hash值是最大的
if (entry == null) {
entry = virtualInvokers.firstEntry();
}

return entry.getValue();
}
}
}

其中,ConsistentHashSelector#md5() 方法和 ConsistentHashSelector#hash() 方法可以看作是工具方法,分别实现对字符串的编码和哈希值计算

当调用 ConsistentHashLoadBalance#doSelect() 方法时,会先根据我们调用的服务来获取一个选择器

如果选择器不存在则会创建一个选择器,并将选择器保存在 selectors 集合中,最后调用 ConsistentHashSelector#select() 方法来完成结果的选择

我们看以下 select() 方法,它对传进来的值进行 md5 编码后,调用了 selectForKey()

1
2
3
4
5
6
7
8
9
public String selectForKey(long hashCode) {
// tailMap() 方法返回一个键大于等于hashCode的Entry列表
Map.Entry<Long, String> entry = virtualInvokers.tailMap(hashCode, true).firstEntry();
// entry 为空说明当前计算的hash值是最大的
if (entry == null) {
entry = virtualInvokers.firstEntry();
}
return entry.getValue();
}

selectForKey() 中我们会选择 TreeMap 中键值大于等于当前 hashCode 的 Entry 列表,并从列表中返回第一个值来作为负载均衡选择的服务器地址

那么 TreeMap 中的值是从哪儿来的,这我们得看 ConsistentHashSelector 的构造方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
ConsistentHashSelector(List<String> invokers, int replicaNumber, int identityHashCode) {
this.virtualInvokers = new TreeMap<>();
this.identityHashCode = identityHashCode;

// 遍历服务地址
for (String invoker : invokers) {
// 对服务地址进行编码 加强散列效果
// eg: 127.0.0.1:20880 -> 127.0.0.1:2088000 127.0.0.1:2088001 ... 127.0.0.1:2088039
// 再对这些进行4次位数级别的散列
for (int i = 0; i < replicaNumber / 4; i++) {
byte[] digest = md5(invoker + i);
for (int h = 0; h < 4; h++) {
long m = hash(digest, h);
virtualInvokers.put(m, invoker);
}
}
}
}

假设我们服务列表中包含地址 127.0.0.1:20880,传入 replicaNumber 的值为160,那么它会对地址实现拼接,从 127.0.0.1:2088000 一直到 127.0.0.1:2088039 来进行编码,并对编码后的结果实现4次散列,将最后的散列值作为 key,127.0.0.1:20880 作为值存入 TreeMap

@SPI

dubbo的SPI介绍:https://mp.weixin.qq.com/s?__biz=MzAwNDA2OTM1Ng==&mid=2453145662&idx=1&sn=0ba56d58eedca7f04b4d013b84080f31&scene=21#wechat_redirect

SPI (Service Provider Interface)

Dubbo 就依靠 SPI 机制实现了插件化功能,几乎将所有的功能组件做成基于 SPI 实现,并且默认提供了很多可以直接使用的扩展点,实现了面向功能进行拆分的对扩展开放的架构

先贴一个 Java 的 SPI 机制

1
2
3
4
5
6
7
8
public class SPIMain {
public static void main(String[] args) {
ServiceLoader<IShout> shouts = ServiceLoader.load(IShout.class);
for (IShout s : shouts) {
s.shout();
}
}
}

其中有资源文件 org.foo.demo.IShout

1
2
org.foo.demo.animal.Dog
org.foo.demo.animal.Cat

这就是 Java 实现的最简单的 SPI 机制,但是它有缺陷

Java SPI 在查找扩展实现类的时候遍历 SPI 的配置文件并且将实现类全部实例化,假设一个实现类初始化过程比较消耗资源且耗时,但是你的代码里面又用不上它,这就产生了资源的浪费。

Dubbo SPI 除了可以按需加载实现类之外,增加了 IOC 和 AOP 的特性,还有个自适应扩展机制。

我们在自己的 RPC 中使用了简略的 Dubbo SPI 机制,接下来看看代码中的实际用法

以 ServiceDiscovery 为例

1
2
3
4
@SPI
public interface ServiceDiscovery {
InetSocketAddress lookupService(RpcRequest rpcRequest);
}

我们给 ServiceDiscovery 添加了注解 @SPI,在使用过程中可以自定义配置使用其实现类 ExtensionLoader.getExtensionLoader(ServiceDiscovery.class).getExtension("zk");

这是怎么做到的?

我们先看 ExtensionLoader#getExtensionLoader() 这个方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
private static final Map<Class<?>, ExtensionLoader<?>> EXTENSION_LOADERS = new ConcurrentHashMap<>();
public static <S> ExtensionLoader<S> getExtensionLoader(Class<S> type) {
// 做一些参数合法性判断
if (type == null) {
throw new IllegalArgumentException("Extension type should not be null.");
}
if (!type.isInterface()) {
throw new IllegalArgumentException("Extension type must be an interface.");
}
if (type.getAnnotation(SPI.class) == null) {
throw new IllegalArgumentException("Extension type must be annotated by @SPI");
}
// firstly get from cache, if not hit, create one
// 从缓存中获取该类的加载器 如果加载器不存在则自己创建一个
ExtensionLoader<S> extensionLoader = (ExtensionLoader<S>) EXTENSION_LOADERS.get(type);
if (extensionLoader == null) {
EXTENSION_LOADERS.putIfAbsent(type, new ExtensionLoader<S>(type));
extensionLoader = (ExtensionLoader<S>) EXTENSION_LOADERS.get(type);
}
return extensionLoader;
}

通过缓存来获取一个类加载器,如果缓存中不存在,则自己创建一个加载器并缓存

通过这个加载器,我们可以通过 ExtensionLoader#getExtension() 加载实现类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
private final Map<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<>();
public T getExtension(String name) {
// 如果传入名称为空则抛出异常
if (StringUtil.isBlank(name)) {
throw new IllegalArgumentException("Extension name should not be null or empty.");
}
// firstly get from cache, if not hit, create one
// 还是先从缓存中获取一个Holder,如果缓存中没有则自己创建一个存入缓存
Holder<Object> holder = cachedInstances.get(name);
if (holder == null) {
cachedInstances.putIfAbsent(name, new Holder<>());
holder = cachedInstances.get(name);
}
// create a singleton if no instance exists
// 如果holder没有实例 则自己创建一个实例对象存进holder
Object instance = holder.get();
if (instance == null) {
synchronized (holder) {
instance = holder.get();
if (instance == null) {
instance = createExtension(name);
holder.set(instance);
}
}
}
return (T) instance;
}

我们看看如何来创建一个实例对象的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
private static final Map<Class<?>, Object> EXTENSION_INSTANCES = new ConcurrentHashMap<>();
private T createExtension(String name) {
// load all extension classes of type T from file and get specific one by name
// 通过全类名来获取类对象
Class<?> clazz = getExtensionClasses().get(name);
if (clazz == null) {
throw new RuntimeException("No such extension of name " + name);
}
// 从缓存中获取这个类的对象
T instance = (T) EXTENSION_INSTANCES.get(clazz);
// 如果缓存中不包含这个类的实例对象 则自己创建一个存入缓存
if (instance == null) {
try {
EXTENSION_INSTANCES.putIfAbsent(clazz, clazz.newInstance());
instance = (T) EXTENSION_INSTANCES.get(clazz);
} catch (Exception e) {
log.error(e.getMessage());
}
}
return instance;
}

接下来看看是怎么通过全类名来获取类对象的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
private final Holder<Map<String, Class<?>>> cachedClasses = new Holder<>();   
private Map<String, Class<?>> getExtensionClasses() {
// get the loaded extension class from the cache
// 从类加载器中获取一个缓存
Map<String, Class<?>> classes = cachedClasses.get();
// double check
if (classes == null) {
synchronized (cachedClasses) {
// 判断当前类是否已经被加载过 如果没有被加载则加载目录下的内容
classes = cachedClasses.get();
if (classes == null) {
classes = new HashMap<>();
// load all extensions from our extensions directory
loadDirectory(classes);
cachedClasses.set(classes);
}
}
}
return classes;
}

我们看如何通过目录加载类的实例对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
private void loadDirectory(Map<String, Class<?>> extensionClasses) {
// 资源地址
String fileName = ExtensionLoader.SERVICE_DIRECTORY + type.getName();
try {
Enumeration<URL> urls;
ClassLoader classLoader = ExtensionLoader.class.getClassLoader();
urls = classLoader.getResources(fileName);
if (urls != null) {
// 加载目录下的所有资源
while (urls.hasMoreElements()) {
URL resourceUrl = urls.nextElement();
loadResource(extensionClasses, classLoader, resourceUrl);
}
}
} catch (IOException e) {
log.error(e.getMessage());
}
}

private void loadResource(Map<String, Class<?>> extensionClasses, ClassLoader classLoader, URL resourceUrl) {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(resourceUrl.openStream(), UTF_8))) {
String line;
// read every line
while ((line = reader.readLine()) != null) {
// get index of comment
final int ci = line.indexOf('#');
if (ci >= 0) {
// string after # is comment so we ignore it
line = line.substring(0, ci);
}
line = line.trim();
if (line.length() > 0) {
try {
final int ei = line.indexOf('=');
// key 就是我们通过name来获取实例
String name = line.substring(0, ei).trim();
// value 实例的全类名
String clazzName = line.substring(ei + 1).trim();
// our SPI use key-value pair so both of them must not be empty
// 通过加载器来获取类对象 并缓存到Map中
if (name.length() > 0 && clazzName.length() > 0) {
Class<?> clazz = classLoader.loadClass(clazzName);
extensionClasses.put(name, clazz);
}
} catch (ClassNotFoundException e) {
log.error(e.getMessage());
}
}

}
} catch (IOException e) {
log.error(e.getMessage());
}
}

参考 Dubbo 的 SPI 机制:

dubbo的SPI机制

可以看出其实已经省略了很多的功能,如依赖注入、包装、Adaptive等

请求流程

客户端

从之前的动态代理部分,我们已经知道,从请求到返回响应结果这部分,实际上都是代理类在做,也就是 RpcClientProxy#invoke()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
private final RpcRequestTransport transport;
private final RpcServiceConfig config;
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
log.info("invoke method: [{}]", method.getName());
// 构建请求
RpcRequest request = RpcRequest.builder()
.methodName(method.getName())
.interfaceName(method.getDeclaringClass().getName())
.paramTypes(method.getParameterTypes())
.parameters(args)
.requestId(UUID.randomUUID().toString())
.group(config.getGroup())
.version(config.getVersion())
.build();
// 发送请求
CompletableFuture<RpcResponse<Object>> future = (CompletableFuture<RpcResponse<Object>>) transport.sendRequest(request);
// 获取响应
RpcResponse<Object> response = future.get();
// 检查请求和响应的合理
this.check(response, request);
return response.getData();
}

我们看 RpcRequestTransport 接口的实现类 RpcClient 的 sendRequest() 方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
public Object sendRequest(RpcRequest rpcRequest) {
// 用于接收返回结果
CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture<>();
// 选择服务器地址(发现服务)
InetSocketAddress address = serviceDiscovery.lookupService(rpcRequest);
Channel channel = getChannel(address);
if (channel.isActive()) {
unprocessedRequests.put(rpcRequest.getRequestId(), resultFuture);
RpcMessage rpcMessage = RpcMessage.builder().data(rpcRequest)
.codec(SerializationTypeEnum.KYRO.getCode())
.compress(CompressTypeEnum.GZIP.getCode())
.messageType(RpcConstant.REQUEST_TYPE).build();
channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
log.info("client send message: [{}]", rpcMessage);
} else {
future.channel().close();
resultFuture.completeExceptionally(future.cause());
log.error("Send failed:", future.cause());
}
});
} else {
throw new IllegalStateException();
}
return resultFuture;
}

public Channel getChannel(InetSocketAddress address) {
Channel channel = channelProvider.get(address);
if (channel == null) {
channel = doConnect(address);
channelProvider.set(address, channel);
}
return channel;
}

@SneakyThrows
public Channel doConnect(InetSocketAddress inetSocketAddress) {
ChannelFuture channel = bootstrap.connect(inetSocketAddress);
channel.sync();
return channel.channel();
}

其主要做了这几件事:

  1. 根据负载均衡策略获取服务地址
  2. 根据服务地址获取该地址的 Channel
  3. 发送协议并返回一个异步结果

这里采用硬编码的方式来实例化协议对象,后续有待改进

那么我们看发送协议的部分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private final ServiceDiscovery serviceDiscovery;
private final UnprocessedRequests unprocessedRequests;
private final ChannelProvider channelProvider;
private final Bootstrap bootstrap;
private final EventLoopGroup group;
public RpcClient() {
bootstrap = new Bootstrap();
group = new NioEventLoopGroup();
bootstrap.group(group)
.channel(NioSocketChannel.class)
// 设置连接超时时间
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
.handler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
//15秒内没有请求发出 则会发送心跳请求
ch.pipeline().addLast(new IdleStateHandler(0, 5, 0, TimeUnit.SECONDS));
ch.pipeline().addLast(new RpcMessageEncoder());
ch.pipeline().addLast(new RpcMessageDecoder());
ch.pipeline().addLast(new RpcClientHandler());
}
});
this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
this.channelProvider = SingletonFactory.getInstance(ChannelProvider.class);
this.serviceDiscovery = ExtensionLoader.getExtensionLoader(ServiceDiscovery.class).getExtension("zk");
}

可以看到我们的客户端启动类添加了4个通道处理器,其作用分别为 心跳包处理器、协议编码器、协议解码器、结果处理器

前三个不做过多的介绍,在这里我们看第四个处理器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
public class RpcClientHandler extends ChannelInboundHandlerAdapter {

private final UnprocessedRequests unprocessedRequests;
private final RpcClient client;

public RpcClientHandler() {
this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
this.client = SingletonFactory.getInstance(RpcClient.class);
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
try {
log.info("client receive msg: [{}]", msg);
if (msg instanceof RpcMessage) {
RpcMessage tmp = (RpcMessage) msg;
byte messageType = tmp.getMessageType();
if (messageType == RpcConstant.HEARTBEAT_RESPONSE_TYPE) {
log.info("heart [{}]", tmp.getData());
} else if (messageType == RpcConstant.RESPONSE_TYPE) {
RpcResponse<Object> rpcResponse = (RpcResponse<Object>) tmp.getData();
log.info("response: {}", rpcResponse);
unprocessedRequests.complete(rpcResponse);
}
}
} finally {
ReferenceCountUtil.release(msg);
}
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleState state = ((IdleStateEvent) evt).state();
// 未在指定时间内向服务器发送数据 此时发送心跳包
if (state == IdleState.WRITER_IDLE) {
log.info("write idle happen [{}]", ctx.channel().remoteAddress());
Channel channel = client.getChannel((InetSocketAddress) ctx.channel().remoteAddress());
RpcMessage rpcMessage = new RpcMessage();
rpcMessage.setCodec(SerializationTypeEnum.KYRO.getCode());
rpcMessage.setCompress(CompressTypeEnum.GZIP.getCode());
rpcMessage.setMessageType(RpcConstant.HEARTBEAT_REQUEST_TYPE);
rpcMessage.setData(RpcConstant.PING);
channel.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
} else {
super.userEventTriggered(ctx, evt);
}
}

/**
* Called when an exception occurs in processing a client message
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.error("client catch exception:", cause);
cause.printStackTrace();
ctx.close();
}
}

在此主要关注的是 #channelRead() 方法,其主要作用是处理服务端的响应结果

在通过前面解码器的处理后,最后发送到该处理器的内容实际上就是我们自定义的协议

可以看到,在之前发送请求和后面处理服务器的返回结果时,我们都用到了 UnprocessedRequests 这个类,那么它做了什么呢

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public class UnprocessedRequests {
/**
* key: requestId
* value: 异步处理 response 返回的 future
*/
private static final Map<String, CompletableFuture<RpcResponse<Object>>> UNPROCESSED_RESPONSE_FUTURES = new ConcurrentHashMap<>();

public void put(String requestId, CompletableFuture<RpcResponse<Object>> future) {
UNPROCESSED_RESPONSE_FUTURES.put(requestId, future);
}

public void complete(RpcResponse<Object> response) {
CompletableFuture<RpcResponse<Object>> future = UNPROCESSED_RESPONSE_FUTURES.remove(response.getRequestId());
if (future != null) {
future.complete(response);
} else {
throw new IllegalStateException();
}
}
}

当发送请求时,我们会将 requestId 作为 Map 的 key,异步返回结果 CompletableFuture 作为 Map 的 value 存入。

当客户端接收到返回结果时,会将这对 key-value 组合从 Map 中移除,并设置异步结果完成

为什么要使用 CompletableFuture ?主要原因有二

  1. 我们在 RpcClientProxy#invoke() 中使用 CompletableFuture#get() 时,第一时间会阻塞,直到获取到返回结果。而获取返回结果我们通过 UnprocessedRequests#complete() 中的 CompletableFuture#complete() 来实现。这样可以保证客户端能够在服务端返回结果后获取结果内容
  2. 我们在 RpcClientHandler 中获取的 RpcResponse 实际上是在 NIO 线程中获取的,而我们最后返回结果是需要在主线程中返回结果。因此,CompletableFuture 的作用也可以是主线程和 NIO 线程之间数据的传输。

服务端

接下来我们看看服务端的内容

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@Component
public class RpcServer {
public static final int PORT = 9999;

//private final ServiceProvider serviceProvider = SingletonFactory.getInstance(ZkServiceProviderImpl.class);
//
//public void registerService(RpcServiceConfig rpcServiceConfig) {
// serviceProvider.publishService(rpcServiceConfig);
//}

@SneakyThrows
public void start() {
// 清除zk服务
CustomShutdownHook.getCustomShutdownHook().clearAll();
// 指定handler所使用的线程池 如不指定则采用IO线程
DefaultEventExecutorGroup serviceHandlerGroup = new DefaultEventExecutorGroup(
RuntimeUtil.cpus() * 2,
ThreadPoolFactoryUtil.createThreadFactory("service-handler-group", false)
);

NioEventLoopGroup boss = new NioEventLoopGroup(1);
NioEventLoopGroup worker = new NioEventLoopGroup();

String host = InetAddress.getLocalHost().getHostAddress();

try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(boss, worker)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
// 心跳
ch.pipeline().addLast(new IdleStateHandler(30, 0, 0, TimeUnit.SECONDS));
// 协议编码器
ch.pipeline().addLast(new RpcMessageEncoder());
// 协议解码器
ch.pipeline().addLast(new RpcMessageDecoder());
// 处理请求方法并返回结果
ch.pipeline().addLast(serviceHandlerGroup, new RpcServerHandler());
}
});
ChannelFuture future = bootstrap.bind(host, PORT).sync();

future.channel().closeFuture().sync();
} catch (InterruptedException e) {
log.error("occur exception when start server:", e);
} finally {
log.error("shutdown bossGroup and workerGroup");
boss.shutdownGracefully();
worker.shutdownGracefully();
}
}
}

和之前一样,我们关注第四个处理器请求处理器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@Slf4j
public class RpcServerHandler extends ChannelInboundHandlerAdapter {

private final RpcRequestHandler rpcRequestHandler;

public RpcServerHandler() {
this.rpcRequestHandler = SingletonFactory.getInstance(RpcRequestHandler.class);
}


@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
try {
if (msg instanceof RpcMessage) {
log.info("server receive msg: [{}] ", msg);
byte messageType = ((RpcMessage) msg).getMessageType();
RpcMessage responseMessage = new RpcMessage();
responseMessage.setCodec(SerializationTypeEnum.KYRO.getCode());
responseMessage.setCompress(CompressTypeEnum.GZIP.getCode());

if (messageType == RpcConstant.HEARTBEAT_REQUEST_TYPE) {
// 如果时心跳包请求 则返回一个心跳包响应
responseMessage.setMessageType(RpcConstant.HEARTBEAT_RESPONSE_TYPE);
responseMessage.setData(RpcConstant.PONG);
} else {
// 方法请求
responseMessage.setMessageType(RpcConstant.RESPONSE_TYPE);
RpcRequest rpcRequest = (RpcRequest) ((RpcMessage) msg).getData();
// 获取服务来调用方法
Object result = rpcRequestHandler.handle(rpcRequest);

log.info(String.format("server get result: %s", result.toString()));

// 写回结果
if (ctx.channel().isActive() && ctx.channel().isWritable()) {
RpcResponse<Object> rpcResponse = RpcResponse.success(result, rpcRequest.getRequestId());
responseMessage.setData(rpcResponse);
} else {
RpcResponse<Object> rpcResponse = RpcResponse.fail(RpcResponseCodeEnum.FAIL);
responseMessage.setData(rpcResponse);
log.error("not writable now, message dropped");
}
}
// 发送响应
ctx.writeAndFlush(responseMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
} finally {
// Ensure that ByteBuf is released, otherwise there may be memory leaks
// 释放ByteBuf
ReferenceCountUtil.release(msg);
}
}


@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleState state = ((IdleStateEvent) evt).state();
if (state == IdleState.READER_IDLE) {
log.info("idle check happen, so close the connection");
ctx.close();
}
} else {
super.userEventTriggered(ctx, evt);
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.error("server catch exception");
cause.printStackTrace();
ctx.close();
}
}

还是看 channelRead() 方法,主要做了以下几件事

  1. 解析客户端发送过来的协议,判断发送内容是否为心跳包
  2. 如果是心跳包则同样返回一个心跳包回去,否则调用服务来处理方法调用
  3. 将 response 作为消息体写入协议,并发送协议给客户端

看服务端的处理器时,可以看到我们为 RpcServerHandler 指定了一个线程池 serviceHandlerGroup,同样这个线程池在每次服务端启动时 CustomShutdownHook.getCustomShutdownHook().clearAll() 会被清理,网上说这是为了将处理过程交给自定义的线程池来处理,而不是 NIO 线程池,这里有啥好处暂时还不是很清楚。