Kaynağa Gözat

Flatten `RPCError` causes if they're also `RPCError`s with the same code (#2083)

## Motivation
For errors happening deep in the task tree, we'd wrap them in many
layers of `RPCError`s. This isn't particularly nice.

## Modifications
This PR changes the behaviour of the `RPCError` initialiser to flatten
the cause as long as it's an `RPCError` with the same status code as the
wrapping error.

## Result
Friendlier errors.
Gus Cairo 1 yıl önce
ebeveyn
işleme
ce25e2ce30

+ 50 - 6
Sources/GRPCCore/RPCError.swift

@@ -35,18 +35,62 @@ public struct RPCError: Sendable, Hashable, Error {
   /// The original error which led to this error being thrown.
   public var cause: (any Error)?
 
-  /// Create a new RPC error.
+  /// Create a new RPC error. If the given `cause` is also an ``RPCError`` sharing the same `code`,
+  /// then they will be flattened into a single error, by merging the messages and metadata.
   ///
   /// - Parameters:
   ///   - code: The status code.
   ///   - message: A message providing additional context about the code.
   ///   - metadata: Any metadata to attach to the error.
   ///   - cause: An underlying error which led to this error being thrown.
-  public init(code: Code, message: String, metadata: Metadata = [:], cause: (any Error)? = nil) {
-    self.code = code
-    self.message = message
-    self.metadata = metadata
-    self.cause = cause
+  public init(
+    code: Code,
+    message: String,
+    metadata: Metadata = [:],
+    cause: (any Error)? = nil
+  ) {
+    if let rpcErrorCause = cause as? RPCError {
+      self = .init(
+        code: code,
+        message: message,
+        metadata: metadata,
+        cause: rpcErrorCause
+      )
+    } else {
+      self.code = code
+      self.message = message
+      self.metadata = metadata
+      self.cause = cause
+    }
+  }
+
+  /// Create a new RPC error. If the given `cause` shares the same `code`, then it will be flattened
+  /// into a single error, by merging the messages and metadata.
+  ///
+  /// - Parameters:
+  ///   - code: The status code.
+  ///   - message: A message providing additional context about the code.
+  ///   - metadata: Any metadata to attach to the error.
+  ///   - cause: An underlying ``RPCError`` which led to this error being thrown.
+  public init(
+    code: Code,
+    message: String,
+    metadata: Metadata = [:],
+    cause: RPCError
+  ) {
+    if cause.code == code {
+      self.code = code
+      self.message = message + " \(cause.message)"
+      var mergedMetadata = metadata
+      mergedMetadata.add(contentsOf: cause.metadata)
+      self.metadata = mergedMetadata
+      self.cause = cause.cause
+    } else {
+      self.code = code
+      self.message = message
+      self.metadata = metadata
+      self.cause = cause
+    }
   }
 
   /// Create a new RPC error from the provided ``Status``.

+ 137 - 72
Tests/GRPCCoreTests/RPCErrorTests.swift

@@ -14,114 +14,179 @@
  * limitations under the License.
  */
 import GRPCCore
-import XCTest
-
-final class RPCErrorTests: XCTestCase {
-  private static let statusCodeRawValue: [(RPCError.Code, Int)] = [
-    (.cancelled, 1),
-    (.unknown, 2),
-    (.invalidArgument, 3),
-    (.deadlineExceeded, 4),
-    (.notFound, 5),
-    (.alreadyExists, 6),
-    (.permissionDenied, 7),
-    (.resourceExhausted, 8),
-    (.failedPrecondition, 9),
-    (.aborted, 10),
-    (.outOfRange, 11),
-    (.unimplemented, 12),
-    (.internalError, 13),
-    (.unavailable, 14),
-    (.dataLoss, 15),
-    (.unauthenticated, 16),
-  ]
+import Testing
 
+@Suite("RPCError Tests")
+struct RPCErrorTests {
+  @Test("Custom String Convertible")
   func testCustomStringConvertible() {
-    XCTAssertDescription(RPCError(code: .dataLoss, message: ""), #"dataLoss: """#)
-    XCTAssertDescription(RPCError(code: .unknown, message: "message"), #"unknown: "message""#)
-    XCTAssertDescription(RPCError(code: .aborted, message: "message"), #"aborted: "message""#)
+    #expect(String(describing: RPCError(code: .dataLoss, message: "")) == #"dataLoss: """#)
+    #expect(
+      String(describing: RPCError(code: .unknown, message: "message")) == #"unknown: "message""#
+    )
+    #expect(
+      String(describing: RPCError(code: .aborted, message: "message")) == #"aborted: "message""#
+    )
 
     struct TestError: Error {}
-    XCTAssertDescription(
-      RPCError(code: .aborted, message: "message", cause: TestError()),
-      #"aborted: "message" (cause: "TestError()")"#
+    #expect(
+      String(describing: RPCError(code: .aborted, message: "message", cause: TestError()))
+        == #"aborted: "message" (cause: "TestError()")"#
     )
   }
 
+  @Test("Error from Status")
   func testErrorFromStatus() throws {
     var status = Status(code: .ok, message: "")
     // ok isn't an error
-    XCTAssertNil(RPCError(status: status))
+    #expect(RPCError(status: status) == nil)
 
     status.code = .invalidArgument
-    var error = try XCTUnwrap(RPCError(status: status))
-    XCTAssertEqual(error.code, .invalidArgument)
-    XCTAssertEqual(error.message, "")
-    XCTAssertEqual(error.metadata, [:])
+    var error = try #require(RPCError(status: status))
+    #expect(error.code == .invalidArgument)
+    #expect(error.message == "")
+    #expect(error.metadata == [:])
 
     status.code = .cancelled
     status.message = "an error message"
-    error = try XCTUnwrap(RPCError(status: status))
-    XCTAssertEqual(error.code, .cancelled)
-    XCTAssertEqual(error.message, "an error message")
-    XCTAssertEqual(error.metadata, [:])
+    error = try #require(RPCError(status: status))
+    #expect(error.code == .cancelled)
+    #expect(error.message == "an error message")
+    #expect(error.metadata == [:])
   }
 
-  func testErrorCodeFromStatusCode() throws {
-    XCTAssertNil(RPCError.Code(Status.Code.ok))
-    XCTAssertEqual(RPCError.Code(Status.Code.cancelled), .cancelled)
-    XCTAssertEqual(RPCError.Code(Status.Code.unknown), .unknown)
-    XCTAssertEqual(RPCError.Code(Status.Code.invalidArgument), .invalidArgument)
-    XCTAssertEqual(RPCError.Code(Status.Code.deadlineExceeded), .deadlineExceeded)
-    XCTAssertEqual(RPCError.Code(Status.Code.notFound), .notFound)
-    XCTAssertEqual(RPCError.Code(Status.Code.alreadyExists), .alreadyExists)
-    XCTAssertEqual(RPCError.Code(Status.Code.permissionDenied), .permissionDenied)
-    XCTAssertEqual(RPCError.Code(Status.Code.resourceExhausted), .resourceExhausted)
-    XCTAssertEqual(RPCError.Code(Status.Code.failedPrecondition), .failedPrecondition)
-    XCTAssertEqual(RPCError.Code(Status.Code.aborted), .aborted)
-    XCTAssertEqual(RPCError.Code(Status.Code.outOfRange), .outOfRange)
-    XCTAssertEqual(RPCError.Code(Status.Code.unimplemented), .unimplemented)
-    XCTAssertEqual(RPCError.Code(Status.Code.internalError), .internalError)
-    XCTAssertEqual(RPCError.Code(Status.Code.unavailable), .unavailable)
-    XCTAssertEqual(RPCError.Code(Status.Code.dataLoss), .dataLoss)
-    XCTAssertEqual(RPCError.Code(Status.Code.unauthenticated), .unauthenticated)
+  @Test(
+    "Error Code from Status Code",
+    arguments: [
+      (Status.Code.ok, nil),
+      (Status.Code.cancelled, RPCError.Code.cancelled),
+      (Status.Code.unknown, RPCError.Code.unknown),
+      (Status.Code.invalidArgument, RPCError.Code.invalidArgument),
+      (Status.Code.deadlineExceeded, RPCError.Code.deadlineExceeded),
+      (Status.Code.notFound, RPCError.Code.notFound),
+      (Status.Code.alreadyExists, RPCError.Code.alreadyExists),
+      (Status.Code.permissionDenied, RPCError.Code.permissionDenied),
+      (Status.Code.resourceExhausted, RPCError.Code.resourceExhausted),
+      (Status.Code.failedPrecondition, RPCError.Code.failedPrecondition),
+      (Status.Code.aborted, RPCError.Code.aborted),
+      (Status.Code.outOfRange, RPCError.Code.outOfRange),
+      (Status.Code.unimplemented, RPCError.Code.unimplemented),
+      (Status.Code.internalError, RPCError.Code.internalError),
+      (Status.Code.unavailable, RPCError.Code.unavailable),
+      (Status.Code.dataLoss, RPCError.Code.dataLoss),
+      (Status.Code.unauthenticated, RPCError.Code.unauthenticated),
+    ]
+  )
+  func testErrorCodeFromStatusCode(statusCode: Status.Code, rpcErrorCode: RPCError.Code?) throws {
+    #expect(RPCError.Code(statusCode) == rpcErrorCode)
   }
 
+  @Test("Equatable Conformance")
   func testEquatableConformance() {
-    XCTAssertEqual(
-      RPCError(code: .cancelled, message: ""),
+    #expect(
       RPCError(code: .cancelled, message: "")
+        == RPCError(code: .cancelled, message: "")
     )
 
-    XCTAssertEqual(
-      RPCError(code: .cancelled, message: "message"),
+    #expect(
       RPCError(code: .cancelled, message: "message")
+        == RPCError(code: .cancelled, message: "message")
     )
 
-    XCTAssertEqual(
-      RPCError(code: .cancelled, message: "message", metadata: ["foo": "bar"]),
+    #expect(
       RPCError(code: .cancelled, message: "message", metadata: ["foo": "bar"])
+        == RPCError(code: .cancelled, message: "message", metadata: ["foo": "bar"])
     )
 
-    XCTAssertNotEqual(
-      RPCError(code: .cancelled, message: ""),
+    #expect(
+      RPCError(code: .cancelled, message: "")
+        != RPCError(code: .cancelled, message: "message")
+    )
+
+    #expect(
       RPCError(code: .cancelled, message: "message")
+        != RPCError(code: .unknown, message: "message")
     )
 
-    XCTAssertNotEqual(
-      RPCError(code: .cancelled, message: "message"),
-      RPCError(code: .unknown, message: "message")
+    #expect(
+      RPCError(code: .cancelled, message: "message", metadata: ["foo": "bar"])
+        != RPCError(code: .cancelled, message: "message", metadata: ["foo": "baz"])
     )
+  }
 
-    XCTAssertNotEqual(
-      RPCError(code: .cancelled, message: "message", metadata: ["foo": "bar"]),
-      RPCError(code: .cancelled, message: "message", metadata: ["foo": "baz"])
+  @Test(
+    "Status Code Raw Values",
+    arguments: [
+      (RPCError.Code.cancelled, 1),
+      (.unknown, 2),
+      (.invalidArgument, 3),
+      (.deadlineExceeded, 4),
+      (.notFound, 5),
+      (.alreadyExists, 6),
+      (.permissionDenied, 7),
+      (.resourceExhausted, 8),
+      (.failedPrecondition, 9),
+      (.aborted, 10),
+      (.outOfRange, 11),
+      (.unimplemented, 12),
+      (.internalError, 13),
+      (.unavailable, 14),
+      (.dataLoss, 15),
+      (.unauthenticated, 16),
+    ]
+  )
+  func testStatusCodeRawValues(statusCode: RPCError.Code, rawValue: Int) {
+    #expect(statusCode.rawValue == rawValue, "\(statusCode) had unexpected raw value")
+  }
+
+  @Test("Flatten causes with same status code")
+  func testFlattenCausesWithSameStatusCode() {
+    let error1 = RPCError(code: .unknown, message: "Error 1.")
+    let error2 = RPCError(code: .unknown, message: "Error 2.", cause: error1)
+    let error3 = RPCError(code: .dataLoss, message: "Error 3.", cause: error2)
+    let error4 = RPCError(code: .aborted, message: "Error 4.", cause: error3)
+    let error5 = RPCError(
+      code: .aborted,
+      message: "Error 5.",
+      cause: error4
+    )
+
+    let unknownMerged = RPCError(code: .unknown, message: "Error 2. Error 1.")
+    let dataLossMerged = RPCError(code: .dataLoss, message: "Error 3.", cause: unknownMerged)
+    let abortedMerged = RPCError(
+      code: .aborted,
+      message: "Error 5. Error 4.",
+      cause: dataLossMerged
     )
+    #expect(error5 == abortedMerged)
   }
 
-  func testStatusCodeRawValues() {
-    for (code, expected) in Self.statusCodeRawValue {
-      XCTAssertEqual(code.rawValue, expected, "\(code) had unexpected raw value")
-    }
+  @Test("Causes of errors with different status codes aren't flattened")
+  func testDifferentStatusCodeAreNotFlattened() throws {
+    let error1 = RPCError(code: .unknown, message: "Error 1.")
+    let error2 = RPCError(code: .dataLoss, message: "Error 2.", cause: error1)
+    let error3 = RPCError(code: .alreadyExists, message: "Error 3.", cause: error2)
+    let error4 = RPCError(code: .aborted, message: "Error 4.", cause: error3)
+    let error5 = RPCError(
+      code: .deadlineExceeded,
+      message: "Error 5.",
+      cause: error4
+    )
+
+    #expect(error5.code == .deadlineExceeded)
+    #expect(error5.message == "Error 5.")
+    let wrappedError4 = try #require(error5.cause as? RPCError)
+    #expect(wrappedError4.code == .aborted)
+    #expect(wrappedError4.message == "Error 4.")
+    let wrappedError3 = try #require(wrappedError4.cause as? RPCError)
+    #expect(wrappedError3.code == .alreadyExists)
+    #expect(wrappedError3.message == "Error 3.")
+    let wrappedError2 = try #require(wrappedError3.cause as? RPCError)
+    #expect(wrappedError2.code == .dataLoss)
+    #expect(wrappedError2.message == "Error 2.")
+    let wrappedError1 = try #require(wrappedError2.cause as? RPCError)
+    #expect(wrappedError1.code == .unknown)
+    #expect(wrappedError1.message == "Error 1.")
+    #expect(wrappedError1.cause == nil)
   }
 }