ServerInterceptorTests.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. @testable import GRPC
  19. import HelloWorldModel
  20. import NIO
  21. import NIOHTTP1
  22. import SwiftProtobuf
  23. import XCTest
  24. extension GRPCServerHandlerProtocol {
  25. fileprivate func receiveRequest(_ request: Echo_EchoRequest) {
  26. let serializer = ProtobufSerializer<Echo_EchoRequest>()
  27. do {
  28. let buffer = try serializer.serialize(request, allocator: ByteBufferAllocator())
  29. self.receiveMessage(buffer)
  30. } catch {
  31. XCTFail("Unexpected error: \(error)")
  32. }
  33. }
  34. }
  35. class ServerInterceptorTests: GRPCTestCase {
  36. private let eventLoop = EmbeddedEventLoop()
  37. private let recorder = ResponseRecorder()
  38. private func makeRecordingInterceptor()
  39. -> RecordingServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  40. return .init()
  41. }
  42. private func echoProvider(
  43. interceptedBy interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  44. ) -> EchoProvider {
  45. return EchoProvider(interceptors: EchoInterceptorFactory(interceptor: interceptor))
  46. }
  47. private func makeHandlerContext(for path: String) -> CallHandlerContext {
  48. return CallHandlerContext(
  49. errorDelegate: nil,
  50. logger: self.serverLogger,
  51. encoding: .disabled,
  52. eventLoop: self.eventLoop,
  53. path: path,
  54. responseWriter: self.recorder,
  55. allocator: ByteBufferAllocator(),
  56. closeFuture: self.eventLoop.makeSucceededVoidFuture()
  57. )
  58. }
  59. // This is only useful for the type inference.
  60. private func request(
  61. _ request: GRPCServerRequestPart<Echo_EchoRequest>
  62. ) -> GRPCServerRequestPart<Echo_EchoRequest> {
  63. return request
  64. }
  65. private func handleMethod(
  66. _ method: Substring,
  67. using provider: CallHandlerProvider
  68. ) -> GRPCServerHandlerProtocol? {
  69. let path = "/\(provider.serviceName)/\(method)"
  70. let context = self.makeHandlerContext(for: path)
  71. return provider.handle(method: method, context: context)
  72. }
  73. fileprivate typealias ResponsePart = GRPCServerResponsePart<Echo_EchoResponse>
  74. func testPassThroughInterceptor() throws {
  75. let recordingInterceptor = self.makeRecordingInterceptor()
  76. let provider = self.echoProvider(interceptedBy: recordingInterceptor)
  77. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  78. // Send requests.
  79. handler.receiveMetadata([:])
  80. handler.receiveRequest(.with { $0.text = "" })
  81. handler.receiveEnd()
  82. // Expect responses.
  83. assertThat(self.recorder.metadata, .is(.notNil()))
  84. assertThat(self.recorder.messages.count, .is(1))
  85. assertThat(self.recorder.status, .is(.notNil()))
  86. // We expect 2 request parts: the provider responds before it sees end, that's fine.
  87. assertThat(recordingInterceptor.requestParts, .hasCount(2))
  88. assertThat(recordingInterceptor.requestParts[0], .is(.metadata()))
  89. assertThat(recordingInterceptor.requestParts[1], .is(.message()))
  90. assertThat(recordingInterceptor.responseParts, .hasCount(3))
  91. assertThat(recordingInterceptor.responseParts[0], .is(.metadata()))
  92. assertThat(recordingInterceptor.responseParts[1], .is(.message()))
  93. assertThat(recordingInterceptor.responseParts[2], .is(.end(status: .is(.ok))))
  94. }
  95. func testUnaryFromInterceptor() throws {
  96. let provider = EchoFromInterceptor()
  97. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  98. // Send the requests.
  99. handler.receiveMetadata([:])
  100. handler.receiveRequest(.with { $0.text = "foo" })
  101. handler.receiveEnd()
  102. // Get the responses.
  103. assertThat(self.recorder.metadata, .is(.notNil()))
  104. assertThat(self.recorder.messages.count, .is(1))
  105. assertThat(self.recorder.status, .is(.notNil()))
  106. }
  107. func testClientStreamingFromInterceptor() throws {
  108. let provider = EchoFromInterceptor()
  109. let handler = try assertNotNil(self.handleMethod("Collect", using: provider))
  110. // Send the requests.
  111. handler.receiveMetadata([:])
  112. for text in ["a", "b", "c"] {
  113. handler.receiveRequest(.with { $0.text = text })
  114. }
  115. handler.receiveEnd()
  116. // Get the responses.
  117. assertThat(self.recorder.metadata, .is(.notNil()))
  118. assertThat(self.recorder.messages.count, .is(1))
  119. assertThat(self.recorder.status, .is(.notNil()))
  120. }
  121. func testServerStreamingFromInterceptor() throws {
  122. let provider = EchoFromInterceptor()
  123. let handler = try assertNotNil(self.handleMethod("Expand", using: provider))
  124. // Send the requests.
  125. handler.receiveMetadata([:])
  126. handler.receiveRequest(.with { $0.text = "a b c" })
  127. handler.receiveEnd()
  128. // Get the responses.
  129. assertThat(self.recorder.metadata, .is(.notNil()))
  130. assertThat(self.recorder.messages.count, .is(3))
  131. assertThat(self.recorder.status, .is(.notNil()))
  132. }
  133. func testBidirectionalStreamingFromInterceptor() throws {
  134. let provider = EchoFromInterceptor()
  135. let handler = try assertNotNil(self.handleMethod("Update", using: provider))
  136. // Send the requests.
  137. handler.receiveMetadata([:])
  138. for text in ["a", "b", "c"] {
  139. handler.receiveRequest(.with { $0.text = text })
  140. }
  141. handler.receiveEnd()
  142. // Get the responses.
  143. assertThat(self.recorder.metadata, .is(.notNil()))
  144. assertThat(self.recorder.messages.count, .is(3))
  145. assertThat(self.recorder.status, .is(.notNil()))
  146. }
  147. }
  148. class EchoInterceptorFactory: Echo_EchoServerInterceptorFactoryProtocol {
  149. private let interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  150. init(interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>) {
  151. self.interceptor = interceptor
  152. }
  153. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  154. return [self.interceptor]
  155. }
  156. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  157. return [self.interceptor]
  158. }
  159. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  160. return [self.interceptor]
  161. }
  162. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  163. return [self.interceptor]
  164. }
  165. }
  166. class ExtraRequestPartEmitter: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  167. enum Part {
  168. case metadata
  169. case message
  170. case end
  171. }
  172. private let part: Part
  173. private let count: Int
  174. init(repeat part: Part, times count: Int) {
  175. self.part = part
  176. self.count = count
  177. }
  178. override func receive(
  179. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  180. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  181. ) {
  182. let count: Int
  183. switch (self.part, part) {
  184. case (.metadata, .metadata),
  185. (.message, .message),
  186. (.end, .end):
  187. count = self.count
  188. default:
  189. count = 1
  190. }
  191. for _ in 0 ..< count {
  192. context.receive(part)
  193. }
  194. }
  195. }
  196. class EchoFromInterceptor: Echo_EchoProvider {
  197. var interceptors: Echo_EchoServerInterceptorFactoryProtocol? = Interceptors()
  198. func get(
  199. request: Echo_EchoRequest,
  200. context: StatusOnlyCallContext
  201. ) -> EventLoopFuture<Echo_EchoResponse> {
  202. XCTFail("Unexpected call to \(#function)")
  203. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  204. }
  205. func expand(
  206. request: Echo_EchoRequest,
  207. context: StreamingResponseCallContext<Echo_EchoResponse>
  208. ) -> EventLoopFuture<GRPCStatus> {
  209. XCTFail("Unexpected call to \(#function)")
  210. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  211. }
  212. func collect(
  213. context: UnaryResponseCallContext<Echo_EchoResponse>
  214. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  215. XCTFail("Unexpected call to \(#function)")
  216. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  217. }
  218. func update(
  219. context: StreamingResponseCallContext<Echo_EchoResponse>
  220. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  221. XCTFail("Unexpected call to \(#function)")
  222. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  223. }
  224. class Interceptors: Echo_EchoServerInterceptorFactoryProtocol {
  225. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  226. return [Interceptor()]
  227. }
  228. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  229. return [Interceptor()]
  230. }
  231. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  232. return [Interceptor()]
  233. }
  234. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  235. return [Interceptor()]
  236. }
  237. }
  238. // Since all methods use the same request/response types, we can use a single interceptor to
  239. // respond to all of them.
  240. class Interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  241. private var collectedRequests: [Echo_EchoRequest] = []
  242. override func receive(
  243. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  244. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  245. ) {
  246. switch part {
  247. case .metadata:
  248. context.send(.metadata([:]), promise: nil)
  249. case let .message(request):
  250. if context.path.hasSuffix("Get") {
  251. // Unary, just reply.
  252. let response = Echo_EchoResponse.with {
  253. $0.text = "echo: \(request.text)"
  254. }
  255. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  256. } else if context.path.hasSuffix("Expand") {
  257. // Server streaming.
  258. let parts = request.text.split(separator: " ")
  259. let metadata = MessageMetadata(compress: false, flush: false)
  260. for part in parts {
  261. context.send(.message(.with { $0.text = "echo: \(part)" }, metadata), promise: nil)
  262. }
  263. } else if context.path.hasSuffix("Collect") {
  264. // Client streaming, store the requests, reply on '.end'
  265. self.collectedRequests.append(request)
  266. } else if context.path.hasSuffix("Update") {
  267. // Bidirectional streaming.
  268. let response = Echo_EchoResponse.with {
  269. $0.text = "echo: \(request.text)"
  270. }
  271. let metadata = MessageMetadata(compress: false, flush: true)
  272. context.send(.message(response, metadata), promise: nil)
  273. } else {
  274. XCTFail("Unexpected path '\(context.path)'")
  275. }
  276. case .end:
  277. if !self.collectedRequests.isEmpty {
  278. let response = Echo_EchoResponse.with {
  279. $0.text = "echo: " + self.collectedRequests.map { $0.text }.joined(separator: " ")
  280. }
  281. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  282. }
  283. context.send(.end(.ok, [:]), promise: nil)
  284. }
  285. }
  286. }
  287. }
  288. // Avoid having to serialize/deserialize messages in test cases.
  289. private class Codec: ChannelDuplexHandler {
  290. typealias InboundIn = GRPCServerRequestPart<Echo_EchoRequest>
  291. typealias InboundOut = GRPCServerRequestPart<ByteBuffer>
  292. typealias OutboundIn = GRPCServerResponsePart<ByteBuffer>
  293. typealias OutboundOut = GRPCServerResponsePart<Echo_EchoResponse>
  294. private let serializer = ProtobufSerializer<Echo_EchoRequest>()
  295. private let deserializer = ProtobufDeserializer<Echo_EchoResponse>()
  296. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  297. switch self.unwrapInboundIn(data) {
  298. case let .metadata(headers):
  299. context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
  300. case let .message(message):
  301. let serialized = try! self.serializer.serialize(message, allocator: context.channel.allocator)
  302. context.fireChannelRead(self.wrapInboundOut(.message(serialized)))
  303. case .end:
  304. context.fireChannelRead(self.wrapInboundOut(.end))
  305. }
  306. }
  307. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  308. switch self.unwrapOutboundIn(data) {
  309. case let .metadata(headers):
  310. context.write(self.wrapOutboundOut(.metadata(headers)), promise: promise)
  311. case let .message(message, metadata):
  312. let deserialzed = try! self.deserializer.deserialize(byteBuffer: message)
  313. context.write(self.wrapOutboundOut(.message(deserialzed, metadata)), promise: promise)
  314. case let .end(status, trailers):
  315. context.write(self.wrapOutboundOut(.end(status, trailers)), promise: promise)
  316. }
  317. }
  318. }