Browse Source

Make the remote address available on the server interceptor context (#1081)

Motivation:

Sometimes, if you're an interceptor, it's useful to know who you're
talking too.

Modifications:

- Make the 'remoteAddress' available on the 'ServerInterceptorContext'

Result:

Users can access the remote address from server interceptors.
George Barnett 5 years ago
parent
commit
976a14859e

+ 1 - 0
Sources/GRPC/CallHandlers/_BaseCallHandler.swift

@@ -102,6 +102,7 @@ public class _BaseCallHandler<
       eventLoop: callHandlerContext.eventLoop,
       path: callHandlerContext.path,
       callType: callType,
+      remoteAddress: callHandlerContext.remoteAddress,
       userInfoRef: userInfoRef,
       interceptors: interceptors,
       onRequestPart: self._receiveRequestPartFromInterceptors(_:),

+ 2 - 0
Sources/GRPC/GRPCServerRequestRoutingHandler.swift

@@ -50,6 +50,8 @@ public struct CallHandlerContext {
   internal var eventLoop: EventLoop
   @usableFromInline
   internal var path: String
+  @usableFromInline
+  internal var remoteAddress: SocketAddress?
 }
 
 /// A call URI split into components.

+ 1 - 0
Sources/GRPC/HTTP2ToRawGRPCServerCodec.swift

@@ -111,6 +111,7 @@ internal final class HTTP2ToRawGRPCServerCodec: ChannelDuplexHandler {
         headers: payload.headers,
         eventLoop: context.eventLoop,
         errorDelegate: self.errorDelegate,
+        remoteAddress: context.channel.remoteAddress,
         logger: self.logger
       )
       self.act(on: action, with: context)

+ 7 - 1
Sources/GRPC/HTTP2ToRawGRPCStateMachine.swift

@@ -278,6 +278,7 @@ extension HTTP2ToRawGRPCStateMachine.RequestIdleResponseIdleState {
     headers: HPACKHeaders,
     eventLoop: EventLoop,
     errorDelegate: ServerErrorDelegate?,
+    remoteAddress: SocketAddress?,
     logger: Logger
   ) -> HTTP2ToRawGRPCStateMachine.StateAndAction {
     // Extract and validate the content type. If it's nil we need to close.
@@ -323,7 +324,8 @@ extension HTTP2ToRawGRPCStateMachine.RequestIdleResponseIdleState {
       logger: logger,
       encoding: self.encoding,
       eventLoop: eventLoop,
-      path: path
+      path: path,
+      remoteAddress: remoteAddress
     )
 
     // We have a matching service, hopefully we have a provider for the method too.
@@ -883,6 +885,7 @@ extension HTTP2ToRawGRPCStateMachine {
     headers: HPACKHeaders,
     eventLoop: EventLoop,
     errorDelegate: ServerErrorDelegate?,
+    remoteAddress: SocketAddress?,
     logger: Logger
   ) -> Action {
     return self.withStateAvoidingCoWs { state in
@@ -890,6 +893,7 @@ extension HTTP2ToRawGRPCStateMachine {
         headers: headers,
         eventLoop: eventLoop,
         errorDelegate: errorDelegate,
+        remoteAddress: remoteAddress,
         logger: logger
       )
     }
@@ -977,6 +981,7 @@ extension HTTP2ToRawGRPCStateMachine.State {
     headers: HPACKHeaders,
     eventLoop: EventLoop,
     errorDelegate: ServerErrorDelegate?,
+    remoteAddress: SocketAddress?,
     logger: Logger
   ) -> HTTP2ToRawGRPCStateMachine.Action {
     switch self {
@@ -986,6 +991,7 @@ extension HTTP2ToRawGRPCStateMachine.State {
         headers: headers,
         eventLoop: eventLoop,
         errorDelegate: errorDelegate,
+        remoteAddress: remoteAddress,
         logger: logger
       )
       self = stateAndAction.state

+ 5 - 0
Sources/GRPC/Interceptor/ServerInterceptorContext.swift

@@ -61,6 +61,11 @@ public struct ServerInterceptorContext<Request, Response> {
     return self._pipeline.path
   }
 
+  /// The address of the remote peer.
+  public var remoteAddress: SocketAddress? {
+    return self._pipeline.remoteAddress
+  }
+
   /// A 'UserInfo' dictionary.
   ///
   /// - Important: While `UserInfo` has value-semantics, this property retrieves from, and sets a

+ 6 - 0
Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift

@@ -30,6 +30,10 @@ internal final class ServerInterceptorPipeline<Request, Response> {
   @usableFromInline
   internal let type: GRPCCallType
 
+  /// The remote peer's address.
+  @usableFromInline
+  internal let remoteAddress: SocketAddress?
+
   /// A logger.
   @usableFromInline
   internal let logger: Logger
@@ -96,6 +100,7 @@ internal final class ServerInterceptorPipeline<Request, Response> {
     eventLoop: EventLoop,
     path: String,
     callType: GRPCCallType,
+    remoteAddress: SocketAddress?,
     userInfoRef: Ref<UserInfo>,
     interceptors: [ServerInterceptor<Request, Response>],
     onRequestPart: @escaping (GRPCServerRequestPart<Request>) -> Void,
@@ -105,6 +110,7 @@ internal final class ServerInterceptorPipeline<Request, Response> {
     self.eventLoop = eventLoop
     self.path = path
     self.type = callType
+    self.remoteAddress = remoteAddress
     self.userInfoRef = userInfoRef
 
     // We need space for the head and tail as well as any user provided interceptors.

+ 14 - 1
Tests/GRPCTests/HTTP2ToRawGRPCStateMachineTests.swift

@@ -100,6 +100,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.viableHeaders,
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
 
@@ -169,6 +170,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.viableHeaders,
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
     assertThat(action, .is(.configure()))
@@ -179,7 +181,9 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
     let action = machine.receive(
       headers: self.makeHeaders(contentType: "application/json"),
       eventLoop: self.eventLoop,
-      errorDelegate: nil, logger: self.logger
+      errorDelegate: nil,
+      remoteAddress: nil,
+      logger: self.logger
     )
     assertThat(
       action,
@@ -193,6 +197,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.makeHeaders(path: "/foo.Foo/Get"),
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
     assertThat(action, .is(.write(.trailersOnly(code: .unimplemented), flush: true)))
@@ -204,6 +209,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.makeHeaders(path: "/echo.Echo/Foo"),
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
     assertThat(action, .is(.write(.trailersOnly(code: .unimplemented), flush: true)))
@@ -215,6 +221,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.makeHeaders(path: "nope"),
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
     assertThat(action, .is(.write(.trailersOnly(code: .unimplemented), flush: true)))
@@ -226,6 +233,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.makeHeaders(encoding: .gzip),
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
     assertThat(action, .is(.write(.trailersOnly(code: .unimplemented), flush: true)))
@@ -238,6 +246,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.makeHeaders(contentType: "application/grpc", encoding: "gzip,identity"),
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
     assertThat(action, .is(.write(.trailersOnly(code: .invalidArgument), flush: true)))
@@ -250,6 +259,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.makeHeaders(contentType: "application/grpc", encoding: "foozip"),
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
 
@@ -268,6 +278,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.makeHeaders(encoding: .gzip),
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
 
@@ -289,6 +300,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.makeHeaders(encoding: .identity),
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
 
@@ -302,6 +314,7 @@ class HTTP2ToRawGRPCStateMachineTests: GRPCTestCase {
       headers: self.makeHeaders(acceptEncoding: [.deflate]),
       eventLoop: self.eventLoop,
       errorDelegate: nil,
+      remoteAddress: nil,
       logger: self.logger
     )
 

+ 11 - 1
Tests/GRPCTests/InterceptorsTests.swift

@@ -161,6 +161,16 @@ private class HelloWorldClientInterceptorFactory:
   }
 }
 
+class RemoteAddressExistsInterceptor<Request, Response>: ServerInterceptor<Request, Response> {
+  override func receive(
+    _ part: GRPCServerRequestPart<Request>,
+    context: ServerInterceptorContext<Request, Response>
+  ) {
+    XCTAssertNotNil(context.remoteAddress)
+    super.receive(part, context: context)
+  }
+}
+
 class NotReallyAuthServerInterceptor<Request: Message, Response: Message>:
   ServerInterceptor<Request, Response> {
   override func receive(
@@ -188,7 +198,7 @@ class NotReallyAuthServerInterceptor<Request: Message, Response: Message>:
 class HelloWorldServerInterceptorFactory: Helloworld_GreeterServerInterceptorFactoryProtocol {
   func makeSayHelloInterceptors(
   ) -> [ServerInterceptor<Helloworld_HelloRequest, Helloworld_HelloReply>] {
-    return [NotReallyAuthServerInterceptor()]
+    return [RemoteAddressExistsInterceptor(), NotReallyAuthServerInterceptor()]
   }
 }
 

+ 1 - 0
Tests/GRPCTests/ServerInterceptorPipelineTests.swift

@@ -40,6 +40,7 @@ class ServerInterceptorPipelineTests: GRPCTestCase {
       eventLoop: self.embeddedEventLoop,
       path: path,
       callType: callType,
+      remoteAddress: nil,
       userInfoRef: Ref(UserInfo()),
       interceptors: interceptors,
       onRequestPart: onRequestPart,