Browse Source

Add a 'UserInfo' heterotyped dictionary (#1031)

Motivation:

Interceptors may be used to provide additional information to the
service provider. This could include information about an authenticated
user, for example. However, we don't currently have such a mechanism.

Modifications:

- Add a type-safe 'UserInfo' dictionary
- Add a 'UserInfo' requirement to the 'ServerCallContext' protocol
- Store a 'Ref<UserInfo>' in the base call handler and pipeline,
  exposing the 'UserInfo' in both the server call context and server
  interceptor context.

Result:

Interceptors can share information in a type-safe way with the service
provider.
George Barnett 5 years ago
parent
commit
224e4b3ef9

+ 1 - 0
Sources/GRPC/CallHandlers/BidirectionalStreamingCallHandler.swift

@@ -132,6 +132,7 @@ public class BidirectionalStreamingCallHandler<
         eventLoop: self.eventLoop,
         headers: headers,
         logger: self.logger,
+        userInfoRef: self.userInfoRef,
         sendResponse: self.sendResponse(_:metadata:promise:)
       )
       let observer = factory(context)

+ 2 - 1
Sources/GRPC/CallHandlers/ClientStreamingCallHandler.swift

@@ -132,7 +132,8 @@ public final class ClientStreamingCallHandler<
       let context = UnaryResponseCallContext<ResponsePayload>(
         eventLoop: self.eventLoop,
         headers: headers,
-        logger: self.logger
+        logger: self.logger,
+        userInfoRef: self.userInfoRef
       )
 
       let observer = factory(context)

+ 1 - 0
Sources/GRPC/CallHandlers/ServerStreamingCallHandler.swift

@@ -137,6 +137,7 @@ public final class ServerStreamingCallHandler<
         eventLoop: self.eventLoop,
         headers: headers,
         logger: self.logger,
+        userInfoRef: self.userInfoRef,
         sendResponse: self.sendResponse(_:metadata:promise:)
       )
       let observer = factory(context)

+ 2 - 1
Sources/GRPC/CallHandlers/UnaryCallHandler.swift

@@ -163,7 +163,8 @@ public final class UnaryCallHandler<
       let context = UnaryResponseCallContext<ResponsePayload>(
         eventLoop: self.eventLoop,
         headers: headers,
-        logger: self.logger
+        logger: self.logger,
+        userInfoRef: self.userInfoRef
       )
       let observer = factory(context)
 

+ 7 - 0
Sources/GRPC/CallHandlers/_BaseCallHandler.swift

@@ -56,20 +56,27 @@ public class _BaseCallHandler<Request, Response>: GRPCCallHandler, ChannelInboun
     return self.callHandlerContext.logger
   }
 
+  /// A reference to `UserInfo`.
+  internal var userInfoRef: Ref<UserInfo>
+
   internal init(
     callHandlerContext: CallHandlerContext,
     codec: ChannelHandler,
     callType: GRPCCallType,
     interceptors: [ServerInterceptor<Request, Response>]
   ) {
+    let userInfoRef = Ref(UserInfo())
+
     self.callHandlerContext = callHandlerContext
     self._codec = codec
     self.callType = callType
+    self.userInfoRef = userInfoRef
     self.pipeline = ServerInterceptorPipeline(
       logger: callHandlerContext.logger,
       eventLoop: callHandlerContext.eventLoop,
       path: callHandlerContext.path,
       callType: callType,
+      userInfoRef: userInfoRef,
       interceptors: interceptors,
       onRequestPart: self.receiveRequestPartFromInterceptors(_:),
       onResponsePart: self.sendResponsePartFromInterceptors(_:promise:)

+ 17 - 0
Sources/GRPC/Interceptor/ServerInterceptorContext.swift

@@ -56,6 +56,23 @@ public struct ServerInterceptorContext<Request, Response> {
     return self.pipeline.path
   }
 
+  /// A 'UserInfo' dictionary.
+  ///
+  /// - Important: While `UserInfo` has value-semantics, this property retrieves from, and sets a
+  ///   reference wrapped `UserInfo`. The contexts passed to the service provider share the same
+  ///   reference. As such this may be used as a mechanism to pass information between interceptors
+  ///   and service providers.
+  /// - Important: `userInfo` *must* be accessed from the context's `eventLoop` in order to ensure
+  ///   thread-safety.
+  public var userInfo: UserInfo {
+    get {
+      return self.pipeline.userInfoRef.value
+    }
+    nonmutating set {
+      self.pipeline.userInfoRef.value = newValue
+    }
+  }
+
   /// Construct a `ServerInterceptorContext` for the interceptor at the given index within the
   /// interceptor pipeline.
   internal init(

+ 5 - 0
Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift

@@ -29,6 +29,9 @@ internal final class ServerInterceptorPipeline<Request, Response> {
   /// A logger.
   internal let logger: Logger
 
+  /// A reference to a 'UserInfo'.
+  internal let userInfoRef: Ref<UserInfo>
+
   /// The contexts associated with the interceptors stored in this pipeline. Contexts will be
   /// removed once the RPC has completed. Contexts are ordered from inbound to outbound, that is,
   /// the head is first and the tail is last.
@@ -80,6 +83,7 @@ internal final class ServerInterceptorPipeline<Request, Response> {
     eventLoop: EventLoop,
     path: String,
     callType: GRPCCallType,
+    userInfoRef: Ref<UserInfo>,
     interceptors: [ServerInterceptor<Request, Response>],
     onRequestPart: @escaping (GRPCServerRequestPart<Request>) -> Void,
     onResponsePart: @escaping (GRPCServerResponsePart<Response>, EventLoopPromise<Void>?) -> Void
@@ -88,6 +92,7 @@ internal final class ServerInterceptorPipeline<Request, Response> {
     self.eventLoop = eventLoop
     self.path = path
     self.type = callType
+    self.userInfoRef = userInfoRef
 
     // We need space for the head and tail as well as any user provided interceptors.
     var contexts: [ServerInterceptorContext<Request, Response>] = []

+ 22 - 0
Sources/GRPC/Ref.swift

@@ -0,0 +1,22 @@
+/*
+ * Copyright 2020, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+internal final class Ref<Value> {
+  internal var value: Value
+  internal init(_ value: Value) {
+    self.value = value
+  }
+}

+ 37 - 2
Sources/GRPC/ServerCallContexts/ServerCallContext.swift

@@ -28,6 +28,9 @@ public protocol ServerCallContext: AnyObject {
   /// Request headers for this request.
   var headers: HPACKHeaders { get }
 
+  /// A 'UserInfo' dictionary.
+  var userInfo: UserInfo { get set }
+
   /// The logger used for this call.
   var logger: Logger { get }
 
@@ -44,21 +47,53 @@ open class ServerCallContextBase: ServerCallContext {
   public let logger: Logger
   public var compressionEnabled: Bool = true
 
+  /// - Important: While `UserInfo` has value-semantics, this property retrieves from, and sets a
+  ///   reference wrapped `UserInfo`. The contexts passed to interceptors provide the same
+  ///   reference. As such this may be used as a mechanism to pass information between interceptors
+  ///   and service providers.
+  public var userInfo: UserInfo {
+    get {
+      return self.userInfoRef.value
+    }
+    set {
+      self.userInfoRef.value = newValue
+    }
+  }
+
+  /// A reference to an underlying `UserInfo`. We share this with the interceptors.
+  private let userInfoRef: Ref<UserInfo>
+
   /// Metadata to return at the end of the RPC. If this is required it should be updated before
   /// the `responsePromise` or `statusPromise` is fulfilled.
   public var trailers = HPACKHeaders()
 
-  public init(eventLoop: EventLoop, headers: HPACKHeaders, logger: Logger) {
+  public convenience init(
+    eventLoop: EventLoop,
+    headers: HPACKHeaders,
+    logger: Logger,
+    userInfo: UserInfo = UserInfo()
+  ) {
+    self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
+  }
+
+  internal init(
+    eventLoop: EventLoop,
+    headers: HPACKHeaders,
+    logger: Logger,
+    userInfoRef: Ref<UserInfo>
+  ) {
     self.eventLoop = eventLoop
     self.headers = headers
+    self.userInfoRef = userInfoRef
     self.logger = logger
   }
 
-  @available(*, deprecated, renamed: "init(eventLoop:headers:logger:)")
+  @available(*, deprecated, renamed: "init(eventLoop:headers:logger:userInfo:)")
   public init(eventLoop: EventLoop, request: HTTPRequestHead, logger: Logger) {
     self.eventLoop = eventLoop
     self.headers = HPACKHeaders(httpHeaders: request.headers, normalizeHTTPHeaders: false)
     self.logger = logger
+    self.userInfoRef = .init(UserInfo())
   }
 
   /// Processes an error, transforming it into a 'GRPCStatus' and any trailers to send to the peer.

+ 25 - 5
Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift

@@ -31,12 +31,26 @@ open class StreamingResponseCallContext<ResponsePayload>: ServerCallContextBase
 
   public let statusPromise: EventLoopPromise<GRPCStatus>
 
-  override public init(eventLoop: EventLoop, headers: HPACKHeaders, logger: Logger) {
+  public convenience init(
+    eventLoop: EventLoop,
+    headers: HPACKHeaders,
+    logger: Logger,
+    userInfo: UserInfo = UserInfo()
+  ) {
+    self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
+  }
+
+  override internal init(
+    eventLoop: EventLoop,
+    headers: HPACKHeaders,
+    logger: Logger,
+    userInfoRef: Ref<UserInfo>
+  ) {
     self.statusPromise = eventLoop.makePromise()
-    super.init(eventLoop: eventLoop, headers: headers, logger: logger)
+    super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
   }
 
-  @available(*, deprecated, renamed: "init(eventLoop:path:headers:logger:)")
+  @available(*, deprecated, renamed: "init(eventLoop:path:headers:logger:userInfo:)")
   override public init(eventLoop: EventLoop, request: HTTPRequestHead, logger: Logger) {
     self.statusPromise = eventLoop.makePromise()
     super.init(eventLoop: eventLoop, request: request, logger: logger)
@@ -113,10 +127,11 @@ internal final class _StreamingResponseCallContext<Request, Response>:
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     logger: Logger,
+    userInfoRef: Ref<UserInfo>,
     sendResponse: @escaping (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
   ) {
     self._sendResponse = sendResponse
-    super.init(eventLoop: eventLoop, headers: headers, logger: logger)
+    super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
   }
 
   override func sendResponse(
@@ -165,7 +180,12 @@ open class StreamingResponseCallContextImpl<ResponsePayload>: StreamingResponseC
     logger: Logger
   ) {
     self.channel = channel
-    super.init(eventLoop: channel.eventLoop, headers: headers, logger: logger)
+    super.init(
+      eventLoop: channel.eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: Ref(UserInfo())
+    )
 
     self.statusPromise.futureResult.whenComplete { result in
       switch result {

+ 23 - 4
Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift

@@ -35,12 +35,26 @@ open class UnaryResponseCallContext<ResponsePayload>: ServerCallContextBase, Sta
   public let responsePromise: EventLoopPromise<ResponsePayload>
   public var responseStatus: GRPCStatus = .ok
 
-  override public init(eventLoop: EventLoop, headers: HPACKHeaders, logger: Logger) {
+  public convenience init(
+    eventLoop: EventLoop,
+    headers: HPACKHeaders,
+    logger: Logger,
+    userInfo: UserInfo = UserInfo()
+  ) {
+    self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
+  }
+
+  override internal init(
+    eventLoop: EventLoop,
+    headers: HPACKHeaders,
+    logger: Logger,
+    userInfoRef: Ref<UserInfo>
+  ) {
     self.responsePromise = eventLoop.makePromise()
-    super.init(eventLoop: eventLoop, headers: headers, logger: logger)
+    super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
   }
 
-  @available(*, deprecated, renamed: "init(eventLoop:headers:logger:)")
+  @available(*, deprecated, renamed: "init(eventLoop:headers:logger:userInfo:)")
   override public init(eventLoop: EventLoop, request: HTTPRequestHead, logger: Logger) {
     self.responsePromise = eventLoop.makePromise()
     super.init(eventLoop: eventLoop, request: request, logger: logger)
@@ -90,7 +104,12 @@ open class UnaryResponseCallContextImpl<ResponsePayload>: UnaryResponseCallConte
     logger: Logger
   ) {
     self.channel = channel
-    super.init(eventLoop: channel.eventLoop, headers: headers, logger: logger)
+    super.init(
+      eventLoop: channel.eventLoop,
+      headers: headers,
+      logger: logger,
+      userInfoRef: .init(UserInfo())
+    )
 
     self.responsePromise.futureResult.whenComplete { [self, weak errorDelegate] result in
       switch result {

+ 106 - 0
Sources/GRPC/UserInfo.swift

@@ -0,0 +1,106 @@
+/*
+ * Copyright 2020, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/// `UserInfo` is a dictionary for heterogeneously typed values with type safe access to the stored
+/// values.
+///
+/// Values are keyed by a type conforming to the `UserInfo.Key` protocol. The protocol requires an
+/// `associatedtype`: the type of the value the key is paired with. A key can be created using a
+/// caseless `enum`, for example:
+///
+/// ```
+/// enum IDKey: UserInfo.Key {
+///   typealias Value = Int
+/// }
+/// ```
+///
+/// Values can be set and retrieved from `UserInfo` by subscripting with the key:
+///
+/// ```
+/// userInfo[IDKey.self] = 42
+/// let id = userInfo[IDKey.self]  // id = 42
+///
+/// userInfo[IDKey.self] = nil
+/// ```
+///
+/// More convenient access can be provided with helper extensions on `UserInfo`:
+///
+/// ```
+/// extension UserInfo {
+///   var id: IDKey.Value? {
+///     get { self[IDKey.self] }
+///     set { self[IDKey.self] = newValue }
+///   }
+/// }
+/// ```
+public struct UserInfo: CustomStringConvertible {
+  private var storage: [AnyUserInfoKey: Any]
+
+  /// A protocol for a key.
+  public typealias Key = UserInfoKey
+
+  /// Create an empty 'UserInfo'.
+  public init() {
+    self.storage = [:]
+  }
+
+  /// Allows values to be set and retrieved in a type safe way.
+  public subscript<Key: UserInfoKey>(key: Key.Type) -> Key.Value? {
+    get {
+      if let anyValue = self.storage[AnyUserInfoKey(key)] {
+        // The types must line up here.
+        return (anyValue as! Key.Value)
+      } else {
+        return nil
+      }
+    }
+    set {
+      self.storage[AnyUserInfoKey(key)] = newValue
+    }
+  }
+
+  public var description: String {
+    return "[" + self.storage.map { key, value in
+      "\(key): \(value)"
+    }.joined(separator: ", ") + "]"
+  }
+
+  /// A `UserInfoKey` wrapper.
+  private struct AnyUserInfoKey: Hashable, CustomStringConvertible {
+    private let keyType: Any.Type
+
+    var description: String {
+      return String(describing: self.keyType.self)
+    }
+
+    init<Key: UserInfoKey>(_ keyType: Key.Type) {
+      self.keyType = keyType
+    }
+
+    static func == (lhs: AnyUserInfoKey, rhs: AnyUserInfoKey) -> Bool {
+      return ObjectIdentifier(lhs.keyType) == ObjectIdentifier(rhs.keyType)
+    }
+
+    func hash(into hasher: inout Hasher) {
+      hasher.combine(ObjectIdentifier(self.keyType))
+    }
+  }
+}
+
+public protocol UserInfoKey {
+  /// The type of identified by this key.
+  associatedtype Value
+}

+ 20 - 1
Tests/GRPCTests/InterceptorsTests.swift

@@ -137,6 +137,9 @@ class HelloWorldProvider: Helloworld_GreeterProvider {
     request: Helloworld_HelloRequest,
     context: StatusOnlyCallContext
   ) -> EventLoopFuture<Helloworld_HelloReply> {
+    // Since we're auth'd, the 'userInfo' should have some magic set.
+    assertThat(context.userInfo.magic, .is("Magic"))
+
     let response = Helloworld_HelloReply.with {
       $0.message = "Hello, \(request.name), you're authorized!"
     }
@@ -166,7 +169,8 @@ class NotReallyAuthServerInterceptor<Request: Message, Response: Message>:
   ) {
     switch part {
     case let .metadata(headers):
-      if headers.first(name: "authorization") == "Magic" {
+      if let auth = headers.first(name: "authorization"), auth == "Magic" {
+        context.userInfo.magic = auth
         context.receive(part)
       } else {
         // Not auth'd. Fail the RPC.
@@ -340,3 +344,18 @@ private class ReversingInterceptors: Echo_EchoClientInterceptorFactoryProtocol {
     return self.interceptors
   }
 }
+
+private enum MagicKey: UserInfo.Key {
+  typealias Value = String
+}
+
+extension UserInfo {
+  fileprivate var magic: MagicKey.Value? {
+    get {
+      return self[MagicKey.self]
+    }
+    set {
+      self[MagicKey.self] = newValue
+    }
+  }
+}

+ 1 - 0
Tests/GRPCTests/ServerInterceptorPipelineTests.swift

@@ -40,6 +40,7 @@ class ServerInterceptorPipelineTests: GRPCTestCase {
       eventLoop: self.embeddedEventLoop,
       path: path,
       callType: callType,
+      userInfoRef: Ref(UserInfo()),
       interceptors: interceptors,
       onRequestPart: onRequestPart,
       onResponsePart: onResponsePart

+ 87 - 0
Tests/GRPCTests/UserInfoTests.swift

@@ -0,0 +1,87 @@
+/*
+ * Copyright 2020, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import GRPC
+
+class UserInfoTests: GRPCTestCase {
+  func testWithSubscript() {
+    var userInfo = UserInfo()
+
+    userInfo[FooKey.self] = "foo"
+    assertThat(userInfo[FooKey.self], .is("foo"))
+
+    userInfo[BarKey.self] = 42
+    assertThat(userInfo[BarKey.self], .is(42))
+
+    userInfo[FooKey.self] = nil
+    assertThat(userInfo[FooKey.self], .is(.nil()))
+
+    userInfo[BarKey.self] = nil
+    assertThat(userInfo[BarKey.self], .is(.nil()))
+  }
+
+  func testWithExtensions() {
+    var userInfo = UserInfo()
+
+    userInfo.foo = "foo"
+    assertThat(userInfo.foo, .is("foo"))
+
+    userInfo.bar = 42
+    assertThat(userInfo.bar, .is(42))
+
+    userInfo.foo = nil
+    assertThat(userInfo.foo, .is(.nil()))
+
+    userInfo.bar = nil
+    assertThat(userInfo.bar, .is(.nil()))
+  }
+
+  func testDescription() {
+    var userInfo = UserInfo()
+    assertThat(String(describing: userInfo), .is("[]"))
+
+    // (We can't test with multiple values since ordering isn't stable.)
+    userInfo.foo = "foo"
+    assertThat(String(describing: userInfo), .is("[FooKey: foo]"))
+  }
+}
+
+private enum FooKey: UserInfoKey {
+  typealias Value = String
+}
+
+private enum BarKey: UserInfoKey {
+  typealias Value = Int
+}
+
+extension UserInfo {
+  fileprivate var foo: FooKey.Value? {
+    get {
+      return self[FooKey.self]
+    }
+    set {
+      self[FooKey.self] = newValue
+    }
+  }
+
+  fileprivate var bar: BarKey.Value? {
+    get {
+      return self[BarKey.self]
+    }
+    set {
+      self[BarKey.self] = newValue
+    }
+  }
+}

+ 12 - 0
Tests/GRPCTests/XCTestManifests.swift

@@ -1039,6 +1039,17 @@ extension TimeLimitTests {
     ]
 }
 
+extension UserInfoTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__UserInfoTests = [
+        ("testDescription", testDescription),
+        ("testWithExtensions", testWithExtensions),
+        ("testWithSubscript", testWithSubscript),
+    ]
+}
+
 extension ZeroLengthWriteTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -1138,6 +1149,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(StopwatchTests.__allTests__StopwatchTests),
         testCase(StreamingRequestClientCallTests.__allTests__StreamingRequestClientCallTests),
         testCase(TimeLimitTests.__allTests__TimeLimitTests),
+        testCase(UserInfoTests.__allTests__UserInfoTests),
         testCase(ZeroLengthWriteTests.__allTests__ZeroLengthWriteTests),
         testCase(ZlibTests.__allTests__ZlibTests),
     ]