ServerInterceptorTests.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. @testable import GRPC
  19. import HelloWorldModel
  20. import NIOCore
  21. import NIOEmbedded
  22. import NIOHTTP1
  23. import SwiftProtobuf
  24. import XCTest
  25. extension GRPCServerHandlerProtocol {
  26. fileprivate func receiveRequest(_ request: Echo_EchoRequest) {
  27. let serializer = ProtobufSerializer<Echo_EchoRequest>()
  28. do {
  29. let buffer = try serializer.serialize(request, allocator: ByteBufferAllocator())
  30. self.receiveMessage(buffer)
  31. } catch {
  32. XCTFail("Unexpected error: \(error)")
  33. }
  34. }
  35. }
  36. class ServerInterceptorTests: GRPCTestCase {
  37. private let eventLoop = EmbeddedEventLoop()
  38. private let recorder = ResponseRecorder()
  39. private func makeRecordingInterceptor()
  40. -> RecordingServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  41. return .init()
  42. }
  43. private func echoProvider(
  44. interceptedBy interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  45. ) -> EchoProvider {
  46. return EchoProvider(interceptors: EchoInterceptorFactory(interceptor: interceptor))
  47. }
  48. private func makeHandlerContext(for path: String) -> CallHandlerContext {
  49. return CallHandlerContext(
  50. errorDelegate: nil,
  51. logger: self.serverLogger,
  52. encoding: .disabled,
  53. eventLoop: self.eventLoop,
  54. path: path,
  55. responseWriter: self.recorder,
  56. allocator: ByteBufferAllocator(),
  57. closeFuture: self.eventLoop.makeSucceededVoidFuture()
  58. )
  59. }
  60. // This is only useful for the type inference.
  61. private func request(
  62. _ request: GRPCServerRequestPart<Echo_EchoRequest>
  63. ) -> GRPCServerRequestPart<Echo_EchoRequest> {
  64. return request
  65. }
  66. private func handleMethod(
  67. _ method: Substring,
  68. using provider: CallHandlerProvider
  69. ) -> GRPCServerHandlerProtocol? {
  70. let path = "/\(provider.serviceName)/\(method)"
  71. let context = self.makeHandlerContext(for: path)
  72. return provider.handle(method: method, context: context)
  73. }
  74. fileprivate typealias ResponsePart = GRPCServerResponsePart<Echo_EchoResponse>
  75. func testPassThroughInterceptor() throws {
  76. let recordingInterceptor = self.makeRecordingInterceptor()
  77. let provider = self.echoProvider(interceptedBy: recordingInterceptor)
  78. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  79. // Send requests.
  80. handler.receiveMetadata([:])
  81. handler.receiveRequest(.with { $0.text = "" })
  82. handler.receiveEnd()
  83. // Expect responses.
  84. assertThat(self.recorder.metadata, .is(.notNil()))
  85. assertThat(self.recorder.messages.count, .is(1))
  86. assertThat(self.recorder.status, .is(.notNil()))
  87. // We expect 2 request parts: the provider responds before it sees end, that's fine.
  88. assertThat(recordingInterceptor.requestParts, .hasCount(2))
  89. assertThat(recordingInterceptor.requestParts[0], .is(.metadata()))
  90. assertThat(recordingInterceptor.requestParts[1], .is(.message()))
  91. assertThat(recordingInterceptor.responseParts, .hasCount(3))
  92. assertThat(recordingInterceptor.responseParts[0], .is(.metadata()))
  93. assertThat(recordingInterceptor.responseParts[1], .is(.message()))
  94. assertThat(recordingInterceptor.responseParts[2], .is(.end(status: .is(.ok))))
  95. }
  96. func testUnaryFromInterceptor() throws {
  97. let provider = EchoFromInterceptor()
  98. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  99. // Send the requests.
  100. handler.receiveMetadata([:])
  101. handler.receiveRequest(.with { $0.text = "foo" })
  102. handler.receiveEnd()
  103. // Get the responses.
  104. assertThat(self.recorder.metadata, .is(.notNil()))
  105. assertThat(self.recorder.messages.count, .is(1))
  106. assertThat(self.recorder.status, .is(.notNil()))
  107. }
  108. func testClientStreamingFromInterceptor() throws {
  109. let provider = EchoFromInterceptor()
  110. let handler = try assertNotNil(self.handleMethod("Collect", using: provider))
  111. // Send the requests.
  112. handler.receiveMetadata([:])
  113. for text in ["a", "b", "c"] {
  114. handler.receiveRequest(.with { $0.text = text })
  115. }
  116. handler.receiveEnd()
  117. // Get the responses.
  118. assertThat(self.recorder.metadata, .is(.notNil()))
  119. assertThat(self.recorder.messages.count, .is(1))
  120. assertThat(self.recorder.status, .is(.notNil()))
  121. }
  122. func testServerStreamingFromInterceptor() throws {
  123. let provider = EchoFromInterceptor()
  124. let handler = try assertNotNil(self.handleMethod("Expand", using: provider))
  125. // Send the requests.
  126. handler.receiveMetadata([:])
  127. handler.receiveRequest(.with { $0.text = "a b c" })
  128. handler.receiveEnd()
  129. // Get the responses.
  130. assertThat(self.recorder.metadata, .is(.notNil()))
  131. assertThat(self.recorder.messages.count, .is(3))
  132. assertThat(self.recorder.status, .is(.notNil()))
  133. }
  134. func testBidirectionalStreamingFromInterceptor() throws {
  135. let provider = EchoFromInterceptor()
  136. let handler = try assertNotNil(self.handleMethod("Update", using: provider))
  137. // Send the requests.
  138. handler.receiveMetadata([:])
  139. for text in ["a", "b", "c"] {
  140. handler.receiveRequest(.with { $0.text = text })
  141. }
  142. handler.receiveEnd()
  143. // Get the responses.
  144. assertThat(self.recorder.metadata, .is(.notNil()))
  145. assertThat(self.recorder.messages.count, .is(3))
  146. assertThat(self.recorder.status, .is(.notNil()))
  147. }
  148. }
  149. class EchoInterceptorFactory: Echo_EchoServerInterceptorFactoryProtocol {
  150. private let interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  151. init(interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>) {
  152. self.interceptor = interceptor
  153. }
  154. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  155. return [self.interceptor]
  156. }
  157. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  158. return [self.interceptor]
  159. }
  160. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  161. return [self.interceptor]
  162. }
  163. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  164. return [self.interceptor]
  165. }
  166. }
  167. class ExtraRequestPartEmitter: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  168. enum Part {
  169. case metadata
  170. case message
  171. case end
  172. }
  173. private let part: Part
  174. private let count: Int
  175. init(repeat part: Part, times count: Int) {
  176. self.part = part
  177. self.count = count
  178. }
  179. override func receive(
  180. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  181. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  182. ) {
  183. let count: Int
  184. switch (self.part, part) {
  185. case (.metadata, .metadata),
  186. (.message, .message),
  187. (.end, .end):
  188. count = self.count
  189. default:
  190. count = 1
  191. }
  192. for _ in 0 ..< count {
  193. context.receive(part)
  194. }
  195. }
  196. }
  197. class EchoFromInterceptor: Echo_EchoProvider {
  198. var interceptors: Echo_EchoServerInterceptorFactoryProtocol? = Interceptors()
  199. func get(
  200. request: Echo_EchoRequest,
  201. context: StatusOnlyCallContext
  202. ) -> EventLoopFuture<Echo_EchoResponse> {
  203. XCTFail("Unexpected call to \(#function)")
  204. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  205. }
  206. func expand(
  207. request: Echo_EchoRequest,
  208. context: StreamingResponseCallContext<Echo_EchoResponse>
  209. ) -> EventLoopFuture<GRPCStatus> {
  210. XCTFail("Unexpected call to \(#function)")
  211. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  212. }
  213. func collect(
  214. context: UnaryResponseCallContext<Echo_EchoResponse>
  215. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  216. XCTFail("Unexpected call to \(#function)")
  217. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  218. }
  219. func update(
  220. context: StreamingResponseCallContext<Echo_EchoResponse>
  221. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  222. XCTFail("Unexpected call to \(#function)")
  223. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  224. }
  225. class Interceptors: Echo_EchoServerInterceptorFactoryProtocol {
  226. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  227. return [Interceptor()]
  228. }
  229. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  230. return [Interceptor()]
  231. }
  232. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  233. return [Interceptor()]
  234. }
  235. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  236. return [Interceptor()]
  237. }
  238. }
  239. // Since all methods use the same request/response types, we can use a single interceptor to
  240. // respond to all of them.
  241. class Interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse> {
  242. private var collectedRequests: [Echo_EchoRequest] = []
  243. override func receive(
  244. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  245. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  246. ) {
  247. switch part {
  248. case .metadata:
  249. context.send(.metadata([:]), promise: nil)
  250. case let .message(request):
  251. if context.path.hasSuffix("Get") {
  252. // Unary, just reply.
  253. let response = Echo_EchoResponse.with {
  254. $0.text = "echo: \(request.text)"
  255. }
  256. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  257. } else if context.path.hasSuffix("Expand") {
  258. // Server streaming.
  259. let parts = request.text.split(separator: " ")
  260. let metadata = MessageMetadata(compress: false, flush: false)
  261. for part in parts {
  262. context.send(.message(.with { $0.text = "echo: \(part)" }, metadata), promise: nil)
  263. }
  264. } else if context.path.hasSuffix("Collect") {
  265. // Client streaming, store the requests, reply on '.end'
  266. self.collectedRequests.append(request)
  267. } else if context.path.hasSuffix("Update") {
  268. // Bidirectional streaming.
  269. let response = Echo_EchoResponse.with {
  270. $0.text = "echo: \(request.text)"
  271. }
  272. let metadata = MessageMetadata(compress: false, flush: true)
  273. context.send(.message(response, metadata), promise: nil)
  274. } else {
  275. XCTFail("Unexpected path '\(context.path)'")
  276. }
  277. case .end:
  278. if !self.collectedRequests.isEmpty {
  279. let response = Echo_EchoResponse.with {
  280. $0.text = "echo: " + self.collectedRequests.map { $0.text }.joined(separator: " ")
  281. }
  282. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  283. }
  284. context.send(.end(.ok, [:]), promise: nil)
  285. }
  286. }
  287. }
  288. }
  289. // Avoid having to serialize/deserialize messages in test cases.
  290. private class Codec: ChannelDuplexHandler {
  291. typealias InboundIn = GRPCServerRequestPart<Echo_EchoRequest>
  292. typealias InboundOut = GRPCServerRequestPart<ByteBuffer>
  293. typealias OutboundIn = GRPCServerResponsePart<ByteBuffer>
  294. typealias OutboundOut = GRPCServerResponsePart<Echo_EchoResponse>
  295. private let serializer = ProtobufSerializer<Echo_EchoRequest>()
  296. private let deserializer = ProtobufDeserializer<Echo_EchoResponse>()
  297. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  298. switch self.unwrapInboundIn(data) {
  299. case let .metadata(headers):
  300. context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
  301. case let .message(message):
  302. let serialized = try! self.serializer.serialize(message, allocator: context.channel.allocator)
  303. context.fireChannelRead(self.wrapInboundOut(.message(serialized)))
  304. case .end:
  305. context.fireChannelRead(self.wrapInboundOut(.end))
  306. }
  307. }
  308. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  309. switch self.unwrapOutboundIn(data) {
  310. case let .metadata(headers):
  311. context.write(self.wrapOutboundOut(.metadata(headers)), promise: promise)
  312. case let .message(message, metadata):
  313. let deserialzed = try! self.deserializer.deserialize(byteBuffer: message)
  314. context.write(self.wrapOutboundOut(.message(deserialzed, metadata)), promise: promise)
  315. case let .end(status, trailers):
  316. context.write(self.wrapOutboundOut(.end(status, trailers)), promise: promise)
  317. }
  318. }
  319. }