Browse Source

Convert errors thrown from interceptors (#2209)

Motivation:

gRPC checks whether errors thrown from interceptors are `RPCError` and
otherwise treats them as `unknown` (to avoid leaking internal
information). There is a third possibility: the error is explicitly
marked as being convertible to an `RPCError`. This check is currently
missing when thrown from client/server interceptors.

Modifications:

- Catch `RPCErrorConvertible` in the client/server executors when thrown
from interceptors
- Add tests

Result:

Error information isn't dropped
George Barnett 10 months ago
parent
commit
c4d6281784

+ 2 - 0
Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor.swift

@@ -186,6 +186,8 @@ extension ClientRPCExecutor {
         }
       } catch let error as RPCError {
         return StreamingClientResponse(error: error)
+      } catch let error as RPCErrorConvertible {
+        return StreamingClientResponse(error: RPCError(error))
       } catch let other {
         let error = RPCError(code: .unknown, message: "", cause: other)
         return StreamingClientResponse(error: error)

+ 2 - 0
Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift

@@ -330,6 +330,8 @@ extension ServerRPCExecutor {
         }
       } catch let error as RPCError {
         return StreamingServerResponse(error: error)
+      } catch let error as RPCErrorConvertible {
+        return StreamingServerResponse(error: RPCError(error))
       } catch let other {
         let error = RPCError(code: .unknown, message: "", cause: other)
         return StreamingServerResponse(error: error)

+ 8 - 2
Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift

@@ -29,6 +29,7 @@ struct ClientRPCExecutorTestHarness {
   private let server: ServerStreamHandler
   private let clientTransport: StreamCountingClientTransport
   private let serverTransport: StreamCountingServerTransport
+  private let interceptors: [any ClientInterceptor]
 
   var clientStreamsOpened: Int {
     self.clientTransport.streamsOpened
@@ -42,8 +43,13 @@ struct ClientRPCExecutorTestHarness {
     self.serverTransport.acceptedStreamsCount
   }
 
-  init(transport: Transport = .inProcess, server: ServerStreamHandler) {
+  init(
+    transport: Transport = .inProcess,
+    server: ServerStreamHandler,
+    interceptors: [any ClientInterceptor] = []
+  ) {
     self.server = server
+    self.interceptors = interceptors
 
     switch transport {
     case .inProcess:
@@ -141,7 +147,7 @@ struct ClientRPCExecutorTestHarness {
         serializer: serializer,
         deserializer: deserializer,
         transport: self.clientTransport,
-        interceptors: [],
+        interceptors: self.interceptors,
         handler: handler
       )
 

+ 21 - 0
Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift

@@ -268,4 +268,25 @@ final class ClientRPCExecutorTests: XCTestCase {
       }
     }
   }
+
+  func testInterceptorErrorConversion() async throws {
+    struct CustomError: RPCErrorConvertible, Error {
+      var rpcErrorCode: RPCError.Code { .alreadyExists }
+      var rpcErrorMessage: String { "foobar" }
+      var rpcErrorMetadata: Metadata { ["error": "yes"] }
+    }
+
+    let tester = ClientRPCExecutorTestHarness(
+      server: .echo,
+      interceptors: [.throwError(CustomError())]
+    )
+
+    try await tester.unary(request: ClientRequest(message: [])) { response in
+      XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in
+        XCTAssertEqual(error.code, .alreadyExists)
+        XCTAssertEqual(error.message, "foobar")
+        XCTAssertEqual(error.metadata, ["error": "yes"])
+      }
+    }
+  }
 }

+ 19 - 0
Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift

@@ -374,4 +374,23 @@ final class ServerRPCExecutorTests: XCTestCase {
       )
     }
   }
+
+  func testInterceptorErrorConversion() async throws {
+    struct CustomError: RPCErrorConvertible, Error {
+      var rpcErrorCode: RPCError.Code { .alreadyExists }
+      var rpcErrorMessage: String { "foobar" }
+      var rpcErrorMetadata: Metadata { ["error": "yes"] }
+    }
+
+    let harness = ServerRPCExecutorTestHarness(interceptors: [.throwError(CustomError())])
+    try await harness.execute(handler: .throwing(CustomError())) { inbound in
+      try await inbound.write(.metadata(["foo": "bar"]))
+      await inbound.finish()
+    } consumer: { outbound in
+      let parts = try await outbound.collect()
+      let status = Status(code: .alreadyExists, message: "foobar")
+      let metadata: Metadata = ["error": "yes"]
+      XCTAssertEqual(parts, [.status(status, metadata)])
+    }
+  }
 }

+ 22 - 15
Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift

@@ -18,11 +18,11 @@ import GRPCCore
 
 extension ClientInterceptor where Self == RejectAllClientInterceptor {
   static func rejectAll(with error: RPCError) -> Self {
-    return RejectAllClientInterceptor(error: error, throw: false)
+    return RejectAllClientInterceptor(reject: error)
   }
 
-  static func throwError(_ error: RPCError) -> Self {
-    return RejectAllClientInterceptor(error: error, throw: true)
+  static func throwError(_ error: any Error) -> Self {
+    return RejectAllClientInterceptor(throw: error)
   }
 
 }
@@ -35,15 +35,21 @@ extension ClientInterceptor where Self == RequestCountingClientInterceptor {
 
 /// Rejects all RPCs with the provided error.
 struct RejectAllClientInterceptor: ClientInterceptor {
-  /// The error to reject all RPCs with.
-  let error: RPCError
-  /// Whether the error should be thrown. If `false` then the request is rejected with the error
-  /// instead.
-  let `throw`: Bool
+  enum Mode: Sendable {
+    /// Throw the error rather.
+    case `throw`(any Error)
+    /// Reject the RPC with a given error.
+    case reject(RPCError)
+  }
+
+  let mode: Mode
+
+  init(throw error: any Error) {
+    self.mode = .throw(error)
+  }
 
-  init(error: RPCError, throw: Bool = false) {
-    self.error = error
-    self.`throw` = `throw`
+  init(reject error: RPCError) {
+    self.mode = .reject(error)
   }
 
   func intercept<Input: Sendable, Output: Sendable>(
@@ -54,10 +60,11 @@ struct RejectAllClientInterceptor: ClientInterceptor {
       ClientContext
     ) async throws -> StreamingClientResponse<Output>
   ) async throws -> StreamingClientResponse<Output> {
-    if self.throw {
-      throw self.error
-    } else {
-      return StreamingClientResponse(error: self.error)
+    switch self.mode {
+    case .throw(let error):
+      throw error
+    case .reject(let error):
+      return StreamingClientResponse(error: error)
     }
   }
 }

+ 22 - 15
Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift

@@ -18,11 +18,11 @@ import GRPCCore
 
 extension ServerInterceptor where Self == RejectAllServerInterceptor {
   static func rejectAll(with error: RPCError) -> Self {
-    return RejectAllServerInterceptor(error: error, throw: false)
+    return RejectAllServerInterceptor(reject: error)
   }
 
-  static func throwError(_ error: RPCError) -> Self {
-    RejectAllServerInterceptor(error: error, throw: true)
+  static func throwError(_ error: any Error) -> Self {
+    RejectAllServerInterceptor(throw: error)
   }
 }
 
@@ -34,15 +34,21 @@ extension ServerInterceptor where Self == RequestCountingServerInterceptor {
 
 /// Rejects all RPCs with the provided error.
 struct RejectAllServerInterceptor: ServerInterceptor {
-  /// The error to reject all RPCs with.
-  let error: RPCError
-  /// Whether the error should be thrown. If `false` then the request is rejected with the error
-  /// instead.
-  let `throw`: Bool
+  enum Mode: Sendable {
+    /// Throw the error rather.
+    case `throw`(any Error)
+    /// Reject the RPC with a given error.
+    case reject(RPCError)
+  }
+
+  let mode: Mode
+
+  init(throw error: any Error) {
+    self.mode = .throw(error)
+  }
 
-  init(error: RPCError, throw: Bool = false) {
-    self.error = error
-    self.`throw` = `throw`
+  init(reject error: RPCError) {
+    self.mode = .reject(error)
   }
 
   func intercept<Input: Sendable, Output: Sendable>(
@@ -53,10 +59,11 @@ struct RejectAllServerInterceptor: ServerInterceptor {
       ServerContext
     ) async throws -> StreamingServerResponse<Output>
   ) async throws -> StreamingServerResponse<Output> {
-    if self.throw {
-      throw self.error
-    } else {
-      return StreamingServerResponse(error: self.error)
+    switch self.mode {
+    case .throw(let error):
+      throw error
+    case .reject(let error):
+      return StreamingServerResponse(error: error)
     }
   }
 }