GRPCCustomPayloadTests.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPC
  17. import NIO
  18. import XCTest
  19. // These tests demonstrate how to use gRPC to create a service provider using your own payload type,
  20. // or alternatively, how to avoid deserialization and just extract the raw bytes from a payload.
  21. class GRPCCustomPayloadTests: GRPCTestCase {
  22. var group: EventLoopGroup!
  23. var server: Server!
  24. var client: AnyServiceClient!
  25. override func setUp() {
  26. super.setUp()
  27. self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  28. self.server = try! Server.insecure(group: self.group)
  29. .withServiceProviders([CustomPayloadProvider()])
  30. .bind(host: "localhost", port: 0)
  31. .wait()
  32. let channel = ClientConnection.insecure(group: self.group)
  33. .connect(host: "localhost", port: server.channel.localAddress!.port!)
  34. self.client = AnyServiceClient(channel: channel)
  35. }
  36. override func tearDown() {
  37. XCTAssertNoThrow(try self.server.close().wait())
  38. XCTAssertNoThrow(try self.client.channel.close().wait())
  39. XCTAssertNoThrow(try self.group.syncShutdownGracefully())
  40. }
  41. func testCustomPayload() throws {
  42. // This test demonstrates how to call a manually created bidirectional RPC with custom payloads.
  43. let statusExpectation = self.expectation(description: "status received")
  44. var responses: [CustomPayload] = []
  45. // Make a bidirectional stream using `CustomPayload` as the request and response type.
  46. // The service defined below is called "CustomPayload", and the method we call on it
  47. // is "AddOneAndReverseMessage"
  48. let rpc: BidirectionalStreamingCall<CustomPayload, CustomPayload> = self.client.makeBidirectionalStreamingCall(
  49. path: "/CustomPayload/AddOneAndReverseMessage",
  50. handler: { responses.append($0) }
  51. )
  52. // Make and send some requests:
  53. let requests: [CustomPayload] = [
  54. CustomPayload(message: "one", number: .random(in: Int64.min..<Int64.max)),
  55. CustomPayload(message: "two", number: .random(in: Int64.min..<Int64.max)),
  56. CustomPayload(message: "three", number: .random(in: Int64.min..<Int64.max))
  57. ]
  58. rpc.sendMessages(requests, promise: nil)
  59. rpc.sendEnd(promise: nil)
  60. // Wait for the RPC to finish before comparing responses.
  61. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  62. self.wait(for: [statusExpectation], timeout: 1.0)
  63. // Are the responses as expected?
  64. let expected = requests.map { request in
  65. CustomPayload(message: String(request.message.reversed()), number: request.number + 1)
  66. }
  67. XCTAssertEqual(responses, expected)
  68. }
  69. func testNoDeserializationOnTheClient() throws {
  70. // This test demonstrates how to skip the deserialization step on the client. It isn't necessary
  71. // to use a custom service provider to do this, although we do here.
  72. let statusExpectation = self.expectation(description: "status received")
  73. var responses: [IdentityPayload] = []
  74. // Here we use `IdentityPayload` for our response type: we define it below such that it does
  75. // not deserialize the bytes provided to it by gRPC.
  76. let rpc: BidirectionalStreamingCall<CustomPayload, IdentityPayload> = self.client.makeBidirectionalStreamingCall(
  77. path: "/CustomPayload/AddOneAndReverseMessage",
  78. handler: { responses.append($0) }
  79. )
  80. let request = CustomPayload(message: "message", number: 42)
  81. rpc.sendMessage(request, promise: nil)
  82. rpc.sendEnd(promise: nil)
  83. // Wait for the RPC to finish before comparing responses.
  84. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  85. self.wait(for: [statusExpectation], timeout: 1.0)
  86. guard var response = responses.first?.buffer else {
  87. XCTFail("RPC completed without a response")
  88. return
  89. }
  90. // We just took the raw bytes from the payload: we can still decode it because we know the
  91. // server returned a serialized `CustomPayload`.
  92. let actual = try CustomPayload(serializedByteBuffer: &response)
  93. XCTAssertEqual(actual.message, "egassem")
  94. XCTAssertEqual(actual.number, 43)
  95. }
  96. func testCustomPayloadUnary() throws {
  97. let rpc: UnaryCall<StringPayload, StringPayload> = self.client.makeUnaryCall(
  98. path: "/CustomPayload/Reverse",
  99. request: StringPayload(message: "foobarbaz")
  100. )
  101. XCTAssertEqual(try rpc.response.map { $0.message }.wait(), "zabraboof")
  102. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  103. }
  104. func testCustomPayloadClientStreaming() throws {
  105. let rpc: ClientStreamingCall<StringPayload, StringPayload> = self.client.makeClientStreamingCall(path: "/CustomPayload/ReverseThenJoin")
  106. rpc.sendMessages(["foo", "bar", "baz"].map(StringPayload.init(message:)), promise: nil)
  107. rpc.sendEnd(promise: nil)
  108. XCTAssertEqual(try rpc.response.map { $0.message }.wait(), "baz bar foo")
  109. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  110. }
  111. func testCustomPayloadServerStreaming() throws {
  112. let message = "abc"
  113. var expectedIterator = message.reversed().makeIterator()
  114. let rpc: ServerStreamingCall<StringPayload, StringPayload> = self.client.makeServerStreamingCall(
  115. path: "/CustomPayload/ReverseThenSplit",
  116. request: StringPayload(message: message)
  117. ) { response in
  118. if let next = expectedIterator.next() {
  119. XCTAssertEqual(String(next), response.message)
  120. } else {
  121. XCTFail("Unexpected message: \(response.message)")
  122. }
  123. }
  124. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  125. }
  126. }
  127. // MARK: Custom Payload Service
  128. fileprivate class CustomPayloadProvider: CallHandlerProvider {
  129. var serviceName: String = "CustomPayload"
  130. fileprivate func reverseString(
  131. request: StringPayload,
  132. context: UnaryResponseCallContext<StringPayload>
  133. ) -> EventLoopFuture<StringPayload> {
  134. let reversed = StringPayload(message: String(request.message.reversed()))
  135. return context.eventLoop.makeSucceededFuture(reversed)
  136. }
  137. fileprivate func reverseThenJoin(
  138. context: UnaryResponseCallContext<StringPayload>
  139. ) -> EventLoopFuture<(StreamEvent<StringPayload>) -> Void> {
  140. var messages: [String] = []
  141. return context.eventLoop.makeSucceededFuture({ event in
  142. switch event {
  143. case .message(let request):
  144. messages.append(request.message)
  145. case .end:
  146. let response = messages.reversed().joined(separator: " ")
  147. context.responsePromise.succeed(StringPayload(message: response))
  148. }
  149. })
  150. }
  151. fileprivate func reverseThenSplit(
  152. request: StringPayload,
  153. context: StreamingResponseCallContext<StringPayload>
  154. ) -> EventLoopFuture<GRPCStatus> {
  155. let responses = request.message.reversed().map {
  156. context.sendResponse(StringPayload(message: String($0)))
  157. }
  158. return EventLoopFuture.andAllSucceed(responses, on: context.eventLoop).map { .ok }
  159. }
  160. // Bidirectional RPC which returns a new `CustomPayload` for each `CustomPayload` received.
  161. // The returned payloads have their `message` reversed and their `number` incremented by one.
  162. fileprivate func addOneAndReverseMessage(
  163. context: StreamingResponseCallContext<CustomPayload>
  164. ) -> EventLoopFuture<(StreamEvent<CustomPayload>) -> Void> {
  165. return context.eventLoop.makeSucceededFuture({ event in
  166. switch event {
  167. case .message(let payload):
  168. let response = CustomPayload(
  169. message: String(payload.message.reversed()),
  170. number: payload.number + 1
  171. )
  172. _ = context.sendResponse(response)
  173. case .end:
  174. context.statusPromise.succeed(.ok)
  175. }
  176. })
  177. }
  178. func handleMethod(_ methodName: String, callHandlerContext: CallHandlerContext) -> GRPCCallHandler? {
  179. switch methodName {
  180. case "Reverse":
  181. return CallHandlerFactory.makeUnary(callHandlerContext: callHandlerContext) { context in
  182. return { request in
  183. return self.reverseString(request: request, context: context)
  184. }
  185. }
  186. case "ReverseThenJoin":
  187. return CallHandlerFactory.makeClientStreaming(callHandlerContext: callHandlerContext) { context in
  188. return self.reverseThenJoin(context: context)
  189. }
  190. case "ReverseThenSplit":
  191. return CallHandlerFactory.makeServerStreaming(callHandlerContext: callHandlerContext) { context in
  192. return { request in
  193. return self.reverseThenSplit(request: request, context: context)
  194. }
  195. }
  196. case "AddOneAndReverseMessage":
  197. return CallHandlerFactory.makeBidirectionalStreaming(callHandlerContext: callHandlerContext) { context in
  198. return self.addOneAndReverseMessage(context: context)
  199. }
  200. default:
  201. return nil
  202. }
  203. }
  204. }
  205. fileprivate struct IdentityPayload: GRPCPayload {
  206. var buffer: ByteBuffer
  207. init(serializedByteBuffer: inout ByteBuffer) throws {
  208. self.buffer = serializedByteBuffer
  209. }
  210. func serialize(into buffer: inout ByteBuffer) throws {
  211. // This will never be called, however, it could be implemented as a direct copy of the bytes
  212. // we hold, e.g.:
  213. //
  214. // var copy = self.buffer
  215. // buffer.writeBuffer(&copy)
  216. fatalError("Unimplemented")
  217. }
  218. }
  219. /// A toy custom payload which holds a `String` and an `Int64`.
  220. ///
  221. /// The payload is serialized as:
  222. /// - the `UInt32` encoded length of the message,
  223. /// - the UTF-8 encoded bytes of the message, and
  224. /// - the `Int64` bytes of the number.
  225. fileprivate struct CustomPayload: GRPCPayload, Equatable {
  226. var message: String
  227. var number: Int64
  228. init(message: String, number: Int64) {
  229. self.message = message
  230. self.number = number
  231. }
  232. init(serializedByteBuffer: inout ByteBuffer) throws {
  233. guard let messageLength = serializedByteBuffer.readInteger(as: UInt32.self),
  234. let message = serializedByteBuffer.readString(length: Int(messageLength)),
  235. let number = serializedByteBuffer.readInteger(as: Int64.self) else {
  236. throw GRPCError.DeserializationFailure()
  237. }
  238. self.message = message
  239. self.number = number
  240. }
  241. func serialize(into buffer: inout ByteBuffer) throws {
  242. buffer.writeInteger(UInt32(self.message.count))
  243. buffer.writeString(self.message)
  244. buffer.writeInteger(self.number)
  245. }
  246. }
  247. fileprivate struct StringPayload: GRPCPayload {
  248. var message: String
  249. init(message: String) {
  250. self.message = message
  251. }
  252. init(serializedByteBuffer: inout ByteBuffer) throws {
  253. self.message = serializedByteBuffer.readString(length: serializedByteBuffer.readableBytes)!
  254. }
  255. func serialize(into buffer: inout ByteBuffer) throws {
  256. buffer.writeString(self.message)
  257. }
  258. }