InterceptorsTests.swift 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. import GRPC
  19. import HelloWorldModel
  20. import NIO
  21. import SwiftProtobuf
  22. import XCTest
  23. class InterceptorsTests: GRPCTestCase {
  24. private var group: EventLoopGroup!
  25. private var server: Server!
  26. private var connection: ClientConnection!
  27. private var echo: Echo_EchoClient!
  28. override func setUp() {
  29. super.setUp()
  30. self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  31. self.server = try! Server.insecure(group: self.group)
  32. .withServiceProviders([EchoProvider(), HelloWorldAuthProvider()])
  33. .withLogger(self.serverLogger)
  34. .bind(host: "localhost", port: 0)
  35. .wait()
  36. self.connection = ClientConnection.insecure(group: self.group)
  37. .withBackgroundActivityLogger(self.clientLogger)
  38. .connect(host: "localhost", port: self.server.channel.localAddress!.port!)
  39. self.echo = Echo_EchoClient(
  40. channel: self.connection,
  41. defaultCallOptions: CallOptions(logger: self.clientLogger),
  42. interceptors: ReversingInterceptors()
  43. )
  44. }
  45. override func tearDown() {
  46. super.tearDown()
  47. XCTAssertNoThrow(try self.connection.close().wait())
  48. XCTAssertNoThrow(try self.server.close().wait())
  49. XCTAssertNoThrow(try self.group.syncShutdownGracefully())
  50. }
  51. func testEcho() {
  52. let get = self.echo.get(.with { $0.text = "hello" })
  53. assertThat(try get.response.wait(), .is(.with { $0.text = "hello :teg ohce tfiwS" }))
  54. assertThat(try get.status.wait(), .hasCode(.ok))
  55. }
  56. func testCollect() {
  57. let collect = self.echo.collect()
  58. collect.sendMessage(.with { $0.text = "1 2" }, promise: nil)
  59. collect.sendMessage(.with { $0.text = "3 4" }, promise: nil)
  60. collect.sendEnd(promise: nil)
  61. assertThat(try collect.response.wait(), .is(.with { $0.text = "3 4 1 2 :tcelloc ohce tfiwS" }))
  62. assertThat(try collect.status.wait(), .hasCode(.ok))
  63. }
  64. func testExpand() {
  65. let expand = self.echo.expand(.with { $0.text = "hello" }) { response in
  66. // Expand splits on spaces, so we only expect one response.
  67. assertThat(response, .is(.with { $0.text = "hello :)0( dnapxe ohce tfiwS" }))
  68. }
  69. assertThat(try expand.status.wait(), .hasCode(.ok))
  70. }
  71. func testUpdate() {
  72. let update = self.echo.update { response in
  73. // We'll just send the one message, so only expect one response.
  74. assertThat(response, .is(.with { $0.text = "hello :)0( etadpu ohce tfiwS" }))
  75. }
  76. update.sendMessage(.with { $0.text = "hello" }, promise: nil)
  77. update.sendEnd(promise: nil)
  78. assertThat(try update.status.wait(), .hasCode(.ok))
  79. }
  80. func testSayHello() {
  81. let greeter = Helloworld_GreeterClient(
  82. channel: self.connection,
  83. defaultCallOptions: CallOptions(logger: self.clientLogger)
  84. )
  85. // Make a call without interceptors.
  86. let notAuthed = greeter.sayHello(.with { $0.name = "World" })
  87. assertThat(try notAuthed.response.wait(), .throws())
  88. assertThat(
  89. try notAuthed.trailingMetadata.wait(),
  90. .contains("www-authenticate", .equalTo(["Magic"]))
  91. )
  92. assertThat(try notAuthed.status.wait(), .hasCode(.unauthenticated))
  93. // Add an interceptor factory.
  94. greeter.interceptors = HelloWorldInterceptorFactory(client: greeter)
  95. // Make sure we break the reference cycle.
  96. defer {
  97. greeter.interceptors = nil
  98. }
  99. // Try again with the not-really-auth interceptor:
  100. let hello = greeter.sayHello(.with { $0.name = "PanCakes" })
  101. assertThat(
  102. try hello.response.map { $0.message }.wait(),
  103. .is(.equalTo("Hello, PanCakes, you're authorized!"))
  104. )
  105. assertThat(try hello.status.wait(), .hasCode(.ok))
  106. }
  107. }
  108. // MARK: - Helpers
  109. class HelloWorldAuthProvider: Helloworld_GreeterProvider {
  110. func sayHello(
  111. request: Helloworld_HelloRequest,
  112. context: StatusOnlyCallContext
  113. ) -> EventLoopFuture<Helloworld_HelloReply> {
  114. // TODO: do this in a server interceptor, when we have one.
  115. if context.headers.first(name: "authorization") == "Magic" {
  116. let response = Helloworld_HelloReply.with {
  117. $0.message = "Hello, \(request.name), you're authorized!"
  118. }
  119. return context.eventLoop.makeSucceededFuture(response)
  120. } else {
  121. context.trailers.add(name: "www-authenticate", value: "Magic")
  122. return context.eventLoop.makeFailedFuture(GRPCStatus(code: .unauthenticated, message: nil))
  123. }
  124. }
  125. }
  126. private class HelloWorldInterceptorFactory: Helloworld_GreeterClientInterceptorFactoryProtocol {
  127. var client: Helloworld_GreeterClient
  128. init(client: Helloworld_GreeterClient) {
  129. self.client = client
  130. }
  131. func makeInterceptors<Request: Message, Response: Message>(
  132. ) -> [ClientInterceptor<Request, Response>] {
  133. return [NotReallyAuth(client: self.client)]
  134. }
  135. }
  136. class NotReallyAuth<Request: Message, Response: Message>: ClientInterceptor<Request, Response> {
  137. private let client: Helloworld_GreeterClient
  138. private enum State {
  139. // We're trying the call, these are the parts we've sent so far.
  140. case trying([ClientRequestPart<Request>])
  141. // We're retrying using this call.
  142. case retrying(Call<Request, Response>)
  143. }
  144. private var state: State = .trying([])
  145. init(client: Helloworld_GreeterClient) {
  146. self.client = client
  147. }
  148. override func cancel(
  149. promise: EventLoopPromise<Void>?,
  150. context: ClientInterceptorContext<Request, Response>
  151. ) {
  152. switch self.state {
  153. case .trying:
  154. context.cancel(promise: promise)
  155. case let .retrying(call):
  156. call.cancel(promise: promise)
  157. context.cancel(promise: nil)
  158. }
  159. }
  160. override func send(
  161. _ part: ClientRequestPart<Request>,
  162. promise: EventLoopPromise<Void>?,
  163. context: ClientInterceptorContext<Request, Response>
  164. ) {
  165. switch self.state {
  166. case var .trying(parts):
  167. // Record the part, incase we need to retry.
  168. parts.append(part)
  169. self.state = .trying(parts)
  170. // Forward the request part.
  171. context.send(part, promise: promise)
  172. case let .retrying(call):
  173. // We're retrying, send the part to the retry call.
  174. call.send(part, promise: promise)
  175. }
  176. }
  177. override func receive(
  178. _ part: ClientResponsePart<Response>,
  179. context: ClientInterceptorContext<Request, Response>
  180. ) {
  181. switch self.state {
  182. case var .trying(parts):
  183. switch part {
  184. // If 'authentication' fails this is the only part we expect, we can forward everything else.
  185. case let .end(status, trailers) where status.code == .unauthenticated:
  186. // We only know how to deal with magic.
  187. guard trailers.first(name: "www-authenticate") == "Magic" else {
  188. // We can't handle this, fail.
  189. context.receive(part)
  190. return
  191. }
  192. // We know how to handle this: make a new call.
  193. let call: Call<Request, Response> = self.client.channel.makeCall(
  194. path: context.path,
  195. type: context.type,
  196. callOptions: context.options,
  197. // We could grab interceptors from the client, but we don't need to.
  198. interceptors: []
  199. )
  200. // We're retying the call now.
  201. self.state = .retrying(call)
  202. // Invoke the call and redirect responses here.
  203. call.invoke(context.receive(_:))
  204. // Parts must contain the metadata as the first item if we got that first response.
  205. if case var .some(.metadata(metadata)) = parts.first {
  206. metadata.replaceOrAdd(name: "authorization", value: "Magic")
  207. parts[0] = .metadata(metadata)
  208. }
  209. // Now replay any requests on the retry call.
  210. for part in parts {
  211. call.send(part, promise: nil)
  212. }
  213. default:
  214. context.receive(part)
  215. }
  216. case .retrying:
  217. // Ignore anything we receive on the original call.
  218. ()
  219. }
  220. }
  221. }
  222. class EchoReverseInterceptor: ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  223. override func send(
  224. _ part: ClientRequestPart<Echo_EchoRequest>,
  225. promise: EventLoopPromise<Void>?,
  226. context: ClientInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  227. ) {
  228. switch part {
  229. case .message(var request, let metadata):
  230. request.text = String(request.text.reversed())
  231. context.send(.message(request, metadata), promise: promise)
  232. default:
  233. context.send(part, promise: promise)
  234. }
  235. }
  236. override func receive(
  237. _ part: ClientResponsePart<Echo_EchoResponse>,
  238. context: ClientInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  239. ) {
  240. switch part {
  241. case var .message(response):
  242. response.text = String(response.text.reversed())
  243. context.receive(.message(response))
  244. default:
  245. context.receive(part)
  246. }
  247. }
  248. }
  249. private class ReversingInterceptors: Echo_EchoClientInterceptorFactoryProtocol {
  250. // This interceptor is stateless, let's just share it.
  251. private let interceptors = [EchoReverseInterceptor()]
  252. func makeGetInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  253. return self.interceptors
  254. }
  255. func makeExpandInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  256. return self.interceptors
  257. }
  258. func makeCollectInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  259. return self.interceptors
  260. }
  261. func makeUpdateInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  262. return self.interceptors
  263. }
  264. }