ServerInterceptorTests.swift 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import EchoImplementation
  17. import EchoModel
  18. import HelloWorldModel
  19. import NIOCore
  20. import NIOEmbedded
  21. import NIOHTTP1
  22. import SwiftProtobuf
  23. import XCTest
  24. @testable import GRPC
  25. extension GRPCServerHandlerProtocol {
  26. fileprivate func receiveRequest(_ request: Echo_EchoRequest) {
  27. let serializer = ProtobufSerializer<Echo_EchoRequest>()
  28. do {
  29. let buffer = try serializer.serialize(request, allocator: ByteBufferAllocator())
  30. self.receiveMessage(buffer)
  31. } catch {
  32. XCTFail("Unexpected error: \(error)")
  33. }
  34. }
  35. }
  36. class ServerInterceptorTests: GRPCTestCase {
  37. private let eventLoop = EmbeddedEventLoop()
  38. private var recorder: ResponseRecorder!
  39. override func setUp() {
  40. super.setUp()
  41. self.recorder = ResponseRecorder(eventLoop: self.eventLoop)
  42. }
  43. private func makeRecordingInterceptor()
  44. -> RecordingServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  45. {
  46. return .init()
  47. }
  48. private func echoProvider(
  49. interceptedBy interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  50. ) -> EchoProvider {
  51. return EchoProvider(interceptors: EchoInterceptorFactory(interceptor: interceptor))
  52. }
  53. private func makeHandlerContext(for path: String) -> CallHandlerContext {
  54. return CallHandlerContext(
  55. errorDelegate: nil,
  56. logger: self.serverLogger,
  57. encoding: .disabled,
  58. eventLoop: self.eventLoop,
  59. path: path,
  60. responseWriter: self.recorder,
  61. allocator: ByteBufferAllocator(),
  62. closeFuture: self.eventLoop.makeSucceededVoidFuture()
  63. )
  64. }
  65. // This is only useful for the type inference.
  66. private func request(
  67. _ request: GRPCServerRequestPart<Echo_EchoRequest>
  68. ) -> GRPCServerRequestPart<Echo_EchoRequest> {
  69. return request
  70. }
  71. private func handleMethod(
  72. _ method: Substring,
  73. using provider: CallHandlerProvider
  74. ) -> GRPCServerHandlerProtocol? {
  75. let path = "/\(provider.serviceName)/\(method)"
  76. let context = self.makeHandlerContext(for: path)
  77. return provider.handle(method: method, context: context)
  78. }
  79. fileprivate typealias ResponsePart = GRPCServerResponsePart<Echo_EchoResponse>
  80. func testPassThroughInterceptor() throws {
  81. let recordingInterceptor = self.makeRecordingInterceptor()
  82. let provider = self.echoProvider(interceptedBy: recordingInterceptor)
  83. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  84. // Send requests.
  85. handler.receiveMetadata([:])
  86. handler.receiveRequest(.with { $0.text = "" })
  87. handler.receiveEnd()
  88. // Expect responses.
  89. assertThat(self.recorder.metadata, .is(.some()))
  90. assertThat(self.recorder.messages.count, .is(1))
  91. assertThat(self.recorder.status, .is(.some()))
  92. // We expect 2 request parts: the provider responds before it sees end, that's fine.
  93. assertThat(recordingInterceptor.requestParts, .hasCount(2))
  94. assertThat(recordingInterceptor.requestParts[0], .is(.metadata()))
  95. assertThat(recordingInterceptor.requestParts[1], .is(.message()))
  96. assertThat(recordingInterceptor.responseParts, .hasCount(3))
  97. assertThat(recordingInterceptor.responseParts[0], .is(.metadata()))
  98. assertThat(recordingInterceptor.responseParts[1], .is(.message()))
  99. assertThat(recordingInterceptor.responseParts[2], .is(.end(status: .is(.ok))))
  100. }
  101. func testUnaryFromInterceptor() throws {
  102. let provider = EchoFromInterceptor()
  103. let handler = try assertNotNil(self.handleMethod("Get", using: provider))
  104. // Send the requests.
  105. handler.receiveMetadata([:])
  106. handler.receiveRequest(.with { $0.text = "foo" })
  107. handler.receiveEnd()
  108. // Get the responses.
  109. assertThat(self.recorder.metadata, .is(.some()))
  110. assertThat(self.recorder.messages.count, .is(1))
  111. assertThat(self.recorder.status, .is(.some()))
  112. }
  113. func testClientStreamingFromInterceptor() throws {
  114. let provider = EchoFromInterceptor()
  115. let handler = try assertNotNil(self.handleMethod("Collect", using: provider))
  116. // Send the requests.
  117. handler.receiveMetadata([:])
  118. for text in ["a", "b", "c"] {
  119. handler.receiveRequest(.with { $0.text = text })
  120. }
  121. handler.receiveEnd()
  122. // Get the responses.
  123. assertThat(self.recorder.metadata, .is(.some()))
  124. assertThat(self.recorder.messages.count, .is(1))
  125. assertThat(self.recorder.status, .is(.some()))
  126. }
  127. func testServerStreamingFromInterceptor() throws {
  128. let provider = EchoFromInterceptor()
  129. let handler = try assertNotNil(self.handleMethod("Expand", using: provider))
  130. // Send the requests.
  131. handler.receiveMetadata([:])
  132. handler.receiveRequest(.with { $0.text = "a b c" })
  133. handler.receiveEnd()
  134. // Get the responses.
  135. assertThat(self.recorder.metadata, .is(.some()))
  136. assertThat(self.recorder.messages.count, .is(3))
  137. assertThat(self.recorder.status, .is(.some()))
  138. }
  139. func testBidirectionalStreamingFromInterceptor() throws {
  140. let provider = EchoFromInterceptor()
  141. let handler = try assertNotNil(self.handleMethod("Update", using: provider))
  142. // Send the requests.
  143. handler.receiveMetadata([:])
  144. for text in ["a", "b", "c"] {
  145. handler.receiveRequest(.with { $0.text = text })
  146. }
  147. handler.receiveEnd()
  148. // Get the responses.
  149. assertThat(self.recorder.metadata, .is(.some()))
  150. assertThat(self.recorder.messages.count, .is(3))
  151. assertThat(self.recorder.status, .is(.some()))
  152. }
  153. }
  154. final class EchoInterceptorFactory: Echo_EchoServerInterceptorFactoryProtocol {
  155. private let interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>
  156. init(interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>) {
  157. self.interceptor = interceptor
  158. }
  159. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  160. return [self.interceptor]
  161. }
  162. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  163. return [self.interceptor]
  164. }
  165. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  166. return [self.interceptor]
  167. }
  168. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  169. return [self.interceptor]
  170. }
  171. }
  172. class ExtraRequestPartEmitter:
  173. ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>,
  174. @unchecked Sendable
  175. {
  176. enum Part {
  177. case metadata
  178. case message
  179. case end
  180. }
  181. private let part: Part
  182. private let count: Int
  183. init(repeat part: Part, times count: Int) {
  184. self.part = part
  185. self.count = count
  186. }
  187. override func receive(
  188. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  189. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  190. ) {
  191. let count: Int
  192. switch (self.part, part) {
  193. case (.metadata, .metadata),
  194. (.message, .message),
  195. (.end, .end):
  196. count = self.count
  197. default:
  198. count = 1
  199. }
  200. for _ in 0 ..< count {
  201. context.receive(part)
  202. }
  203. }
  204. }
  205. class EchoFromInterceptor: Echo_EchoProvider {
  206. var interceptors: Echo_EchoServerInterceptorFactoryProtocol? = Interceptors()
  207. func get(
  208. request: Echo_EchoRequest,
  209. context: StatusOnlyCallContext
  210. ) -> EventLoopFuture<Echo_EchoResponse> {
  211. XCTFail("Unexpected call to \(#function)")
  212. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  213. }
  214. func expand(
  215. request: Echo_EchoRequest,
  216. context: StreamingResponseCallContext<Echo_EchoResponse>
  217. ) -> EventLoopFuture<GRPCStatus> {
  218. XCTFail("Unexpected call to \(#function)")
  219. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  220. }
  221. func collect(
  222. context: UnaryResponseCallContext<Echo_EchoResponse>
  223. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  224. XCTFail("Unexpected call to \(#function)")
  225. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  226. }
  227. func update(
  228. context: StreamingResponseCallContext<Echo_EchoResponse>
  229. ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
  230. XCTFail("Unexpected call to \(#function)")
  231. return context.eventLoop.makeFailedFuture(GRPCStatus.processingError)
  232. }
  233. final class Interceptors: Echo_EchoServerInterceptorFactoryProtocol {
  234. func makeGetInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  235. return [Interceptor()]
  236. }
  237. func makeExpandInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  238. return [Interceptor()]
  239. }
  240. func makeCollectInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  241. return [Interceptor()]
  242. }
  243. func makeUpdateInterceptors() -> [ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>] {
  244. return [Interceptor()]
  245. }
  246. }
  247. // Since all methods use the same request/response types, we can use a single interceptor to
  248. // respond to all of them.
  249. class Interceptor: ServerInterceptor<Echo_EchoRequest, Echo_EchoResponse>, @unchecked Sendable {
  250. private var collectedRequests: [Echo_EchoRequest] = []
  251. override func receive(
  252. _ part: GRPCServerRequestPart<Echo_EchoRequest>,
  253. context: ServerInterceptorContext<Echo_EchoRequest, Echo_EchoResponse>
  254. ) {
  255. switch part {
  256. case .metadata:
  257. context.send(.metadata([:]), promise: nil)
  258. case let .message(request):
  259. if context.path.hasSuffix("Get") {
  260. // Unary, just reply.
  261. let response = Echo_EchoResponse.with {
  262. $0.text = "echo: \(request.text)"
  263. }
  264. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  265. } else if context.path.hasSuffix("Expand") {
  266. // Server streaming.
  267. let parts = request.text.split(separator: " ")
  268. let metadata = MessageMetadata(compress: false, flush: false)
  269. for part in parts {
  270. context.send(.message(.with { $0.text = "echo: \(part)" }, metadata), promise: nil)
  271. }
  272. } else if context.path.hasSuffix("Collect") {
  273. // Client streaming, store the requests, reply on '.end'
  274. self.collectedRequests.append(request)
  275. } else if context.path.hasSuffix("Update") {
  276. // Bidirectional streaming.
  277. let response = Echo_EchoResponse.with {
  278. $0.text = "echo: \(request.text)"
  279. }
  280. let metadata = MessageMetadata(compress: false, flush: true)
  281. context.send(.message(response, metadata), promise: nil)
  282. } else {
  283. XCTFail("Unexpected path '\(context.path)'")
  284. }
  285. case .end:
  286. if !self.collectedRequests.isEmpty {
  287. let response = Echo_EchoResponse.with {
  288. $0.text = "echo: " + self.collectedRequests.map { $0.text }.joined(separator: " ")
  289. }
  290. context.send(.message(response, .init(compress: false, flush: false)), promise: nil)
  291. }
  292. context.send(.end(.ok, [:]), promise: nil)
  293. }
  294. }
  295. }
  296. }
  297. // Avoid having to serialize/deserialize messages in test cases.
  298. private class Codec: ChannelDuplexHandler {
  299. typealias InboundIn = GRPCServerRequestPart<Echo_EchoRequest>
  300. typealias InboundOut = GRPCServerRequestPart<ByteBuffer>
  301. typealias OutboundIn = GRPCServerResponsePart<ByteBuffer>
  302. typealias OutboundOut = GRPCServerResponsePart<Echo_EchoResponse>
  303. private let serializer = ProtobufSerializer<Echo_EchoRequest>()
  304. private let deserializer = ProtobufDeserializer<Echo_EchoResponse>()
  305. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  306. switch self.unwrapInboundIn(data) {
  307. case let .metadata(headers):
  308. context.fireChannelRead(self.wrapInboundOut(.metadata(headers)))
  309. case let .message(message):
  310. let serialized = try! self.serializer.serialize(message, allocator: context.channel.allocator)
  311. context.fireChannelRead(self.wrapInboundOut(.message(serialized)))
  312. case .end:
  313. context.fireChannelRead(self.wrapInboundOut(.end))
  314. }
  315. }
  316. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  317. switch self.unwrapOutboundIn(data) {
  318. case let .metadata(headers):
  319. context.write(self.wrapOutboundOut(.metadata(headers)), promise: promise)
  320. case let .message(message, metadata):
  321. let deserialzed = try! self.deserializer.deserialize(byteBuffer: message)
  322. context.write(self.wrapOutboundOut(.message(deserialzed, metadata)), promise: promise)
  323. case let .end(status, trailers):
  324. context.write(self.wrapOutboundOut(.end(status, trailers)), promise: promise)
  325. }
  326. }
  327. }