ServerInterceptorTests.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. import HelloWorldModel
  19. import NIOCore
  20. import NIOEmbedded
  21. import NIOHTTP1
  22. import SwiftProtobuf
  23. import XCTest
  24. @testable import GRPC
  25. extension GRPCServerHandlerProtocol {
  26. fileprivate func receiveRequest(_ request: Echo_EchoRequest) {
  27. let serializer = ProtobufSerializer<Echo_EchoRequest>()
  28. do {
  29. let buffer = try serializer.serialize(request, allocator: ByteBufferAllocator())
  30. self.receiveMessage(buffer)
  31. } catch {
  32. XCTFail("Unexpected error: \(error)")
  33. }
  34. }
  35. }
  36. class ServerInterceptorTests: GRPCTestCase {
  37. private let eventLoop = EmbeddedEventLoop()
  38. private var recorder: ResponseRecorder!
  39. override func setUp() {
  40. super.setUp()
  41. self.recorder = ResponseRecorder(eventLoop: self.eventLoop)
  42. }
  43. private func makeRecordingInterceptor()
  44. -> RecordingServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  45. {
  46. return .init()
  47. }
  48. private func echoProvider(
  49. interceptedBy interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  50. ) -> EchoProvider {
  51. return EchoProvider(interceptors: EchoInterceptorFactory(interceptor: interceptor))
  52. }
  53. private func makeHandlerContext(for path: String) -> CallHandlerContext {
  54. return CallHandlerContext(
  55. errorDelegate: nil,
  56. logger: self.serverLogger,
  57. encoding: .disabled,
  58. eventLoop: self.eventLoop,
  59. path: path,
  60. responseWriter: self.recorder,
  61. allocator: ByteBufferAllocator(),
  62. closeFuture: self.eventLoop.makeSucceededVoidFuture()
  63. )
  64. }
  65. // This is only useful for the type inference.
  66. private func request(
  67. _ request: GRPCServerRequestPart<Echo_EchoRequest>
  68. ) -> GRPCServerRequestPart<Echo_EchoRequest> {
  69. return request
  70. }
  71. private func handleMethod(
  72. _ method: Substring,
  73. using provider: CallHandlerProvider
  74. ) -> GRPCServerHandlerProtocol? {
  75. let path = "/\(provider.serviceName)/\(method)"
  76. let context = self.makeHandlerContext(for: path)
  77. return provider.handle(method: method, context: context)
  78. }
  79. fileprivate typealias ResponsePart = GRPCServerResponsePart<Echo_EchoResponse>
  80. func testPassThroughInterceptor() throws {
  81. let recordingInterceptor = self.makeRecordingInterceptor()
  82. let provider = self.echoProvider(interceptedBy: recordingInterceptor)
  83. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  84. // Send requests.
  85. handler.receiveMetadata([:])
  86. handler.receiveRequest(.with { $0.text = "" })
  87. handler.receiveEnd()
  88. // Expect responses.
  89. assertThat(self.recorder.metadata, .is(.some()))
  90. assertThat(self.recorder.messages.count, .is(1))
  91. assertThat(self.recorder.status, .is(.some()))
  92. // We expect 2 request parts: the provider responds before it sees end, that's fine.
  93. assertThat(recordingInterceptor.requestParts, .hasCount(2))
  94. assertThat(recordingInterceptor.requestParts[0], .is(.metadata()))
  95. assertThat(recordingInterceptor.requestParts[1], .is(.message()))
  96. assertThat(recordingInterceptor.responseParts, .hasCount(3))
  97. assertThat(recordingInterceptor.responseParts[0], .is(.metadata()))
  98. assertThat(recordingInterceptor.responseParts[1], .is(.message()))
  99. assertThat(recordingInterceptor.responseParts[2], .is(.end(status: .is(.ok))))
  100. }
  101. func testUnaryFromInterceptor() throws {
  102. let provider = EchoFromInterceptor()
  103. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  104. // Send the requests.
  105. handler.receiveMetadata([:])
  106. handler.receiveRequest(.with { $0.text = "foo" })
  107. handler.receiveEnd()
  108. // Get the responses.
  109. assertThat(self.recorder.metadata, .is(.some()))
  110. assertThat(self.recorder.messages.count, .is(1))
  111. assertThat(self.recorder.status, .is(.some()))
  112. }
  113. func testClientStreamingFromInterceptor() throws {
  114. let provider = EchoFromInterceptor()
  115. let handler = try assertNotNil(self.handleMethod("Collect", using: provider))
  116. // Send the requests.
  117. handler.receiveMetadata([:])
  118. for text in ["a", "b", "c"] {
  119. handler.receiveRequest(.with { $0.text = text })
  120. }
  121. handler.receiveEnd()
  122. // Get the responses.
  123. assertThat(self.recorder.metadata, .is(.some()))
  124. assertThat(self.recorder.messages.count, .is(1))
  125. assertThat(self.recorder.status, .is(.some()))
  126. }
  127. func testServerStreamingFromInterceptor() throws {
  128. let provider = EchoFromInterceptor()
  129. let handler = try assertNotNil(self.handleMethod("Expand", using: provider))
  130. // Send the requests.
  131. handler.receiveMetadata([:])
  132. handler.receiveRequest(.with { $0.text = "a b c" })
  133. handler.receiveEnd()
  134. // Get the responses.
  135. assertThat(self.recorder.metadata, .is(.some()))
  136. assertThat(self.recorder.messages.count, .is(3))
  137. assertThat(self.recorder.status, .is(.some()))
  138. }
  139. func testBidirectionalStreamingFromInterceptor() throws {
  140. let provider = EchoFromInterceptor()
  141. let handler = try assertNotNil(self.handleMethod("Update", using: provider))
  142. // Send the requests.
  143. handler.receiveMetadata([:])
  144. for text in ["a", "b", "c"] {
  145. handler.receiveRequest(.with { $0.text = text })
  146. }
  147. handler.receiveEnd()
  148. // Get the responses.
  149. assertThat(self.recorder.metadata, .is(.some()))
  150. assertThat(self.recorder.messages.count, .is(3))
  151. assertThat(self.recorder.status, .is(.some()))
  152. }
  153. }
  154. final class EchoInterceptorFactory: Echo_EchoServerInterceptorFactoryProtocol {
  155. private let interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  156. init(interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>) {
  157. self.interceptor = interceptor
  158. }
  159. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  160. return [self.interceptor]
  161. }
  162. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  163. return [self.interceptor]
  164. }
  165. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  166. return [self.interceptor]
  167. }
  168. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  169. return [self.interceptor]
  170. }
  171. }
  172. class ExtraRequestPartEmitter: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  173. enum Part {
  174. case metadata
  175. case message
  176. case end
  177. }
  178. private let part: Part
  179. private let count: Int
  180. init(repeat part: Part, times count: Int) {
  181. self.part = part
  182. self.count = count
  183. }
  184. override func receive(
  185. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  186. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  187. ) {
  188. let count: Int
  189. switch (self.part, part) {
  190. case (.metadata, .metadata),
  191. (.message, .message),
  192. (.end, .end):
  193. count = self.count
  194. default:
  195. count = 1
  196. }
  197. for _ in 0 ..< count {
  198. context.receive(part)
  199. }
  200. }
  201. }
  202. class EchoFromInterceptor: Echo_EchoProvider {
  203. var interceptors: Echo_EchoServerInterceptorFactoryProtocol? = Interceptors()
  204. func get(
  205. request: Echo_EchoRequest,
  206. context: StatusOnlyCallContext
  207. ) -> EventLoopFuture<Echo_EchoResponse> {
  208. XCTFail("Unexpected call to \(#function)")
  209. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  210. }
  211. func expand(
  212. request: Echo_EchoRequest,
  213. context: StreamingResponseCallContext<Echo_EchoResponse>
  214. ) -> EventLoopFuture<GRPCStatus> {
  215. XCTFail("Unexpected call to \(#function)")
  216. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  217. }
  218. func collect(
  219. context: UnaryResponseCallContext<Echo_EchoResponse>
  220. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  221. XCTFail("Unexpected call to \(#function)")
  222. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  223. }
  224. func update(
  225. context: StreamingResponseCallContext<Echo_EchoResponse>
  226. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  227. XCTFail("Unexpected call to \(#function)")
  228. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  229. }
  230. final class Interceptors: Echo_EchoServerInterceptorFactoryProtocol {
  231. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  232. return [Interceptor()]
  233. }
  234. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  235. return [Interceptor()]
  236. }
  237. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  238. return [Interceptor()]
  239. }
  240. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  241. return [Interceptor()]
  242. }
  243. }
  244. // Since all methods use the same request/response types, we can use a single interceptor to
  245. // respond to all of them.
  246. class Interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  247. private var collectedRequests: [Echo_EchoRequest] = []
  248. override func receive(
  249. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  250. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  251. ) {
  252. switch part {
  253. case .metadata:
  254. context.send(.metadata([:]), promise: nil)
  255. case let .message(request):
  256. if context.path.hasSuffix("Get") {
  257. // Unary, just reply.
  258. let response = Echo_EchoResponse.with {
  259. $0.text = "echo: \(request.text)"
  260. }
  261. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  262. } else if context.path.hasSuffix("Expand") {
  263. // Server streaming.
  264. let parts = request.text.split(separator: " ")
  265. let metadata = MessageMetadata(compress: false, flush: false)
  266. for part in parts {
  267. context.send(.message(.with { $0.text = "echo: \(part)" }, metadata), promise: nil)
  268. }
  269. } else if context.path.hasSuffix("Collect") {
  270. // Client streaming, store the requests, reply on '.end'
  271. self.collectedRequests.append(request)
  272. } else if context.path.hasSuffix("Update") {
  273. // Bidirectional streaming.
  274. let response = Echo_EchoResponse.with {
  275. $0.text = "echo: \(request.text)"
  276. }
  277. let metadata = MessageMetadata(compress: false, flush: true)
  278. context.send(.message(response, metadata), promise: nil)
  279. } else {
  280. XCTFail("Unexpected path '\(context.path)'")
  281. }
  282. case .end:
  283. if !self.collectedRequests.isEmpty {
  284. let response = Echo_EchoResponse.with {
  285. $0.text = "echo: " + self.collectedRequests.map { $0.text }.joined(separator: " ")
  286. }
  287. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  288. }
  289. context.send(.end(.ok, [:]), promise: nil)
  290. }
  291. }
  292. }
  293. }
  294. // Avoid having to serialize/deserialize messages in test cases.
  295. private class Codec: ChannelDuplexHandler {
  296. typealias InboundIn = GRPCServerRequestPart<Echo_EchoRequest>
  297. typealias InboundOut = GRPCServerRequestPart<ByteBuffer>
  298. typealias OutboundIn = GRPCServerResponsePart<ByteBuffer>
  299. typealias OutboundOut = GRPCServerResponsePart<Echo_EchoResponse>
  300. private let serializer = ProtobufSerializer<Echo_EchoRequest>()
  301. private let deserializer = ProtobufDeserializer<Echo_EchoResponse>()
  302. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  303. switch self.unwrapInboundIn(data) {
  304. case let .metadata(headers):
  305. context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
  306. case let .message(message):
  307. let serialized = try! self.serializer.serialize(message, allocator: context.channel.allocator)
  308. context.fireChannelRead(self.wrapInboundOut(.message(serialized)))
  309. case .end:
  310. context.fireChannelRead(self.wrapInboundOut(.end))
  311. }
  312. }
  313. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  314. switch self.unwrapOutboundIn(data) {
  315. case let .metadata(headers):
  316. context.write(self.wrapOutboundOut(.metadata(headers)), promise: promise)
  317. case let .message(message, metadata):
  318. let deserialzed = try! self.deserializer.deserialize(byteBuffer: message)
  319. context.write(self.wrapOutboundOut(.message(deserialzed, metadata)), promise: promise)
  320. case let .end(status, trailers):
  321. context.write(self.wrapOutboundOut(.end(status, trailers)), promise: promise)
  322. }
  323. }
  324. }