InterceptedRPCCancellationTests.swift 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. /*
  2. * Copyright 2021, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. import Logging
  19. import NIOCore
  20. import NIOPosix
  21. import XCTest
  22. import protocol SwiftProtobuf.Message
  23. @testable import GRPC
  24. final class InterceptedRPCCancellationTests: GRPCTestCase {
  25. func testCancellationWithinInterceptedRPC() throws {
  26. // This test validates that when using interceptors to replay an RPC that the lifecycle of
  27. // the interceptor pipeline is correctly managed. That is, the transport maintains a reference
  28. // to the pipeline for as long as the call is alive (rather than dropping the reference when
  29. // the RPC ends).
  30. let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
  31. defer {
  32. XCTAssertNoThrow(try group.syncShutdownGracefully())
  33. }
  34. // Interceptor checks that a "magic" header is present.
  35. let serverInterceptors = EchoServerInterceptors({ MagicRequiredServerInterceptor() })
  36. let server = try Server.insecure(group: group)
  37. .withLogger(self.serverLogger)
  38. .withServiceProviders([EchoProvider(interceptors: serverInterceptors)])
  39. .bind(host: "127.0.0.1", port: 0)
  40. .wait()
  41. defer {
  42. XCTAssertNoThrow(try server.close().wait())
  43. }
  44. let connection = ClientConnection.insecure(group: group)
  45. .withBackgroundActivityLogger(self.clientLogger)
  46. .connect(host: "127.0.0.1", port: server.channel.localAddress!.port!)
  47. defer {
  48. XCTAssertNoThrow(try connection.close().wait())
  49. }
  50. // Retries an RPC with a "magic" header if it fails with the permission denied status code.
  51. let clientInterceptors = EchoClientInterceptors {
  52. return MagicAddingClientInterceptor(channel: connection)
  53. }
  54. let echo = Echo_EchoNIOClient(channel: connection, interceptors: clientInterceptors)
  55. let receivedFirstResponse = connection.eventLoop.makePromise(of: Void.self)
  56. let update = echo.update { _ in
  57. receivedFirstResponse.succeed(())
  58. }
  59. XCTAssertNoThrow(try update.sendMessage(.with { $0.text = "ping" }).wait())
  60. // Wait for the pong: it means the second RPC is up and running and the first should have
  61. // completed.
  62. XCTAssertNoThrow(try receivedFirstResponse.futureResult.wait())
  63. XCTAssertNoThrow(try update.cancel().wait())
  64. let status = try update.status.wait()
  65. XCTAssertEqual(status.code, .cancelled)
  66. }
  67. }
  68. final class MagicRequiredServerInterceptor<
  69. Request: Message,
  70. Response: Message
  71. >: ServerInterceptor<Request, Response>, @unchecked Sendable {
  72. override func receive(
  73. _ part: GRPCServerRequestPart<Request>,
  74. context: ServerInterceptorContext<Request, Response>
  75. ) {
  76. switch part {
  77. case let .metadata(metadata):
  78. if metadata.contains(name: "magic") {
  79. context.logger.debug("metadata contains magic; accepting rpc")
  80. context.receive(part)
  81. } else {
  82. context.logger.debug("metadata does not contains magic; rejecting rpc")
  83. let status = GRPCStatus(code: .permissionDenied, message: nil)
  84. context.send(.end(status, [:]), promise: nil)
  85. }
  86. case .message, .end:
  87. context.receive(part)
  88. }
  89. }
  90. }
  91. final class MagicAddingClientInterceptor<
  92. Request: Message,
  93. Response: Message
  94. >: ClientInterceptor<Request, Response>, @unchecked Sendable {
  95. private let channel: GRPCChannel
  96. private var requestParts = CircularBuffer<GRPCClientRequestPart<Request>>()
  97. private var retry: Call<Request, Response>?
  98. init(channel: GRPCChannel) {
  99. self.channel = channel
  100. }
  101. override func cancel(
  102. promise: EventLoopPromise<Void>?,
  103. context: ClientInterceptorContext<Request, Response>
  104. ) {
  105. if let retry = self.retry {
  106. context.logger.debug("cancelling retry RPC")
  107. retry.cancel(promise: promise)
  108. } else {
  109. context.cancel(promise: promise)
  110. }
  111. }
  112. override func send(
  113. _ part: GRPCClientRequestPart<Request>,
  114. promise: EventLoopPromise<Void>?,
  115. context: ClientInterceptorContext<Request, Response>
  116. ) {
  117. if let retry = self.retry {
  118. context.logger.debug("retrying part \(part)")
  119. retry.send(part, promise: promise)
  120. } else {
  121. switch part {
  122. case .metadata:
  123. // Replace the metadata with the magic words.
  124. self.requestParts.append(.metadata(["magic": "it's real!"]))
  125. case .message, .end:
  126. self.requestParts.append(part)
  127. }
  128. context.send(part, promise: promise)
  129. }
  130. }
  131. override func receive(
  132. _ part: GRPCClientResponsePart<Response>,
  133. context: ClientInterceptorContext<Request, Response>
  134. ) {
  135. switch part {
  136. case .metadata, .message:
  137. XCTFail("Unexpected response part \(part)")
  138. context.receive(part)
  139. case let .end(status, _):
  140. guard status.code == .permissionDenied else {
  141. XCTFail("Unexpected status code \(status)")
  142. context.receive(part)
  143. return
  144. }
  145. XCTAssertNil(self.retry)
  146. context.logger.debug("initial rpc failed, retrying")
  147. self.retry = self.channel.makeCall(
  148. path: context.path,
  149. type: context.type,
  150. callOptions: CallOptions(logger: context.logger),
  151. interceptors: []
  152. )
  153. self.retry!.invoke {
  154. context.logger.debug("intercepting error from retried rpc")
  155. context.errorCaught($0)
  156. } onResponsePart: { responsePart in
  157. context.logger.debug("intercepting response part from retried rpc")
  158. context.receive(responsePart)
  159. }
  160. while let requestPart = self.requestParts.popFirst() {
  161. context.logger.debug("replaying \(requestPart) on new rpc")
  162. self.retry!.send(requestPart, promise: nil)
  163. }
  164. }
  165. }
  166. }