ServerInterceptorTests.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. @testable import GRPC
  19. import HelloWorldModel
  20. import NIOCore
  21. import NIOEmbedded
  22. import NIOHTTP1
  23. import SwiftProtobuf
  24. import XCTest
  25. extension GRPCServerHandlerProtocol {
  26. fileprivate func receiveRequest(_ request: Echo_EchoRequest) {
  27. let serializer = ProtobufSerializer<Echo_EchoRequest>()
  28. do {
  29. let buffer = try serializer.serialize(request, allocator: ByteBufferAllocator())
  30. self.receiveMessage(buffer)
  31. } catch {
  32. XCTFail("Unexpected error: \(error)")
  33. }
  34. }
  35. }
  36. class ServerInterceptorTests: GRPCTestCase {
  37. private let eventLoop = EmbeddedEventLoop()
  38. private var recorder: ResponseRecorder!
  39. override func setUp() {
  40. super.setUp()
  41. self.recorder = ResponseRecorder(eventLoop: self.eventLoop)
  42. }
  43. private func makeRecordingInterceptor()
  44. -> RecordingServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  45. return .init()
  46. }
  47. private func echoProvider(
  48. interceptedBy interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  49. ) -> EchoProvider {
  50. return EchoProvider(interceptors: EchoInterceptorFactory(interceptor: interceptor))
  51. }
  52. private func makeHandlerContext(for path: String) -> CallHandlerContext {
  53. return CallHandlerContext(
  54. errorDelegate: nil,
  55. logger: self.serverLogger,
  56. encoding: .disabled,
  57. eventLoop: self.eventLoop,
  58. path: path,
  59. responseWriter: self.recorder,
  60. allocator: ByteBufferAllocator(),
  61. closeFuture: self.eventLoop.makeSucceededVoidFuture()
  62. )
  63. }
  64. // This is only useful for the type inference.
  65. private func request(
  66. _ request: GRPCServerRequestPart<Echo_EchoRequest>
  67. ) -> GRPCServerRequestPart<Echo_EchoRequest> {
  68. return request
  69. }
  70. private func handleMethod(
  71. _ method: Substring,
  72. using provider: CallHandlerProvider
  73. ) -> GRPCServerHandlerProtocol? {
  74. let path = "/\(provider.serviceName)/\(method)"
  75. let context = self.makeHandlerContext(for: path)
  76. return provider.handle(method: method, context: context)
  77. }
  78. fileprivate typealias ResponsePart = GRPCServerResponsePart<Echo_EchoResponse>
  79. func testPassThroughInterceptor() throws {
  80. let recordingInterceptor = self.makeRecordingInterceptor()
  81. let provider = self.echoProvider(interceptedBy: recordingInterceptor)
  82. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  83. // Send requests.
  84. handler.receiveMetadata([:])
  85. handler.receiveRequest(.with { $0.text = "" })
  86. handler.receiveEnd()
  87. // Expect responses.
  88. assertThat(self.recorder.metadata, .is(.notNil()))
  89. assertThat(self.recorder.messages.count, .is(1))
  90. assertThat(self.recorder.status, .is(.notNil()))
  91. // We expect 2 request parts: the provider responds before it sees end, that's fine.
  92. assertThat(recordingInterceptor.requestParts, .hasCount(2))
  93. assertThat(recordingInterceptor.requestParts[0], .is(.metadata()))
  94. assertThat(recordingInterceptor.requestParts[1], .is(.message()))
  95. assertThat(recordingInterceptor.responseParts, .hasCount(3))
  96. assertThat(recordingInterceptor.responseParts[0], .is(.metadata()))
  97. assertThat(recordingInterceptor.responseParts[1], .is(.message()))
  98. assertThat(recordingInterceptor.responseParts[2], .is(.end(status: .is(.ok))))
  99. }
  100. func testUnaryFromInterceptor() throws {
  101. let provider = EchoFromInterceptor()
  102. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  103. // Send the requests.
  104. handler.receiveMetadata([:])
  105. handler.receiveRequest(.with { $0.text = "foo" })
  106. handler.receiveEnd()
  107. // Get the responses.
  108. assertThat(self.recorder.metadata, .is(.notNil()))
  109. assertThat(self.recorder.messages.count, .is(1))
  110. assertThat(self.recorder.status, .is(.notNil()))
  111. }
  112. func testClientStreamingFromInterceptor() throws {
  113. let provider = EchoFromInterceptor()
  114. let handler = try assertNotNil(self.handleMethod("Collect", using: provider))
  115. // Send the requests.
  116. handler.receiveMetadata([:])
  117. for text in ["a", "b", "c"] {
  118. handler.receiveRequest(.with { $0.text = text })
  119. }
  120. handler.receiveEnd()
  121. // Get the responses.
  122. assertThat(self.recorder.metadata, .is(.notNil()))
  123. assertThat(self.recorder.messages.count, .is(1))
  124. assertThat(self.recorder.status, .is(.notNil()))
  125. }
  126. func testServerStreamingFromInterceptor() throws {
  127. let provider = EchoFromInterceptor()
  128. let handler = try assertNotNil(self.handleMethod("Expand", using: provider))
  129. // Send the requests.
  130. handler.receiveMetadata([:])
  131. handler.receiveRequest(.with { $0.text = "a b c" })
  132. handler.receiveEnd()
  133. // Get the responses.
  134. assertThat(self.recorder.metadata, .is(.notNil()))
  135. assertThat(self.recorder.messages.count, .is(3))
  136. assertThat(self.recorder.status, .is(.notNil()))
  137. }
  138. func testBidirectionalStreamingFromInterceptor() throws {
  139. let provider = EchoFromInterceptor()
  140. let handler = try assertNotNil(self.handleMethod("Update", using: provider))
  141. // Send the requests.
  142. handler.receiveMetadata([:])
  143. for text in ["a", "b", "c"] {
  144. handler.receiveRequest(.with { $0.text = text })
  145. }
  146. handler.receiveEnd()
  147. // Get the responses.
  148. assertThat(self.recorder.metadata, .is(.notNil()))
  149. assertThat(self.recorder.messages.count, .is(3))
  150. assertThat(self.recorder.status, .is(.notNil()))
  151. }
  152. }
  153. final class EchoInterceptorFactory: Echo_EchoServerInterceptorFactoryProtocol {
  154. private let interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  155. init(interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>) {
  156. self.interceptor = interceptor
  157. }
  158. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  159. return [self.interceptor]
  160. }
  161. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  162. return [self.interceptor]
  163. }
  164. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  165. return [self.interceptor]
  166. }
  167. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  168. return [self.interceptor]
  169. }
  170. }
  171. class ExtraRequestPartEmitter: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  172. enum Part {
  173. case metadata
  174. case message
  175. case end
  176. }
  177. private let part: Part
  178. private let count: Int
  179. init(repeat part: Part, times count: Int) {
  180. self.part = part
  181. self.count = count
  182. }
  183. override func receive(
  184. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  185. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  186. ) {
  187. let count: Int
  188. switch (self.part, part) {
  189. case (.metadata, .metadata),
  190. (.message, .message),
  191. (.end, .end):
  192. count = self.count
  193. default:
  194. count = 1
  195. }
  196. for _ in 0 ..< count {
  197. context.receive(part)
  198. }
  199. }
  200. }
  201. class EchoFromInterceptor: Echo_EchoProvider {
  202. var interceptors: Echo_EchoServerInterceptorFactoryProtocol? = Interceptors()
  203. func get(
  204. request: Echo_EchoRequest,
  205. context: StatusOnlyCallContext
  206. ) -> EventLoopFuture<Echo_EchoResponse> {
  207. XCTFail("Unexpected call to \(#function)")
  208. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  209. }
  210. func expand(
  211. request: Echo_EchoRequest,
  212. context: StreamingResponseCallContext<Echo_EchoResponse>
  213. ) -> EventLoopFuture<GRPCStatus> {
  214. XCTFail("Unexpected call to \(#function)")
  215. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  216. }
  217. func collect(
  218. context: UnaryResponseCallContext<Echo_EchoResponse>
  219. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  220. XCTFail("Unexpected call to \(#function)")
  221. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  222. }
  223. func update(
  224. context: StreamingResponseCallContext<Echo_EchoResponse>
  225. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  226. XCTFail("Unexpected call to \(#function)")
  227. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  228. }
  229. final class Interceptors: Echo_EchoServerInterceptorFactoryProtocol {
  230. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  231. return [Interceptor()]
  232. }
  233. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  234. return [Interceptor()]
  235. }
  236. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  237. return [Interceptor()]
  238. }
  239. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  240. return [Interceptor()]
  241. }
  242. }
  243. // Since all methods use the same request/response types, we can use a single interceptor to
  244. // respond to all of them.
  245. class Interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  246. private var collectedRequests: [Echo_EchoRequest] = []
  247. override func receive(
  248. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  249. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  250. ) {
  251. switch part {
  252. case .metadata:
  253. context.send(.metadata([:]), promise: nil)
  254. case let .message(request):
  255. if context.path.hasSuffix("Get") {
  256. // Unary, just reply.
  257. let response = Echo_EchoResponse.with {
  258. $0.text = "echo: \(request.text)"
  259. }
  260. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  261. } else if context.path.hasSuffix("Expand") {
  262. // Server streaming.
  263. let parts = request.text.split(separator: " ")
  264. let metadata = MessageMetadata(compress: false, flush: false)
  265. for part in parts {
  266. context.send(.message(.with { $0.text = "echo: \(part)" }, metadata), promise: nil)
  267. }
  268. } else if context.path.hasSuffix("Collect") {
  269. // Client streaming, store the requests, reply on '.end'
  270. self.collectedRequests.append(request)
  271. } else if context.path.hasSuffix("Update") {
  272. // Bidirectional streaming.
  273. let response = Echo_EchoResponse.with {
  274. $0.text = "echo: \(request.text)"
  275. }
  276. let metadata = MessageMetadata(compress: false, flush: true)
  277. context.send(.message(response, metadata), promise: nil)
  278. } else {
  279. XCTFail("Unexpected path '\(context.path)'")
  280. }
  281. case .end:
  282. if !self.collectedRequests.isEmpty {
  283. let response = Echo_EchoResponse.with {
  284. $0.text = "echo: " + self.collectedRequests.map { $0.text }.joined(separator: " ")
  285. }
  286. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  287. }
  288. context.send(.end(.ok, [:]), promise: nil)
  289. }
  290. }
  291. }
  292. }
  293. // Avoid having to serialize/deserialize messages in test cases.
  294. private class Codec: ChannelDuplexHandler {
  295. typealias InboundIn = GRPCServerRequestPart<Echo_EchoRequest>
  296. typealias InboundOut = GRPCServerRequestPart<ByteBuffer>
  297. typealias OutboundIn = GRPCServerResponsePart<ByteBuffer>
  298. typealias OutboundOut = GRPCServerResponsePart<Echo_EchoResponse>
  299. private let serializer = ProtobufSerializer<Echo_EchoRequest>()
  300. private let deserializer = ProtobufDeserializer<Echo_EchoResponse>()
  301. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  302. switch self.unwrapInboundIn(data) {
  303. case let .metadata(headers):
  304. context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
  305. case let .message(message):
  306. let serialized = try! self.serializer.serialize(message, allocator: context.channel.allocator)
  307. context.fireChannelRead(self.wrapInboundOut(.message(serialized)))
  308. case .end:
  309. context.fireChannelRead(self.wrapInboundOut(.end))
  310. }
  311. }
  312. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  313. switch self.unwrapOutboundIn(data) {
  314. case let .metadata(headers):
  315. context.write(self.wrapOutboundOut(.metadata(headers)), promise: promise)
  316. case let .message(message, metadata):
  317. let deserialzed = try! self.deserializer.deserialize(byteBuffer: message)
  318. context.write(self.wrapOutboundOut(.message(deserialzed, metadata)), promise: promise)
  319. case let .end(status, trailers):
  320. context.write(self.wrapOutboundOut(.end(status, trailers)), promise: promise)
  321. }
  322. }
  323. }