InterceptorsTests.swift 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Atomics
  17. import EchoImplementation
  18. import EchoModel
  19. import GRPC
  20. import HelloWorldModel
  21. import NIOCore
  22. import NIOHPACK
  23. import NIOPosix
  24. import SwiftProtobuf
  25. import XCTest
  26. class InterceptorsTests: GRPCTestCase {
  27. private var group: EventLoopGroup!
  28. private var server: Server!
  29. private var connection: ClientConnection!
  30. private var echo: Echo_EchoNIOClient!
  31. private let onCloseCounter = ManagedAtomic<Int>(0)
  32. override func setUp() {
  33. super.setUp()
  34. self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  35. self.server = try! Server.insecure(group: self.group)
  36. .withServiceProviders([
  37. EchoProvider(interceptors: CountOnCloseInterceptors(counter: self.onCloseCounter)),
  38. HelloWorldProvider(interceptors: HelloWorldServerInterceptorFactory()),
  39. ])
  40. .withLogger(self.serverLogger)
  41. .bind(host: "localhost", port: 0)
  42. .wait()
  43. self.connection = ClientConnection.insecure(group: self.group)
  44. .withBackgroundActivityLogger(self.clientLogger)
  45. .connect(host: "localhost", port: self.server.channel.localAddress!.port!)
  46. self.echo = Echo_EchoNIOClient(
  47. channel: self.connection,
  48. defaultCallOptions: CallOptions(logger: self.clientLogger),
  49. interceptors: ReversingInterceptors()
  50. )
  51. }
  52. override func tearDown() {
  53. super.tearDown()
  54. XCTAssertNoThrow(try self.connection.close().wait())
  55. XCTAssertNoThrow(try self.server.close().wait())
  56. XCTAssertNoThrow(try self.group.syncShutdownGracefully())
  57. }
  58. func testEcho() {
  59. let get = self.echo.get(.with { $0.text = "hello" })
  60. assertThat(try get.response.wait(), .is(.with { $0.text = "hello :teg ohce tfiwS" }))
  61. assertThat(try get.status.wait(), .hasCode(.ok))
  62. XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1)
  63. }
  64. func testCollect() {
  65. let collect = self.echo.collect()
  66. collect.sendMessage(.with { $0.text = "1 2" }, promise: nil)
  67. collect.sendMessage(.with { $0.text = "3 4" }, promise: nil)
  68. collect.sendEnd(promise: nil)
  69. assertThat(try collect.response.wait(), .is(.with { $0.text = "3 4 1 2 :tcelloc ohce tfiwS" }))
  70. assertThat(try collect.status.wait(), .hasCode(.ok))
  71. XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1)
  72. }
  73. func testExpand() {
  74. let expand = self.echo.expand(.with { $0.text = "hello" }) { response in
  75. // Expand splits on spaces, so we only expect one response.
  76. assertThat(response, .is(.with { $0.text = "hello :)0( dnapxe ohce tfiwS" }))
  77. }
  78. assertThat(try expand.status.wait(), .hasCode(.ok))
  79. XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1)
  80. }
  81. func testUpdate() {
  82. let update = self.echo.update { response in
  83. // We'll just send the one message, so only expect one response.
  84. assertThat(response, .is(.with { $0.text = "hello :)0( etadpu ohce tfiwS" }))
  85. }
  86. update.sendMessage(.with { $0.text = "hello" }, promise: nil)
  87. update.sendEnd(promise: nil)
  88. assertThat(try update.status.wait(), .hasCode(.ok))
  89. XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1)
  90. }
  91. func testSayHello() {
  92. var greeter = Helloworld_GreeterNIOClient(
  93. channel: self.connection,
  94. defaultCallOptions: CallOptions(logger: self.clientLogger)
  95. )
  96. // Make a call without interceptors.
  97. let notAuthed = greeter.sayHello(.with { $0.name = "World" })
  98. assertThat(try notAuthed.response.wait(), .throws())
  99. assertThat(
  100. try notAuthed.trailingMetadata.wait(),
  101. .contains("www-authenticate", ["Magic"])
  102. )
  103. assertThat(try notAuthed.status.wait(), .hasCode(.unauthenticated))
  104. // Add an interceptor factory.
  105. greeter.interceptors = HelloWorldClientInterceptorFactory(client: greeter)
  106. // Make sure we break the reference cycle.
  107. defer {
  108. greeter.interceptors = nil
  109. }
  110. // Try again with the not-really-auth interceptor:
  111. let hello = greeter.sayHello(.with { $0.name = "PanCakes" })
  112. assertThat(
  113. try hello.response.map { $0.message }.wait(),
  114. .is(.equalTo("Hello, PanCakes, you're authorized!"))
  115. )
  116. assertThat(try hello.status.wait(), .hasCode(.ok))
  117. }
  118. }
  119. // MARK: - Helpers
  120. class HelloWorldProvider: Helloworld_GreeterProvider {
  121. var interceptors: Helloworld_GreeterServerInterceptorFactoryProtocol?
  122. init(interceptors: Helloworld_GreeterServerInterceptorFactoryProtocol? = nil) {
  123. self.interceptors = interceptors
  124. }
  125. func sayHello(
  126. request: Helloworld_HelloRequest,
  127. context: StatusOnlyCallContext
  128. ) -> EventLoopFuture<Helloworld_HelloReply> {
  129. // Since we're auth'd, the 'userInfo' should have some magic set.
  130. assertThat(context.userInfo.magic, .is("Magic"))
  131. let response = Helloworld_HelloReply.with {
  132. $0.message = "Hello, \(request.name), you're authorized!"
  133. }
  134. return context.eventLoop.makeSucceededFuture(response)
  135. }
  136. }
  137. extension HelloWorldClientInterceptorFactory: @unchecked Sendable {}
  138. private class HelloWorldClientInterceptorFactory:
  139. Helloworld_GreeterClientInterceptorFactoryProtocol
  140. {
  141. var client: Helloworld_GreeterNIOClient
  142. init(client: Helloworld_GreeterNIOClient) {
  143. self.client = client
  144. }
  145. func makeSayHelloInterceptors() -> [ClientInterceptor<
  146. Helloworld_HelloRequest, Helloworld_HelloReply
  147. >] {
  148. return [NotReallyAuthClientInterceptor(client: self.client)]
  149. }
  150. }
  151. class RemoteAddressExistsInterceptor<Request, Response>:
  152. ServerInterceptor<Request, Response>, @unchecked Sendable
  153. {
  154. override func receive(
  155. _ part: GRPCServerRequestPart<Request>,
  156. context: ServerInterceptorContext<Request, Response>
  157. ) {
  158. XCTAssertNotNil(context.remoteAddress)
  159. super.receive(part, context: context)
  160. }
  161. }
  162. class NotReallyAuthServerInterceptor<Request: Message, Response: Message>:
  163. ServerInterceptor<Request, Response>,
  164. @unchecked Sendable
  165. {
  166. override func receive(
  167. _ part: GRPCServerRequestPart<Request>,
  168. context: ServerInterceptorContext<Request, Response>
  169. ) {
  170. switch part {
  171. case let .metadata(headers):
  172. if let auth = headers.first(name: "authorization"), auth == "Magic" {
  173. context.userInfo.magic = auth
  174. context.receive(part)
  175. } else {
  176. // Not auth'd. Fail the RPC.
  177. let status = GRPCStatus(code: .unauthenticated, message: "You need some magic auth!")
  178. let trailers = HPACKHeaders([("www-authenticate", "Magic")])
  179. context.send(.end(status, trailers), promise: nil)
  180. }
  181. case .message, .end:
  182. context.receive(part)
  183. }
  184. }
  185. }
  186. final class HelloWorldServerInterceptorFactory: Helloworld_GreeterServerInterceptorFactoryProtocol {
  187. func makeSayHelloInterceptors() -> [ServerInterceptor<
  188. Helloworld_HelloRequest, Helloworld_HelloReply
  189. >] {
  190. return [RemoteAddressExistsInterceptor(), NotReallyAuthServerInterceptor()]
  191. }
  192. }
  193. class NotReallyAuthClientInterceptor<Request: Message, Response: Message>:
  194. ClientInterceptor<Request, Response>, @unchecked Sendable
  195. {
  196. private let client: Helloworld_GreeterNIOClient
  197. private enum State {
  198. // We're trying the call, these are the parts we've sent so far.
  199. case trying([GRPCClientRequestPart<Request>])
  200. // We're retrying using this call.
  201. case retrying(Call<Request, Response>)
  202. }
  203. private var state: State = .trying([])
  204. init(client: Helloworld_GreeterNIOClient) {
  205. self.client = client
  206. }
  207. override func cancel(
  208. promise: EventLoopPromise<Void>?,
  209. context: ClientInterceptorContext<Request, Response>
  210. ) {
  211. switch self.state {
  212. case .trying:
  213. context.cancel(promise: promise)
  214. case let .retrying(call):
  215. call.cancel(promise: promise)
  216. context.cancel(promise: nil)
  217. }
  218. }
  219. override func send(
  220. _ part: GRPCClientRequestPart<Request>,
  221. promise: EventLoopPromise<Void>?,
  222. context: ClientInterceptorContext<Request, Response>
  223. ) {
  224. switch self.state {
  225. case var .trying(parts):
  226. // Record the part, incase we need to retry.
  227. parts.append(part)
  228. self.state = .trying(parts)
  229. // Forward the request part.
  230. context.send(part, promise: promise)
  231. case let .retrying(call):
  232. // We're retrying, send the part to the retry call.
  233. call.send(part, promise: promise)
  234. }
  235. }
  236. override func receive(
  237. _ part: GRPCClientResponsePart<Response>,
  238. context: ClientInterceptorContext<Request, Response>
  239. ) {
  240. switch self.state {
  241. case var .trying(parts):
  242. switch part {
  243. // If 'authentication' fails this is the only part we expect, we can forward everything else.
  244. case let .end(status, trailers) where status.code == .unauthenticated:
  245. // We only know how to deal with magic.
  246. guard trailers.first(name: "www-authenticate") == "Magic" else {
  247. // We can't handle this, fail.
  248. context.receive(part)
  249. return
  250. }
  251. // We know how to handle this: make a new call.
  252. let call: Call<Request, Response> = self.client.channel.makeCall(
  253. path: context.path,
  254. type: context.type,
  255. callOptions: context.options,
  256. // We could grab interceptors from the client, but we don't need to.
  257. interceptors: []
  258. )
  259. // We're retying the call now.
  260. self.state = .retrying(call)
  261. // Invoke the call and redirect responses here.
  262. call.invoke(onError: context.errorCaught(_:), onResponsePart: context.receive(_:))
  263. // Parts must contain the metadata as the first item if we got that first response.
  264. if case var .some(.metadata(metadata)) = parts.first {
  265. metadata.replaceOrAdd(name: "authorization", value: "Magic")
  266. parts[0] = .metadata(metadata)
  267. }
  268. // Now replay any requests on the retry call.
  269. for part in parts {
  270. call.send(part, promise: nil)
  271. }
  272. default:
  273. context.receive(part)
  274. }
  275. case .retrying:
  276. // Ignore anything we receive on the original call.
  277. ()
  278. }
  279. }
  280. }
  281. final class EchoReverseInterceptor: ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>,
  282. @unchecked Sendable
  283. {
  284. override func send(
  285. _ part: GRPCClientRequestPart<Echo_EchoRequest>,
  286. promise: EventLoopPromise<Void>?,
  287. context: ClientInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  288. ) {
  289. switch part {
  290. case .message(var request, let metadata):
  291. request.text = String(request.text.reversed())
  292. context.send(.message(request, metadata), promise: promise)
  293. default:
  294. context.send(part, promise: promise)
  295. }
  296. }
  297. override func receive(
  298. _ part: GRPCClientResponsePart<Echo_EchoResponse>,
  299. context: ClientInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  300. ) {
  301. switch part {
  302. case var .message(response):
  303. response.text = String(response.text.reversed())
  304. context.receive(.message(response))
  305. default:
  306. context.receive(part)
  307. }
  308. }
  309. }
  310. final class ReversingInterceptors: Echo_EchoClientInterceptorFactoryProtocol {
  311. // This interceptor is stateless, let's just share it.
  312. private let interceptors = [EchoReverseInterceptor()]
  313. func makeGetInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  314. return self.interceptors
  315. }
  316. func makeExpandInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  317. return self.interceptors
  318. }
  319. func makeCollectInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  320. return self.interceptors
  321. }
  322. func makeUpdateInterceptors() -> [ClientInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  323. return self.interceptors
  324. }
  325. }
  326. final class CountOnCloseInterceptors: Echo_EchoServerInterceptorFactoryProtocol {
  327. // This interceptor is stateless, let's just share it.
  328. private let interceptors: [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>]
  329. init(counter: ManagedAtomic<Int>) {
  330. self.interceptors = [CountOnCloseServerInterceptor(counter: counter)]
  331. }
  332. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  333. return self.interceptors
  334. }
  335. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  336. return self.interceptors
  337. }
  338. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  339. return self.interceptors
  340. }
  341. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  342. return self.interceptors
  343. }
  344. }
  345. final class CountOnCloseServerInterceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>,
  346. @unchecked Sendable
  347. {
  348. private let counter: ManagedAtomic<Int>
  349. init(counter: ManagedAtomic<Int>) {
  350. self.counter = counter
  351. }
  352. override func receive(
  353. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  354. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  355. ) {
  356. switch part {
  357. case .metadata:
  358. context.closeFuture.whenComplete { _ in
  359. self.counter.wrappingIncrement(ordering: .sequentiallyConsistent)
  360. }
  361. default:
  362. ()
  363. }
  364. context.receive(part)
  365. }
  366. }
  367. private enum MagicKey: UserInfo.Key {
  368. typealias Value = String
  369. }
  370. extension UserInfo {
  371. fileprivate var magic: MagicKey.Value? {
  372. get {
  373. return self[MagicKey.self]
  374. }
  375. set {
  376. self[MagicKey.self] = newValue
  377. }
  378. }
  379. }