GRPCCustomPayloadTests.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPC
  17. import NIOCore
  18. import NIOPosix
  19. import XCTest
  20. // These tests demonstrate how to use gRPC to create a service provider using your own payload type,
  21. // or alternatively, how to avoid deserialization and just extract the raw bytes from a payload.
  22. class GRPCCustomPayloadTests: GRPCTestCase {
  23. var group: EventLoopGroup!
  24. var server: Server!
  25. var client: AnyServiceClient!
  26. override func setUp() {
  27. super.setUp()
  28. self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  29. self.server = try! Server.insecure(group: self.group)
  30. .withServiceProviders([CustomPayloadProvider()])
  31. .withLogger(self.serverLogger)
  32. .bind(host: "localhost", port: 0)
  33. .wait()
  34. let channel = ClientConnection.insecure(group: self.group)
  35. .withBackgroundActivityLogger(self.clientLogger)
  36. .connect(host: "localhost", port: self.server.channel.localAddress!.port!)
  37. self.client = AnyServiceClient(channel: channel, defaultCallOptions: self.callOptionsWithLogger)
  38. }
  39. override func tearDown() {
  40. XCTAssertNoThrow(try self.server.close().wait())
  41. XCTAssertNoThrow(try self.client.channel.close().wait())
  42. XCTAssertNoThrow(try self.group.syncShutdownGracefully())
  43. super.tearDown()
  44. }
  45. func testCustomPayload() throws {
  46. // This test demonstrates how to call a manually created bidirectional RPC with custom payloads.
  47. let statusExpectation = self.expectation(description: "status received")
  48. var responses: [CustomPayload] = []
  49. // Make a bidirectional stream using `CustomPayload` as the request and response type.
  50. // The service defined below is called "CustomPayload", and the method we call on it
  51. // is "AddOneAndReverseMessage"
  52. let rpc: BidirectionalStreamingCall<CustomPayload, CustomPayload> = self.client
  53. .makeBidirectionalStreamingCall(
  54. path: "/CustomPayload/AddOneAndReverseMessage",
  55. handler: { responses.append($0) }
  56. )
  57. // Make and send some requests:
  58. let requests: [CustomPayload] = [
  59. CustomPayload(message: "one", number: .random(in: Int64.min ..< Int64.max)),
  60. CustomPayload(message: "two", number: .random(in: Int64.min ..< Int64.max)),
  61. CustomPayload(message: "three", number: .random(in: Int64.min ..< Int64.max)),
  62. ]
  63. rpc.sendMessages(requests, promise: nil)
  64. rpc.sendEnd(promise: nil)
  65. // Wait for the RPC to finish before comparing responses.
  66. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  67. self.wait(for: [statusExpectation], timeout: 1.0)
  68. // Are the responses as expected?
  69. let expected = requests.map { request in
  70. CustomPayload(message: String(request.message.reversed()), number: request.number + 1)
  71. }
  72. XCTAssertEqual(responses, expected)
  73. }
  74. func testNoDeserializationOnTheClient() throws {
  75. // This test demonstrates how to skip the deserialization step on the client. It isn't necessary
  76. // to use a custom service provider to do this, although we do here.
  77. let statusExpectation = self.expectation(description: "status received")
  78. var responses: [IdentityPayload] = []
  79. // Here we use `IdentityPayload` for our response type: we define it below such that it does
  80. // not deserialize the bytes provided to it by gRPC.
  81. let rpc: BidirectionalStreamingCall<CustomPayload, IdentityPayload> = self.client
  82. .makeBidirectionalStreamingCall(
  83. path: "/CustomPayload/AddOneAndReverseMessage",
  84. handler: { responses.append($0) }
  85. )
  86. let request = CustomPayload(message: "message", number: 42)
  87. rpc.sendMessage(request, promise: nil)
  88. rpc.sendEnd(promise: nil)
  89. // Wait for the RPC to finish before comparing responses.
  90. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  91. self.wait(for: [statusExpectation], timeout: 1.0)
  92. guard var response = responses.first?.buffer else {
  93. XCTFail("RPC completed without a response")
  94. return
  95. }
  96. // We just took the raw bytes from the payload: we can still decode it because we know the
  97. // server returned a serialized `CustomPayload`.
  98. let actual = try CustomPayload(serializedByteBuffer: &response)
  99. XCTAssertEqual(actual.message, "egassem")
  100. XCTAssertEqual(actual.number, 43)
  101. }
  102. func testCustomPayloadUnary() throws {
  103. let rpc: UnaryCall<StringPayload, StringPayload> = self.client.makeUnaryCall(
  104. path: "/CustomPayload/Reverse",
  105. request: StringPayload(message: "foobarbaz")
  106. )
  107. XCTAssertEqual(try rpc.response.map { $0.message }.wait(), "zabraboof")
  108. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  109. }
  110. func testCustomPayloadClientStreaming() throws {
  111. let rpc: ClientStreamingCall<StringPayload, StringPayload> = self.client
  112. .makeClientStreamingCall(path: "/CustomPayload/ReverseThenJoin")
  113. rpc.sendMessages(["foo", "bar", "baz"].map(StringPayload.init(message:)), promise: nil)
  114. rpc.sendEnd(promise: nil)
  115. XCTAssertEqual(try rpc.response.map { $0.message }.wait(), "baz bar foo")
  116. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  117. }
  118. func testCustomPayloadServerStreaming() throws {
  119. let message = "abc"
  120. var expectedIterator = message.reversed().makeIterator()
  121. let rpc: ServerStreamingCall<StringPayload, StringPayload> = self.client
  122. .makeServerStreamingCall(
  123. path: "/CustomPayload/ReverseThenSplit",
  124. request: StringPayload(message: message)
  125. ) { response in
  126. if let next = expectedIterator.next() {
  127. XCTAssertEqual(String(next), response.message)
  128. } else {
  129. XCTFail("Unexpected message: \(response.message)")
  130. }
  131. }
  132. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  133. }
  134. }
  135. // MARK: Custom Payload Service
  136. private class CustomPayloadProvider: CallHandlerProvider {
  137. var serviceName: Substring = "CustomPayload"
  138. fileprivate func reverseString(
  139. request: StringPayload,
  140. context: StatusOnlyCallContext
  141. ) -> EventLoopFuture<StringPayload> {
  142. let reversed = StringPayload(message: String(request.message.reversed()))
  143. return context.eventLoop.makeSucceededFuture(reversed)
  144. }
  145. fileprivate func reverseThenJoin(
  146. context: UnaryResponseCallContext<StringPayload>
  147. ) -> EventLoopFuture<(StreamEvent<StringPayload>) -> Void> {
  148. var messages: [String] = []
  149. return context.eventLoop.makeSucceededFuture({ event in
  150. switch event {
  151. case let .message(request):
  152. messages.append(request.message)
  153. case .end:
  154. let response = messages.reversed().joined(separator: " ")
  155. context.responsePromise.succeed(StringPayload(message: response))
  156. }
  157. })
  158. }
  159. fileprivate func reverseThenSplit(
  160. request: StringPayload,
  161. context: StreamingResponseCallContext<StringPayload>
  162. ) -> EventLoopFuture<GRPCStatus> {
  163. let responses = request.message.reversed().map {
  164. context.sendResponse(StringPayload(message: String($0)))
  165. }
  166. return EventLoopFuture.andAllSucceed(responses, on: context.eventLoop).map { .ok }
  167. }
  168. // Bidirectional RPC which returns a new `CustomPayload` for each `CustomPayload` received.
  169. // The returned payloads have their `message` reversed and their `number` incremented by one.
  170. fileprivate func addOneAndReverseMessage(
  171. context: StreamingResponseCallContext<CustomPayload>
  172. ) -> EventLoopFuture<(StreamEvent<CustomPayload>) -> Void> {
  173. return context.eventLoop.makeSucceededFuture({ event in
  174. switch event {
  175. case let .message(payload):
  176. let response = CustomPayload(
  177. message: String(payload.message.reversed()),
  178. number: payload.number + 1
  179. )
  180. _ = context.sendResponse(response)
  181. case .end:
  182. context.statusPromise.succeed(.ok)
  183. }
  184. })
  185. }
  186. func handle(method name: Substring, context: CallHandlerContext) -> GRPCServerHandlerProtocol? {
  187. switch name {
  188. case "Reverse":
  189. return UnaryServerHandler(
  190. context: context,
  191. requestDeserializer: GRPCPayloadDeserializer<StringPayload>(),
  192. responseSerializer: GRPCPayloadSerializer<StringPayload>(),
  193. interceptors: [],
  194. userFunction: self.reverseString(request:context:)
  195. )
  196. case "ReverseThenJoin":
  197. return ClientStreamingServerHandler(
  198. context: context,
  199. requestDeserializer: GRPCPayloadDeserializer<StringPayload>(),
  200. responseSerializer: GRPCPayloadSerializer<StringPayload>(),
  201. interceptors: [],
  202. observerFactory: self.reverseThenJoin(context:)
  203. )
  204. case "ReverseThenSplit":
  205. return ServerStreamingServerHandler(
  206. context: context,
  207. requestDeserializer: GRPCPayloadDeserializer<StringPayload>(),
  208. responseSerializer: GRPCPayloadSerializer<StringPayload>(),
  209. interceptors: [],
  210. userFunction: self.reverseThenSplit(request:context:)
  211. )
  212. case "AddOneAndReverseMessage":
  213. return BidirectionalStreamingServerHandler(
  214. context: context,
  215. requestDeserializer: GRPCPayloadDeserializer<CustomPayload>(),
  216. responseSerializer: GRPCPayloadSerializer<CustomPayload>(),
  217. interceptors: [],
  218. observerFactory: self.addOneAndReverseMessage(context:)
  219. )
  220. default:
  221. return nil
  222. }
  223. }
  224. }
  225. private struct IdentityPayload: GRPCPayload {
  226. var buffer: ByteBuffer
  227. init(serializedByteBuffer: inout ByteBuffer) throws {
  228. self.buffer = serializedByteBuffer
  229. }
  230. func serialize(into buffer: inout ByteBuffer) throws {
  231. // This will never be called, however, it could be implemented as a direct copy of the bytes
  232. // we hold, e.g.:
  233. //
  234. // var copy = self.buffer
  235. // buffer.writeBuffer(&copy)
  236. fatalError("Unimplemented")
  237. }
  238. }
  239. /// A toy custom payload which holds a `String` and an `Int64`.
  240. ///
  241. /// The payload is serialized as:
  242. /// - the `UInt32` encoded length of the message,
  243. /// - the UTF-8 encoded bytes of the message, and
  244. /// - the `Int64` bytes of the number.
  245. private struct CustomPayload: GRPCPayload, Equatable {
  246. var message: String
  247. var number: Int64
  248. init(message: String, number: Int64) {
  249. self.message = message
  250. self.number = number
  251. }
  252. init(serializedByteBuffer: inout ByteBuffer) throws {
  253. guard let messageLength = serializedByteBuffer.readInteger(as: UInt32.self),
  254. let message = serializedByteBuffer.readString(length: Int(messageLength)),
  255. let number = serializedByteBuffer.readInteger(as: Int64.self) else {
  256. throw GRPCError.DeserializationFailure()
  257. }
  258. self.message = message
  259. self.number = number
  260. }
  261. func serialize(into buffer: inout ByteBuffer) throws {
  262. buffer.writeInteger(UInt32(self.message.count))
  263. buffer.writeString(self.message)
  264. buffer.writeInteger(self.number)
  265. }
  266. }
  267. private struct StringPayload: GRPCPayload {
  268. var message: String
  269. init(message: String) {
  270. self.message = message
  271. }
  272. init(serializedByteBuffer: inout ByteBuffer) throws {
  273. self.message = serializedByteBuffer.readString(length: serializedByteBuffer.readableBytes)!
  274. }
  275. func serialize(into buffer: inout ByteBuffer) throws {
  276. buffer.writeString(self.message)
  277. }
  278. }