GRPCCustomPayloadTests.swift 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPC
  17. import NIO
  18. import XCTest
  19. // These tests demonstrate how to use gRPC to create a service provider using your own payload type,
  20. // or alternatively, how to avoid deserialization and just extract the raw bytes from a payload.
  21. class GRPCCustomPayloadTests: GRPCTestCase {
  22. var group: EventLoopGroup!
  23. var server: Server!
  24. var client: AnyServiceClient!
  25. override func setUp() {
  26. super.setUp()
  27. self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  28. let serverConfig: Server.Configuration = .init(
  29. target: .hostAndPort("localhost", 0),
  30. eventLoopGroup: self.group,
  31. serviceProviders: [CustomPayloadProvider()]
  32. )
  33. self.server = try! Server.start(configuration: serverConfig).wait()
  34. let clientConfig: ClientConnection.Configuration = .init(
  35. target: .hostAndPort("localhost", server.channel.localAddress!.port!),
  36. eventLoopGroup: self.group
  37. )
  38. self.client = AnyServiceClient(channel: ClientConnection(configuration: clientConfig))
  39. }
  40. override func tearDown() {
  41. XCTAssertNoThrow(try self.server.close().wait())
  42. XCTAssertNoThrow(try self.client.channel.close().wait())
  43. XCTAssertNoThrow(try self.group.syncShutdownGracefully())
  44. }
  45. func testCustomPayload() throws {
  46. // This test demonstrates how to call a manually created bidirectional RPC with custom payloads.
  47. let statusExpectation = self.expectation(description: "status received")
  48. var responses: [CustomPayload] = []
  49. // Make a bidirectional stream using `CustomPayload` as the request and response type.
  50. // The service defined below is called "CustomPayload", and the method we call on it
  51. // is "AddOneAndReverseMessage"
  52. let rpc: BidirectionalStreamingCall<CustomPayload, CustomPayload> = self.client.makeBidirectionalStreamingCall(
  53. path: "/CustomPayload/AddOneAndReverseMessage",
  54. handler: { responses.append($0) }
  55. )
  56. // Make and send some requests:
  57. let requests: [CustomPayload] = [
  58. CustomPayload(message: "one", number: .random(in: Int64.min..<Int64.max)),
  59. CustomPayload(message: "two", number: .random(in: Int64.min..<Int64.max)),
  60. CustomPayload(message: "three", number: .random(in: Int64.min..<Int64.max))
  61. ]
  62. rpc.sendMessages(requests, promise: nil)
  63. rpc.sendEnd(promise: nil)
  64. // Wait for the RPC to finish before comparing responses.
  65. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  66. self.wait(for: [statusExpectation], timeout: 1.0)
  67. // Are the responses as expected?
  68. let expected = requests.map { request in
  69. CustomPayload(message: String(request.message.reversed()), number: request.number + 1)
  70. }
  71. XCTAssertEqual(responses, expected)
  72. }
  73. func testNoDeserializationOnTheClient() throws {
  74. // This test demonstrates how to skip the deserialization step on the client. It isn't necessary
  75. // to use a custom service provider to do this, although we do here.
  76. let statusExpectation = self.expectation(description: "status received")
  77. var responses: [IdentityPayload] = []
  78. // Here we use `IdentityPayload` for our response type: we define it below such that it does
  79. // not deserialize the bytes provided to it by gRPC.
  80. let rpc: BidirectionalStreamingCall<CustomPayload, IdentityPayload> = self.client.makeBidirectionalStreamingCall(
  81. path: "/CustomPayload/AddOneAndReverseMessage",
  82. handler: { responses.append($0) }
  83. )
  84. let request = CustomPayload(message: "message", number: 42)
  85. rpc.sendMessage(request, promise: nil)
  86. rpc.sendEnd(promise: nil)
  87. // Wait for the RPC to finish before comparing responses.
  88. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  89. self.wait(for: [statusExpectation], timeout: 1.0)
  90. guard var response = responses.first?.buffer else {
  91. XCTFail("RPC completed without a response")
  92. return
  93. }
  94. // We just took the raw bytes from the payload: we can still decode it because we know the
  95. // server returned a serialized `CustomPayload`.
  96. let actual = try CustomPayload(serializedByteBuffer: &response)
  97. XCTAssertEqual(actual.message, "egassem")
  98. XCTAssertEqual(actual.number, 43)
  99. }
  100. }
  101. // MARK: Custom Payload Service
  102. fileprivate class CustomPayloadProvider: CallHandlerProvider {
  103. var serviceName: String = "CustomPayload"
  104. // Bidirectional RPC which returns a new `CustomPayload` for each `CustomPayload` received.
  105. // The returned payloads have their `message` reversed and their `number` incremented by one.
  106. fileprivate func addOneAndReverseMessage(
  107. context: StreamingResponseCallContext<CustomPayload>
  108. ) -> EventLoopFuture<(StreamEvent<CustomPayload>) -> Void> {
  109. return context.eventLoop.makeSucceededFuture({ event in
  110. switch event {
  111. case .message(let payload):
  112. let response = CustomPayload(
  113. message: String(payload.message.reversed()),
  114. number: payload.number + 1
  115. )
  116. _ = context.sendResponse(response)
  117. case .end:
  118. context.statusPromise.succeed(.ok)
  119. }
  120. })
  121. }
  122. func handleMethod(_ methodName: String, callHandlerContext: CallHandlerContext) -> GRPCCallHandler? {
  123. switch methodName {
  124. case "AddOneAndReverseMessage":
  125. return BidirectionalStreamingCallHandler<CustomPayload, CustomPayload>(callHandlerContext: callHandlerContext) { context in
  126. return self.addOneAndReverseMessage(context: context)
  127. }
  128. default:
  129. return nil
  130. }
  131. }
  132. }
  133. fileprivate struct IdentityPayload: GRPCPayload {
  134. var buffer: ByteBuffer
  135. init(serializedByteBuffer: inout ByteBuffer) throws {
  136. self.buffer = serializedByteBuffer
  137. }
  138. func serialize(into buffer: inout ByteBuffer) throws {
  139. // This will never be called, however, it could be implemented as a direct copy of the bytes
  140. // we hold, e.g.:
  141. //
  142. // var copy = self.buffer
  143. // buffer.writeBuffer(&copy)
  144. fatalError("Unimplemented")
  145. }
  146. }
  147. /// A toy custom payload which holds a `String` and an `Int64`.
  148. ///
  149. /// The payload is serialized as:
  150. /// - the `UInt32` encoded length of the message,
  151. /// - the UTF-8 encoded bytes of the message, and
  152. /// - the `Int64` bytes of the number.
  153. fileprivate struct CustomPayload: GRPCPayload, Equatable {
  154. var message: String
  155. var number: Int64
  156. init(message: String, number: Int64) {
  157. self.message = message
  158. self.number = number
  159. }
  160. init(serializedByteBuffer: inout ByteBuffer) throws {
  161. guard let messageLength = serializedByteBuffer.readInteger(as: UInt32.self),
  162. let message = serializedByteBuffer.readString(length: Int(messageLength)),
  163. let number = serializedByteBuffer.readInteger(as: Int64.self) else {
  164. throw GRPCError.DeserializationFailure()
  165. }
  166. self.message = message
  167. self.number = number
  168. }
  169. func serialize(into buffer: inout ByteBuffer) throws {
  170. buffer.writeInteger(UInt32(self.message.count))
  171. buffer.writeString(self.message)
  172. buffer.writeInteger(self.number)
  173. }
  174. }