GRPCCustomPayloadTests.swift 12 KB


  1. /*
  2. * Copyright 2020, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPC
  17. import NIOCore
  18. import NIOPosix
  19. import XCTest
  20. // These tests demonstrate how to use gRPC to create a service provider using your own payload type,
  21. // or alternatively, how to avoid deserialization and just extract the raw bytes from a payload.
  22. class GRPCCustomPayloadTests: GRPCTestCase {
  23. var group: EventLoopGroup!
  24. var server: Server!
  25. var client: GRPCAnyServiceClient!
  26. override func setUp() {
  27. super.setUp()
  28. self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
  29. self.server = try! Server.insecure(group: self.group)
  30. .withServiceProviders([CustomPayloadProvider()])
  31. .withLogger(self.serverLogger)
  32. .bind(host: "localhost", port: 0)
  33. .wait()
  34. let channel = ClientConnection.insecure(group: self.group)
  35. .withBackgroundActivityLogger(self.clientLogger)
  36. .connect(host: "localhost", port: self.server.channel.localAddress!.port!)
  37. self.client = GRPCAnyServiceClient(
  38. channel: channel,
  39. defaultCallOptions: self.callOptionsWithLogger
  40. )
  41. }
  42. override func tearDown() {
  43. XCTAssertNoThrow(try self.server.close().wait())
  44. XCTAssertNoThrow(try self.client.channel.close().wait())
  45. XCTAssertNoThrow(try self.group.syncShutdownGracefully())
  46. super.tearDown()
  47. }
  48. func testCustomPayload() throws {
  49. // This test demonstrates how to call a manually created bidirectional RPC with custom payloads.
  50. let statusExpectation = self.expectation(description: "status received")
  51. var responses: [CustomPayload] = []
  52. // Make a bidirectional stream using `CustomPayload` as the request and response type.
  53. // The service defined below is called "CustomPayload", and the method we call on it
  54. // is "AddOneAndReverseMessage"
  55. let rpc: BidirectionalStreamingCall<CustomPayload, CustomPayload> = self.client
  56. .makeBidirectionalStreamingCall(
  57. path: "/CustomPayload/AddOneAndReverseMessage",
  58. handler: { responses.append($0) }
  59. )
  60. // Make and send some requests:
  61. let requests: [CustomPayload] = [
  62. CustomPayload(message: "one", number: .random(in: Int64.min ..< Int64.max)),
  63. CustomPayload(message: "two", number: .random(in: Int64.min ..< Int64.max)),
  64. CustomPayload(message: "three", number: .random(in: Int64.min ..< Int64.max)),
  65. ]
  66. rpc.sendMessages(requests, promise: nil)
  67. rpc.sendEnd(promise: nil)
  68. // Wait for the RPC to finish before comparing responses.
  69. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  70. self.wait(for: [statusExpectation], timeout: 1.0)
  71. // Are the responses as expected?
  72. let expected = requests.map { request in
  73. CustomPayload(message: String(request.message.reversed()), number: request.number + 1)
  74. }
  75. XCTAssertEqual(responses, expected)
  76. }
  77. func testNoDeserializationOnTheClient() throws {
  78. // This test demonstrates how to skip the deserialization step on the client. It isn't necessary
  79. // to use a custom service provider to do this, although we do here.
  80. let statusExpectation = self.expectation(description: "status received")
  81. var responses: [IdentityPayload] = []
  82. // Here we use `IdentityPayload` for our response type: we define it below such that it does
  83. // not deserialize the bytes provided to it by gRPC.
  84. let rpc: BidirectionalStreamingCall<CustomPayload, IdentityPayload> = self.client
  85. .makeBidirectionalStreamingCall(
  86. path: "/CustomPayload/AddOneAndReverseMessage",
  87. handler: { responses.append($0) }
  88. )
  89. let request = CustomPayload(message: "message", number: 42)
  90. rpc.sendMessage(request, promise: nil)
  91. rpc.sendEnd(promise: nil)
  92. // Wait for the RPC to finish before comparing responses.
  93. rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
  94. self.wait(for: [statusExpectation], timeout: 1.0)
  95. guard var response = responses.first?.buffer else {
  96. XCTFail("RPC completed without a response")
  97. return
  98. }
  99. // We just took the raw bytes from the payload: we can still decode it because we know the
  100. // server returned a serialized `CustomPayload`.
  101. let actual = try CustomPayload(serializedByteBuffer: &response)
  102. XCTAssertEqual(actual.message, "egassem")
  103. XCTAssertEqual(actual.number, 43)
  104. }
  105. func testCustomPayloadUnary() throws {
  106. let rpc: UnaryCall<StringPayload, StringPayload> = self.client.makeUnaryCall(
  107. path: "/CustomPayload/Reverse",
  108. request: StringPayload(message: "foobarbaz")
  109. )
  110. XCTAssertEqual(try rpc.response.map { $0.message }.wait(), "zabraboof")
  111. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  112. }
  113. func testCustomPayloadClientStreaming() throws {
  114. let rpc: ClientStreamingCall<StringPayload, StringPayload> = self.client
  115. .makeClientStreamingCall(path: "/CustomPayload/ReverseThenJoin")
  116. rpc.sendMessages(["foo", "bar", "baz"].map(StringPayload.init(message:)), promise: nil)
  117. rpc.sendEnd(promise: nil)
  118. XCTAssertEqual(try rpc.response.map { $0.message }.wait(), "baz bar foo")
  119. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  120. }
  121. func testCustomPayloadServerStreaming() throws {
  122. let message = "abc"
  123. var expectedIterator = message.reversed().makeIterator()
  124. let rpc: ServerStreamingCall<StringPayload, StringPayload> = self.client
  125. .makeServerStreamingCall(
  126. path: "/CustomPayload/ReverseThenSplit",
  127. request: StringPayload(message: message)
  128. ) { response in
  129. if let next = expectedIterator.next() {
  130. XCTAssertEqual(String(next), response.message)
  131. } else {
  132. XCTFail("Unexpected message: \(response.message)")
  133. }
  134. }
  135. XCTAssertEqual(try rpc.status.map { $0.code }.wait(), .ok)
  136. }
  137. }
  138. // MARK: Custom Payload Service
  139. private class CustomPayloadProvider: CallHandlerProvider {
  140. var serviceName: Substring = "CustomPayload"
  141. fileprivate func reverseString(
  142. request: StringPayload,
  143. context: StatusOnlyCallContext
  144. ) -> EventLoopFuture<StringPayload> {
  145. let reversed = StringPayload(message: String(request.message.reversed()))
  146. return context.eventLoop.makeSucceededFuture(reversed)
  147. }
  148. fileprivate func reverseThenJoin(
  149. context: UnaryResponseCallContext<StringPayload>
  150. ) -> EventLoopFuture<(StreamEvent<StringPayload>) -> Void> {
  151. var messages: [String] = []
  152. return context.eventLoop.makeSucceededFuture({ event in
  153. switch event {
  154. case let .message(request):
  155. messages.append(request.message)
  156. case .end:
  157. let response = messages.reversed().joined(separator: " ")
  158. context.responsePromise.succeed(StringPayload(message: response))
  159. }
  160. })
  161. }
  162. fileprivate func reverseThenSplit(
  163. request: StringPayload,
  164. context: StreamingResponseCallContext<StringPayload>
  165. ) -> EventLoopFuture<GRPCStatus> {
  166. let responses = request.message.reversed().map {
  167. context.sendResponse(StringPayload(message: String($0)))
  168. }
  169. return EventLoopFuture.andAllSucceed(responses, on: context.eventLoop).map { .ok }
  170. }
  171. // Bidirectional RPC which returns a new `CustomPayload` for each `CustomPayload` received.
  172. // The returned payloads have their `message` reversed and their `number` incremented by one.
  173. fileprivate func addOneAndReverseMessage(
  174. context: StreamingResponseCallContext<CustomPayload>
  175. ) -> EventLoopFuture<(StreamEvent<CustomPayload>) -> Void> {
  176. return context.eventLoop.makeSucceededFuture({ event in
  177. switch event {
  178. case let .message(payload):
  179. let response = CustomPayload(
  180. message: String(payload.message.reversed()),
  181. number: payload.number + 1
  182. )
  183. _ = context.sendResponse(response)
  184. case .end:
  185. context.statusPromise.succeed(.ok)
  186. }
  187. })
  188. }
  189. func handle(method name: Substring, context: CallHandlerContext) -> GRPCServerHandlerProtocol? {
  190. switch name {
  191. case "Reverse":
  192. return UnaryServerHandler(
  193. context: context,
  194. requestDeserializer: GRPCPayloadDeserializer<StringPayload>(),
  195. responseSerializer: GRPCPayloadSerializer<StringPayload>(),
  196. interceptors: [],
  197. userFunction: self.reverseString(request:context:)
  198. )
  199. case "ReverseThenJoin":
  200. return ClientStreamingServerHandler(
  201. context: context,
  202. requestDeserializer: GRPCPayloadDeserializer<StringPayload>(),
  203. responseSerializer: GRPCPayloadSerializer<StringPayload>(),
  204. interceptors: [],
  205. observerFactory: self.reverseThenJoin(context:)
  206. )
  207. case "ReverseThenSplit":
  208. return ServerStreamingServerHandler(
  209. context: context,
  210. requestDeserializer: GRPCPayloadDeserializer<StringPayload>(),
  211. responseSerializer: GRPCPayloadSerializer<StringPayload>(),
  212. interceptors: [],
  213. userFunction: self.reverseThenSplit(request:context:)
  214. )
  215. case "AddOneAndReverseMessage":
  216. return BidirectionalStreamingServerHandler(
  217. context: context,
  218. requestDeserializer: GRPCPayloadDeserializer<CustomPayload>(),
  219. responseSerializer: GRPCPayloadSerializer<CustomPayload>(),
  220. interceptors: [],
  221. observerFactory: self.addOneAndReverseMessage(context:)
  222. )
  223. default:
  224. return nil
  225. }
  226. }
  227. }
  228. private struct IdentityPayload: GRPCPayload {
  229. var buffer: ByteBuffer
  230. init(serializedByteBuffer: inout ByteBuffer) throws {
  231. self.buffer = serializedByteBuffer
  232. }
  233. func serialize(into buffer: inout ByteBuffer) throws {
  234. // This will never be called, however, it could be implemented as a direct copy of the bytes
  235. // we hold, e.g.:
  236. //
  237. // var copy = self.buffer
  238. // buffer.writeBuffer(&copy)
  239. fatalError("Unimplemented")
  240. }
  241. }
  242. /// A toy custom payload which holds a `String` and an `Int64`.
  243. ///
  244. /// The payload is serialized as:
  245. /// - the `UInt32` encoded length of the message,
  246. /// - the UTF-8 encoded bytes of the message, and
  247. /// - the `Int64` bytes of the number.
  248. private struct CustomPayload: GRPCPayload, Equatable {
  249. var message: String
  250. var number: Int64
  251. init(message: String, number: Int64) {
  252. self.message = message
  253. self.number = number
  254. }
  255. init(serializedByteBuffer: inout ByteBuffer) throws {
  256. guard let messageLength = serializedByteBuffer.readInteger(as: UInt32.self),
  257. let message = serializedByteBuffer.readString(length: Int(messageLength)),
  258. let number = serializedByteBuffer.readInteger(as: Int64.self)
  259. else {
  260. throw GRPCError.DeserializationFailure()
  261. }
  262. self.message = message
  263. self.number = number
  264. }
  265. func serialize(into buffer: inout ByteBuffer) throws {
  266. buffer.writeInteger(UInt32(self.message.count))
  267. buffer.writeString(self.message)
  268. buffer.writeInteger(self.number)
  269. }
  270. }
  271. private struct StringPayload: GRPCPayload {
  272. var message: String
  273. init(message: String) {
  274. self.message = message
  275. }
  276. init(serializedByteBuffer: inout ByteBuffer) throws {
  277. self.message = serializedByteBuffer.readString(length: serializedByteBuffer.readableBytes)!
  278. }
  279. func serialize(into buffer: inout ByteBuffer) throws {
  280. buffer.writeString(self.message)
  281. }
  282. }