|
|
@@ -18,11 +18,11 @@ import GRPCCore
|
|
|
|
|
|
extension ClientInterceptor where Self == RejectAllClientInterceptor {
|
|
|
static func rejectAll(with error: RPCError) -> Self {
|
|
|
- return RejectAllClientInterceptor(error: error, throw: false)
|
|
|
+ return RejectAllClientInterceptor(reject: error)
|
|
|
}
|
|
|
|
|
|
- static func throwError(_ error: RPCError) -> Self {
|
|
|
- return RejectAllClientInterceptor(error: error, throw: true)
|
|
|
+ static func throwError(_ error: any Error) -> Self {
|
|
|
+ return RejectAllClientInterceptor(throw: error)
|
|
|
}
|
|
|
|
|
|
}
|
|
|
@@ -35,15 +35,21 @@ extension ClientInterceptor where Self == RequestCountingClientInterceptor {
|
|
|
|
|
|
/// Rejects all RPCs with the provided error.
|
|
|
struct RejectAllClientInterceptor: ClientInterceptor {
|
|
|
- /// The error to reject all RPCs with.
|
|
|
- let error: RPCError
|
|
|
- /// Whether the error should be thrown. If `false` then the request is rejected with the error
|
|
|
- /// instead.
|
|
|
- let `throw`: Bool
|
|
|
+ enum Mode: Sendable {
|
|
|
+ /// Throw the error rather.
|
|
|
+ case `throw`(any Error)
|
|
|
+ /// Reject the RPC with a given error.
|
|
|
+ case reject(RPCError)
|
|
|
+ }
|
|
|
+
|
|
|
+ let mode: Mode
|
|
|
+
|
|
|
+ init(throw error: any Error) {
|
|
|
+ self.mode = .throw(error)
|
|
|
+ }
|
|
|
|
|
|
- init(error: RPCError, throw: Bool = false) {
|
|
|
- self.error = error
|
|
|
- self.`throw` = `throw`
|
|
|
+ init(reject error: RPCError) {
|
|
|
+ self.mode = .reject(error)
|
|
|
}
|
|
|
|
|
|
func intercept<Input: Sendable, Output: Sendable>(
|
|
|
@@ -54,10 +60,11 @@ struct RejectAllClientInterceptor: ClientInterceptor {
|
|
|
ClientContext
|
|
|
) async throws -> StreamingClientResponse<Output>
|
|
|
) async throws -> StreamingClientResponse<Output> {
|
|
|
- if self.throw {
|
|
|
- throw self.error
|
|
|
- } else {
|
|
|
- return StreamingClientResponse(error: self.error)
|
|
|
+ switch self.mode {
|
|
|
+ case .throw(let error):
|
|
|
+ throw error
|
|
|
+ case .reject(let error):
|
|
|
+ return StreamingClientResponse(error: error)
|
|
|
}
|
|
|
}
|
|
|
}
|