GRPCCustomPayloadTests.swift 11 KB


  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPC
  17. import NIO
  18. import XCTest
  19. // These tests demonstrate how to use gRPC to create a service provider using your own payload type,
  20. // or alternatively, how to avoid deserialization and just extract the raw bytes from a payload.
  21. class GRPCCustomPayloadTests: GRPCTestCase {
  22. var group: EventLoopGroup!
  23. var server: Server!
  24. var client: AnyServiceClient!
  25. override func setUp() {
  26. super.setUp()
  27. self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  28. self.server = try! Server.insecure(group: self.group)
  29. .withServiceProviders([CustomPayloadProvider()])
  30. .withLogger(self.serverLogger)
  31. .bind(host: "localhost", port: 0)
  32. .wait()
  33. let channel = ClientConnection.insecure(group: self.group)
  34. .withBackgroundActivityLogger(self.clientLogger)
  35. .connect(host: "localhost", port: server.channel.localAddress!.port!)
  36. self.client = AnyServiceClient(channel: channel, defaultCallOptions: self.callOptionsWithLogger)
  37. }
  38. override func tearDown() {
  39. XCTAssertNoThrow(try self.server.close().wait())
  40. XCTAssertNoThrow(try self.client.channel.close().wait())
  41. XCTAssertNoThrow(try self.group.syncShutdownGracefully())
  42. super.tearDown()
  43. }
  44. func testCustomPayload() throws {
  45. // This test demonstrates how to call a manually created bidirectional RPC with custom payloads.
  46. let statusExpectation = self.expectation(description: "status received")
  47. var responses: [CustomPayload] = []
  48. // Make a bidirectional stream using `CustomPayload` as the request and response type.
  49. // The service defined below is called "CustomPayload", and the method we call on it
  50. // is "AddOneAndReverseMessage"
  51. let rpc: BidirectionalStreamingCall<CustomPayload, CustomPayload> = self.client.makeBidirectionalStreamingCall(
  52. path: "/CustomPayload/AddOneAndReverseMessage",
  53. handler: { responses.append($0) }
  54. )
  55. // Make and send some requests:
  56. let requests: [CustomPayload] = [
  57. CustomPayload(message: "one", number: .random(in: Int64.min..<Int64.max)),
  58. CustomPayload(message: "two", number: .random(in: Int64.min..<Int64.max)),
  59. CustomPayload(message: "three", number: .random(in: Int64.min..<Int64.max))
  60. ]
  61. rpc.sendMessages(requests, promise: nil)
  62. rpc.sendEnd(promise: nil)
  63. // Wait for the RPC to finish before comparing responses.
  64. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  65. self.wait(for: [statusExpectation], timeout: 1.0)
  66. // Are the responses as expected?
  67. let expected = requests.map { request in
  68. CustomPayload(message: String(request.message.reversed()), number: request.number + 1)
  69. }
  70. XCTAssertEqual(responses, expected)
  71. }
  72. func testNoDeserializationOnTheClient() throws {
  73. // This test demonstrates how to skip the deserialization step on the client. It isn't necessary
  74. // to use a custom service provider to do this, although we do here.
  75. let statusExpectation = self.expectation(description: "status received")
  76. var responses: [IdentityPayload] = []
  77. // Here we use `IdentityPayload` for our response type: we define it below such that it does
  78. // not deserialize the bytes provided to it by gRPC.
  79. let rpc: BidirectionalStreamingCall<CustomPayload, IdentityPayload> = self.client.makeBidirectionalStreamingCall(
  80. path: "/CustomPayload/AddOneAndReverseMessage",
  81. handler: { responses.append($0) }
  82. )
  83. let request = CustomPayload(message: "message", number: 42)
  84. rpc.sendMessage(request, promise: nil)
  85. rpc.sendEnd(promise: nil)
  86. // Wait for the RPC to finish before comparing responses.
  87. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  88. self.wait(for: [statusExpectation], timeout: 1.0)
  89. guard var response = responses.first?.buffer else {
  90. XCTFail("RPC completed without a response")
  91. return
  92. }
  93. // We just took the raw bytes from the payload: we can still decode it because we know the
  94. // server returned a serialized `CustomPayload`.
  95. let actual = try CustomPayload(serializedByteBuffer: &response)
  96. XCTAssertEqual(actual.message, "egassem")
  97. XCTAssertEqual(actual.number, 43)
  98. }
  99. func testCustomPayloadUnary() throws {
  100. let rpc: UnaryCall<StringPayload, StringPayload> = self.client.makeUnaryCall(
  101. path: "/CustomPayload/Reverse",
  102. request: StringPayload(message: "foobarbaz")
  103. )
  104. XCTAssertEqual(try rpc.response.map { $0.message }.wait(), "zabraboof")
  105. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  106. }
  107. func testCustomPayloadClientStreaming() throws {
  108. let rpc: ClientStreamingCall<StringPayload, StringPayload> = self.client.makeClientStreamingCall(path: "/CustomPayload/ReverseThenJoin")
  109. rpc.sendMessages(["foo", "bar", "baz"].map(StringPayload.init(message:)), promise: nil)
  110. rpc.sendEnd(promise: nil)
  111. XCTAssertEqual(try rpc.response.map { $0.message }.wait(), "baz bar foo")
  112. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  113. }
  114. func testCustomPayloadServerStreaming() throws {
  115. let message = "abc"
  116. var expectedIterator = message.reversed().makeIterator()
  117. let rpc: ServerStreamingCall<StringPayload, StringPayload> = self.client.makeServerStreamingCall(
  118. path: "/CustomPayload/ReverseThenSplit",
  119. request: StringPayload(message: message)
  120. ) { response in
  121. if let next = expectedIterator.next() {
  122. XCTAssertEqual(String(next), response.message)
  123. } else {
  124. XCTFail("Unexpected message: \(response.message)")
  125. }
  126. }
  127. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  128. }
  129. }
  130. // MARK: Custom Payload Service
  131. fileprivate class CustomPayloadProvider: CallHandlerProvider {
  132. var serviceName: String = "CustomPayload"
  133. fileprivate func reverseString(
  134. request: StringPayload,
  135. context: UnaryResponseCallContext<StringPayload>
  136. ) -> EventLoopFuture<StringPayload> {
  137. let reversed = StringPayload(message: String(request.message.reversed()))
  138. return context.eventLoop.makeSucceededFuture(reversed)
  139. }
  140. fileprivate func reverseThenJoin(
  141. context: UnaryResponseCallContext<StringPayload>
  142. ) -> EventLoopFuture<(StreamEvent<StringPayload>) -> Void> {
  143. var messages: [String] = []
  144. return context.eventLoop.makeSucceededFuture({ event in
  145. switch event {
  146. case .message(let request):
  147. messages.append(request.message)
  148. case .end:
  149. let response = messages.reversed().joined(separator: " ")
  150. context.responsePromise.succeed(StringPayload(message: response))
  151. }
  152. })
  153. }
  154. fileprivate func reverseThenSplit(
  155. request: StringPayload,
  156. context: StreamingResponseCallContext<StringPayload>
  157. ) -> EventLoopFuture<GRPCStatus> {
  158. let responses = request.message.reversed().map {
  159. context.sendResponse(StringPayload(message: String($0)))
  160. }
  161. return EventLoopFuture.andAllSucceed(responses, on: context.eventLoop).map { .ok }
  162. }
  163. // Bidirectional RPC which returns a new `CustomPayload` for each `CustomPayload` received.
  164. // The returned payloads have their `message` reversed and their `number` incremented by one.
  165. fileprivate func addOneAndReverseMessage(
  166. context: StreamingResponseCallContext<CustomPayload>
  167. ) -> EventLoopFuture<(StreamEvent<CustomPayload>) -> Void> {
  168. return context.eventLoop.makeSucceededFuture({ event in
  169. switch event {
  170. case .message(let payload):
  171. let response = CustomPayload(
  172. message: String(payload.message.reversed()),
  173. number: payload.number + 1
  174. )
  175. _ = context.sendResponse(response)
  176. case .end:
  177. context.statusPromise.succeed(.ok)
  178. }
  179. })
  180. }
  181. func handleMethod(_ methodName: String, callHandlerContext: CallHandlerContext) -> GRPCCallHandler? {
  182. switch methodName {
  183. case "Reverse":
  184. return CallHandlerFactory.makeUnary(callHandlerContext: callHandlerContext) { context in
  185. return { request in
  186. return self.reverseString(request: request, context: context)
  187. }
  188. }
  189. case "ReverseThenJoin":
  190. return CallHandlerFactory.makeClientStreaming(callHandlerContext: callHandlerContext) { context in
  191. return self.reverseThenJoin(context: context)
  192. }
  193. case "ReverseThenSplit":
  194. return CallHandlerFactory.makeServerStreaming(callHandlerContext: callHandlerContext) { context in
  195. return { request in
  196. return self.reverseThenSplit(request: request, context: context)
  197. }
  198. }
  199. case "AddOneAndReverseMessage":
  200. return CallHandlerFactory.makeBidirectionalStreaming(callHandlerContext: callHandlerContext) { context in
  201. return self.addOneAndReverseMessage(context: context)
  202. }
  203. default:
  204. return nil
  205. }
  206. }
  207. }
  208. fileprivate struct IdentityPayload: GRPCPayload {
  209. var buffer: ByteBuffer
  210. init(serializedByteBuffer: inout ByteBuffer) throws {
  211. self.buffer = serializedByteBuffer
  212. }
  213. func serialize(into buffer: inout ByteBuffer) throws {
  214. // This will never be called, however, it could be implemented as a direct copy of the bytes
  215. // we hold, e.g.:
  216. //
  217. // var copy = self.buffer
  218. // buffer.writeBuffer(&copy)
  219. fatalError("Unimplemented")
  220. }
  221. }
  222. /// A toy custom payload which holds a `String` and an `Int64`.
  223. ///
  224. /// The payload is serialized as:
  225. /// - the `UInt32` encoded length of the message,
  226. /// - the UTF-8 encoded bytes of the message, and
  227. /// - the `Int64` bytes of the number.
  228. fileprivate struct CustomPayload: GRPCPayload, Equatable {
  229. var message: String
  230. var number: Int64
  231. init(message: String, number: Int64) {
  232. self.message = message
  233. self.number = number
  234. }
  235. init(serializedByteBuffer: inout ByteBuffer) throws {
  236. guard let messageLength = serializedByteBuffer.readInteger(as: UInt32.self),
  237. let message = serializedByteBuffer.readString(length: Int(messageLength)),
  238. let number = serializedByteBuffer.readInteger(as: Int64.self) else {
  239. throw GRPCError.DeserializationFailure()
  240. }
  241. self.message = message
  242. self.number = number
  243. }
  244. func serialize(into buffer: inout ByteBuffer) throws {
  245. buffer.writeInteger(UInt32(self.message.count))
  246. buffer.writeString(self.message)
  247. buffer.writeInteger(self.number)
  248. }
  249. }
  250. fileprivate struct StringPayload: GRPCPayload {
  251. var message: String
  252. init(message: String) {
  253. self.message = message
  254. }
  255. init(serializedByteBuffer: inout ByteBuffer) throws {
  256. self.message = serializedByteBuffer.readString(length: serializedByteBuffer.readableBytes)!
  257. }
  258. func serialize(into buffer: inout ByteBuffer) throws {
  259. buffer.writeString(self.message)
  260. }
  261. }