ServerInterceptorTests.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. @testable import GRPC
  19. import HelloWorldModel
  20. import NIO
  21. import NIOHTTP1
  22. import SwiftProtobuf
  23. import XCTest
  24. extension GRPCServerHandlerProtocol {
  25. fileprivate func receiveRequest(_ request: Echo_EchoRequest) {
  26. let serializer = ProtobufSerializer<Echo_EchoRequest>()
  27. do {
  28. let buffer = try serializer.serialize(request, allocator: ByteBufferAllocator())
  29. self.receiveMessage(buffer)
  30. } catch {
  31. XCTFail("Unexpected error: \(error)")
  32. }
  33. }
  34. }
  35. class ServerInterceptorTests: GRPCTestCase {
  36. private let eventLoop = EmbeddedEventLoop()
  37. private let recorder = ResponseRecorder()
  38. private func makeRecordingInterceptor()
  39. -> RecordingServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  40. return .init()
  41. }
  42. private func echoProvider(
  43. interceptedBy interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  44. ) -> EchoProvider {
  45. return EchoProvider(interceptors: EchoInterceptorFactory(interceptor: interceptor))
  46. }
  47. private func makeHandlerContext(for path: String) -> CallHandlerContext {
  48. return CallHandlerContext(
  49. errorDelegate: nil,
  50. logger: self.serverLogger,
  51. encoding: .disabled,
  52. eventLoop: self.eventLoop,
  53. path: path,
  54. responseWriter: self.recorder,
  55. allocator: ByteBufferAllocator()
  56. )
  57. }
  58. // This is only useful for the type inference.
  59. private func request(
  60. _ request: GRPCServerRequestPart<Echo_EchoRequest>
  61. ) -> GRPCServerRequestPart<Echo_EchoRequest> {
  62. return request
  63. }
  64. private func handleMethod(
  65. _ method: Substring,
  66. using provider: CallHandlerProvider
  67. ) -> GRPCServerHandlerProtocol? {
  68. let path = "/\(provider.serviceName)/\(method)"
  69. let context = self.makeHandlerContext(for: path)
  70. return provider.handle(method: method, context: context)
  71. }
  72. fileprivate typealias ResponsePart = GRPCServerResponsePart<Echo_EchoResponse>
  73. func testPassThroughInterceptor() throws {
  74. let recordingInterceptor = self.makeRecordingInterceptor()
  75. let provider = self.echoProvider(interceptedBy: recordingInterceptor)
  76. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  77. // Send requests.
  78. handler.receiveMetadata([:])
  79. handler.receiveRequest(.with { $0.text = "" })
  80. handler.receiveEnd()
  81. // Expect responses.
  82. assertThat(self.recorder.metadata, .is(.notNil()))
  83. assertThat(self.recorder.messages.count, .is(1))
  84. assertThat(self.recorder.status, .is(.notNil()))
  85. // We expect 2 request parts: the provider responds before it sees end, that's fine.
  86. assertThat(recordingInterceptor.requestParts, .hasCount(2))
  87. assertThat(recordingInterceptor.requestParts[0], .is(.metadata()))
  88. assertThat(recordingInterceptor.requestParts[1], .is(.message()))
  89. assertThat(recordingInterceptor.responseParts, .hasCount(3))
  90. assertThat(recordingInterceptor.responseParts[0], .is(.metadata()))
  91. assertThat(recordingInterceptor.responseParts[1], .is(.message()))
  92. assertThat(recordingInterceptor.responseParts[2], .is(.end(status: .is(.ok))))
  93. }
  94. func testUnaryFromInterceptor() throws {
  95. let provider = EchoFromInterceptor()
  96. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  97. // Send the requests.
  98. handler.receiveMetadata([:])
  99. handler.receiveRequest(.with { $0.text = "foo" })
  100. handler.receiveEnd()
  101. // Get the responses.
  102. assertThat(self.recorder.metadata, .is(.notNil()))
  103. assertThat(self.recorder.messages.count, .is(1))
  104. assertThat(self.recorder.status, .is(.notNil()))
  105. }
  106. func testClientStreamingFromInterceptor() throws {
  107. let provider = EchoFromInterceptor()
  108. let handler = try assertNotNil(self.handleMethod("Collect", using: provider))
  109. // Send the requests.
  110. handler.receiveMetadata([:])
  111. for text in ["a", "b", "c"] {
  112. handler.receiveRequest(.with { $0.text = text })
  113. }
  114. handler.receiveEnd()
  115. // Get the responses.
  116. assertThat(self.recorder.metadata, .is(.notNil()))
  117. assertThat(self.recorder.messages.count, .is(1))
  118. assertThat(self.recorder.status, .is(.notNil()))
  119. }
  120. func testServerStreamingFromInterceptor() throws {
  121. let provider = EchoFromInterceptor()
  122. let handler = try assertNotNil(self.handleMethod("Expand", using: provider))
  123. // Send the requests.
  124. handler.receiveMetadata([:])
  125. handler.receiveRequest(.with { $0.text = "a b c" })
  126. handler.receiveEnd()
  127. // Get the responses.
  128. assertThat(self.recorder.metadata, .is(.notNil()))
  129. assertThat(self.recorder.messages.count, .is(3))
  130. assertThat(self.recorder.status, .is(.notNil()))
  131. }
  132. func testBidirectionalStreamingFromInterceptor() throws {
  133. let provider = EchoFromInterceptor()
  134. let handler = try assertNotNil(self.handleMethod("Update", using: provider))
  135. // Send the requests.
  136. handler.receiveMetadata([:])
  137. for text in ["a", "b", "c"] {
  138. handler.receiveRequest(.with { $0.text = text })
  139. }
  140. handler.receiveEnd()
  141. // Get the responses.
  142. assertThat(self.recorder.metadata, .is(.notNil()))
  143. assertThat(self.recorder.messages.count, .is(3))
  144. assertThat(self.recorder.status, .is(.notNil()))
  145. }
  146. }
  147. class EchoInterceptorFactory: Echo_EchoServerInterceptorFactoryProtocol {
  148. private let interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  149. init(interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>) {
  150. self.interceptor = interceptor
  151. }
  152. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  153. return [self.interceptor]
  154. }
  155. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  156. return [self.interceptor]
  157. }
  158. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  159. return [self.interceptor]
  160. }
  161. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  162. return [self.interceptor]
  163. }
  164. }
  165. class ExtraRequestPartEmitter: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  166. enum Part {
  167. case metadata
  168. case message
  169. case end
  170. }
  171. private let part: Part
  172. private let count: Int
  173. init(repeat part: Part, times count: Int) {
  174. self.part = part
  175. self.count = count
  176. }
  177. override func receive(
  178. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  179. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  180. ) {
  181. let count: Int
  182. switch (self.part, part) {
  183. case (.metadata, .metadata),
  184. (.message, .message),
  185. (.end, .end):
  186. count = self.count
  187. default:
  188. count = 1
  189. }
  190. for _ in 0 ..< count {
  191. context.receive(part)
  192. }
  193. }
  194. }
  195. class EchoFromInterceptor: Echo_EchoProvider {
  196. var interceptors: Echo_EchoServerInterceptorFactoryProtocol? = Interceptors()
  197. func get(
  198. request: Echo_EchoRequest,
  199. context: StatusOnlyCallContext
  200. ) -> EventLoopFuture<Echo_EchoResponse> {
  201. XCTFail("Unexpected call to \(#function)")
  202. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  203. }
  204. func expand(
  205. request: Echo_EchoRequest,
  206. context: StreamingResponseCallContext<Echo_EchoResponse>
  207. ) -> EventLoopFuture<GRPCStatus> {
  208. XCTFail("Unexpected call to \(#function)")
  209. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  210. }
  211. func collect(
  212. context: UnaryResponseCallContext<Echo_EchoResponse>
  213. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  214. XCTFail("Unexpected call to \(#function)")
  215. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  216. }
  217. func update(
  218. context: StreamingResponseCallContext<Echo_EchoResponse>
  219. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  220. XCTFail("Unexpected call to \(#function)")
  221. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  222. }
  223. class Interceptors: Echo_EchoServerInterceptorFactoryProtocol {
  224. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  225. return [Interceptor()]
  226. }
  227. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  228. return [Interceptor()]
  229. }
  230. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  231. return [Interceptor()]
  232. }
  233. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  234. return [Interceptor()]
  235. }
  236. }
  237. // Since all methods use the same request/response types, we can use a single interceptor to
  238. // respond to all of them.
  239. class Interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  240. private var collectedRequests: [Echo_EchoRequest] = []
  241. override func receive(
  242. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  243. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  244. ) {
  245. switch part {
  246. case .metadata:
  247. context.send(.metadata([:]), promise: nil)
  248. case let .message(request):
  249. if context.path.hasSuffix("Get") {
  250. // Unary, just reply.
  251. let response = Echo_EchoResponse.with {
  252. $0.text = "echo: \(request.text)"
  253. }
  254. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  255. } else if context.path.hasSuffix("Expand") {
  256. // Server streaming.
  257. let parts = request.text.split(separator: " ")
  258. let metadata = MessageMetadata(compress: false, flush: false)
  259. for part in parts {
  260. context.send(.message(.with { $0.text = "echo: \(part)" }, metadata), promise: nil)
  261. }
  262. } else if context.path.hasSuffix("Collect") {
  263. // Client streaming, store the requests, reply on '.end'
  264. self.collectedRequests.append(request)
  265. } else if context.path.hasSuffix("Update") {
  266. // Bidirectional streaming.
  267. let response = Echo_EchoResponse.with {
  268. $0.text = "echo: \(request.text)"
  269. }
  270. let metadata = MessageMetadata(compress: false, flush: true)
  271. context.send(.message(response, metadata), promise: nil)
  272. } else {
  273. XCTFail("Unexpected path '\(context.path)'")
  274. }
  275. case .end:
  276. if !self.collectedRequests.isEmpty {
  277. let response = Echo_EchoResponse.with {
  278. $0.text = "echo: " + self.collectedRequests.map { $0.text }.joined(separator: " ")
  279. }
  280. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  281. }
  282. context.send(.end(.ok, [:]), promise: nil)
  283. }
  284. }
  285. }
  286. }
  287. // Avoid having to serialize/deserialize messages in test cases.
  288. private class Codec: ChannelDuplexHandler {
  289. typealias InboundIn = GRPCServerRequestPart<Echo_EchoRequest>
  290. typealias InboundOut = GRPCServerRequestPart<ByteBuffer>
  291. typealias OutboundIn = GRPCServerResponsePart<ByteBuffer>
  292. typealias OutboundOut = GRPCServerResponsePart<Echo_EchoResponse>
  293. private let serializer = ProtobufSerializer<Echo_EchoRequest>()
  294. private let deserializer = ProtobufDeserializer<Echo_EchoResponse>()
  295. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  296. switch self.unwrapInboundIn(data) {
  297. case let .metadata(headers):
  298. context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
  299. case let .message(message):
  300. let serialized = try! self.serializer.serialize(message, allocator: context.channel.allocator)
  301. context.fireChannelRead(self.wrapInboundOut(.message(serialized)))
  302. case .end:
  303. context.fireChannelRead(self.wrapInboundOut(.end))
  304. }
  305. }
  306. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  307. switch self.unwrapOutboundIn(data) {
  308. case let .metadata(headers):
  309. context.write(self.wrapOutboundOut(.metadata(headers)), promise: promise)
  310. case let .message(message, metadata):
  311. let deserialzed = try! self.deserializer.deserialize(byteBuffer: message)
  312. context.write(self.wrapOutboundOut(.message(deserialzed, metadata)), promise: promise)
  313. case let .end(status, trailers):
  314. context.write(self.wrapOutboundOut(.end(status, trailers)), promise: promise)
  315. }
  316. }
  317. }