Przeglądaj źródła

Add tests for custom payloads (#719)

Motivation:

We recently added support for arbitrary payloads which conform to
`GRPCPayload`, however, we didn't have any tests demonstrating this.

Modifications:

- Add a test for custom payloads types
- Add a test for payload which skips deserialization
- Update Protobuf to 1.8.0 and deserialize from ContiguousBytes

Result:

Better tests.
George Barnett 6 lat temu
rodzic
commit
4602c5c6ea

+ 2 - 2
Package.resolved

@@ -51,8 +51,8 @@
         "repositoryURL": "https://github.com/apple/swift-protobuf.git",
         "state": {
           "branch": null,
-          "revision": "da75a93ac017534e0028e83c0e4fc4610d2acf48",
-          "version": "1.7.0"
+          "revision": "7790acf0a81d08429cb20375bf42a8c7d279c5fe",
+          "version": "1.8.0"
         }
       }
     ]

+ 1 - 1
Package.swift

@@ -39,7 +39,7 @@ let package = Package(
     .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.3.0"),
 
     // Official SwiftProtobuf library, for [de]serializing data to send on the wire.
-    .package(url: "https://github.com/apple/swift-protobuf.git", from: "1.7.0"),
+    .package(url: "https://github.com/apple/swift-protobuf.git", from: "1.8.0"),
 
     // Logging API.
     .package(url: "https://github.com/apple/swift-log", from: "1.0.0"),

+ 11 - 11
Sources/GRPC/GRPCPayload.swift

@@ -15,19 +15,19 @@
  */
 import NIO
 
-/// Data passed through the library is required to conform to this GRPCPayload protocol
+/// A data type which may be serialized into and out from a `ByteBuffer` in order to be sent between
+/// gRPC peers.
 public protocol GRPCPayload {
-  
-  /// Initializes a new payload object from a given `NIO.ByteBuffer`
+  /// Initializes a new payload by deserializing the bytes from the given `ByteBuffer`.
   ///
-  /// - Parameter serializedByteBuffer: A buffer containing the serialised bytes of this payload.
-  /// - Throws: Should throw an error if the data wasn't serialized properly
-  init(serializedByteBuffer: inout NIO.ByteBuffer) throws
+  /// - Parameter serializedByteBuffer: A buffer containing the serialized bytes of this payload.
+  /// - Throws: If the payload could not be deserialized from the buffer.
+  init(serializedByteBuffer: inout ByteBuffer) throws
 
-  /// Serializes the payload into a `ByteBuffer`.
+  /// Serializes the payload into the given `ByteBuffer`.
   ///
-  /// - Parameters:
-  ///   - buffer: The buffer to write the payload into.
-  /// - Note: Implementers must *NOT* clear or read bytes from `buffer`.
-  func serialize(into buffer: inout NIO.ByteBuffer) throws
+  /// - Parameter buffer: The buffer to write the serialized payload into.
+  /// - Throws: If the payload could not be serialized.
+  /// - Important: Implementers must *NOT* clear or read bytes from `buffer`.
+  func serialize(into buffer: inout ByteBuffer) throws
 }

+ 2 - 3
Sources/GRPC/GRPCProtobufPayload.swift

@@ -16,13 +16,12 @@
 import NIO
 import SwiftProtobuf
 
-/// GRPCProtobufPayload which allows Protobuf Messages to be passed into the library
+/// Provides default implementations of `GRPCPayload` for `SwiftProtobuf.Message`s.
 public protocol GRPCProtobufPayload: GRPCPayload, Message {}
 
 public extension GRPCProtobufPayload {
-  
   init(serializedByteBuffer: inout NIO.ByteBuffer) throws {
-    try self.init(serializedData: serializedByteBuffer.readData(length: serializedByteBuffer.readableBytes)!)
+    try self.init(contiguousBytes: serializedByteBuffer.readableBytesView)
   }
 
   func serialize(into buffer: inout NIO.ByteBuffer) throws {

+ 207 - 0
Tests/GRPCTests/GRPCCustomPayloadTests.swift

@@ -0,0 +1,207 @@
+/*
+ * Copyright 2020, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import GRPC
+import NIO
+import XCTest
+
+// These tests demonstrate how to use gRPC to create a service provider using your own payload type,
+// or alternatively, how to avoid deserialization and just extract the raw bytes from a payload.
+class GRPCCustomPayloadTests: GRPCTestCase {
+  var group: EventLoopGroup!
+  var server: Server!
+  var client: AnyServiceClient!
+
+  override func setUp() {
+    super.setUp()
+    self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+
+    let serverConfig: Server.Configuration = .init(
+      target: .hostAndPort("localhost", 0),
+      eventLoopGroup: self.group,
+      serviceProviders: [CustomPayloadProvider()]
+    )
+
+    self.server = try! Server.start(configuration: serverConfig).wait()
+
+    let clientConfig: ClientConnection.Configuration = .init(
+      target: .hostAndPort("localhost", server.channel.localAddress!.port!),
+      eventLoopGroup: self.group
+    )
+
+    self.client = AnyServiceClient(connection: .init(configuration: clientConfig))
+  }
+
+  override func tearDown() {
+    XCTAssertNoThrow(try self.server.close().wait())
+    XCTAssertNoThrow(try self.client.connection.close().wait())
+    XCTAssertNoThrow(try self.group.syncShutdownGracefully())
+  }
+
+  func testCustomPayload() throws {
+    // This test demonstrates how to call a manually created bidirectional RPC with custom payloads.
+    let statusExpectation = self.expectation(description: "status received")
+
+    var responses: [CustomPayload] = []
+
+    // Make a bidirectional stream using `CustomPayload` as the request and response type.
+    // The service defined below is called "CustomPayload", and the method we call on it
+    // is "AddOneAndReverseMessage"
+    let rpc: BidirectionalStreamingCall<CustomPayload, CustomPayload> = self.client.makeBidirectionalStreamingCall(
+      path: "/CustomPayload/AddOneAndReverseMessage",
+      handler: { responses.append($0) }
+    )
+
+    // Make and send some requests:
+    let requests: [CustomPayload] = [
+      CustomPayload(message: "one", number: .random(in: Int64.min..<Int64.max)),
+      CustomPayload(message: "two", number: .random(in: Int64.min..<Int64.max)),
+      CustomPayload(message: "three", number: .random(in: Int64.min..<Int64.max))
+    ]
+    rpc.sendMessages(requests, promise: nil)
+    rpc.sendEnd(promise: nil)
+
+    // Wait for the RPC to finish before comparing responses.
+    rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
+    self.wait(for: [statusExpectation], timeout: 1.0)
+
+    // Are the responses as expected?
+    let expected = requests.map { request in
+      CustomPayload(message: String(request.message.reversed()), number: request.number + 1)
+    }
+    XCTAssertEqual(responses, expected)
+  }
+
+  func testNoDeserializationOnTheClient() throws {
+    // This test demonstrates how to skip the deserialization step on the client. It isn't necessary
+    // to use a custom service provider to do this, although we do here.
+    let statusExpectation = self.expectation(description: "status received")
+
+    var responses: [IdentityPayload] = []
+    // Here we use `IdentityPayload` for our response type: we define it below such that it does
+    // not deserialize the bytes provided to it by gRPC.
+    let rpc: BidirectionalStreamingCall<CustomPayload, IdentityPayload> = self.client.makeBidirectionalStreamingCall(
+      path: "/CustomPayload/AddOneAndReverseMessage",
+      handler: { responses.append($0) }
+    )
+
+    let request = CustomPayload(message: "message", number: 42)
+    rpc.sendMessage(request, promise: nil)
+    rpc.sendEnd(promise: nil)
+
+    // Wait for the RPC to finish before comparing responses.
+    rpc.status.map { $0.code }.assertEqual(.ok, fulfill: statusExpectation)
+    self.wait(for: [statusExpectation], timeout: 1.0)
+
+    guard var response = responses.first?.buffer else {
+      XCTFail("RPC completed without a response")
+      return
+    }
+
+    // We just took the raw bytes from the payload: we can still decode it because we know the
+    // server returned a serialized `CustomPayload`.
+    let actual = try CustomPayload(serializedByteBuffer: &response)
+    XCTAssertEqual(actual.message, "egassem")
+    XCTAssertEqual(actual.number, 43)
+  }
+}
+
+// MARK: Custom Payload Service
+
+fileprivate class CustomPayloadProvider: CallHandlerProvider {
+  var serviceName: String = "CustomPayload"
+
+  // Bidirectional RPC which returns a new `CustomPayload` for each `CustomPayload` received.
+  // The returned payloads have their `message` reversed and their `number` incremented by one.
+  fileprivate func addOneAndReverseMessage(
+    context: StreamingResponseCallContext<CustomPayload>
+  ) -> EventLoopFuture<(StreamEvent<CustomPayload>) -> Void> {
+    return context.eventLoop.makeSucceededFuture({ event in
+      switch event {
+      case .message(let payload):
+        let response = CustomPayload(
+          message: String(payload.message.reversed()),
+          number: payload.number + 1
+        )
+        _ = context.sendResponse(response)
+
+      case .end:
+        context.statusPromise.succeed(.ok)
+      }
+    })
+  }
+
+  func handleMethod(_ methodName: String, callHandlerContext: CallHandlerContext) -> GRPCCallHandler? {
+    switch methodName {
+    case "AddOneAndReverseMessage":
+      return BidirectionalStreamingCallHandler<CustomPayload, CustomPayload>(callHandlerContext: callHandlerContext) { context in
+        return self.addOneAndReverseMessage(context: context)
+      }
+
+    default:
+      return nil
+    }
+  }
+}
+
+fileprivate struct IdentityPayload: GRPCPayload {
+  var buffer: ByteBuffer
+
+  init(serializedByteBuffer: inout ByteBuffer) throws {
+    self.buffer = serializedByteBuffer
+  }
+
+  func serialize(into buffer: inout ByteBuffer) throws {
+    // This will never be called, however, it could be implemented as a direct copy of the bytes
+    // we hold, e.g.:
+    //
+    //   var copy = self.buffer
+    //   buffer.writeBuffer(&copy)
+    fatalError("Unimplemented")
+  }
+}
+
+/// A toy custom payload which holds a `String` and an `Int64`.
+///
+/// The payload is serialized as:
+/// - the `UInt32` encoded length of the message,
+/// - the UTF-8 encoded bytes of the message, and
+/// - the `Int64` bytes of the number.
+fileprivate struct CustomPayload: GRPCPayload, Equatable {
+  var message: String
+  var number: Int64
+
+  init(message: String, number: Int64) {
+    self.message = message
+    self.number = number
+  }
+
+  init(serializedByteBuffer: inout ByteBuffer) throws {
+    guard let messageLength = serializedByteBuffer.readInteger(as: UInt32.self),
+      let message = serializedByteBuffer.readString(length: Int(messageLength)),
+      let number = serializedByteBuffer.readInteger(as: Int64.self) else {
+        throw GRPCError.DeserializationFailure()
+    }
+
+    self.message = message
+    self.number = number
+  }
+
+  func serialize(into buffer: inout ByteBuffer) throws {
+    buffer.writeInteger(UInt32(self.message.count))
+    buffer.writeString(self.message)
+    buffer.writeInteger(self.number)
+  }
+}

+ 11 - 0
Tests/GRPCTests/XCTestManifests.swift

@@ -316,6 +316,16 @@ extension GRPCClientStateMachineTests {
     ]
 }
 
+extension GRPCCustomPayloadTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__GRPCCustomPayloadTests = [
+        ("testCustomPayload", testCustomPayload),
+        ("testNoDeserializationOnTheClient", testNoDeserializationOnTheClient),
+    ]
+}
+
 extension GRPCInsecureInteroperabilityTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -625,6 +635,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(FunctionalTestsMutualAuthentication.__allTests__FunctionalTestsMutualAuthentication),
         testCase(FunctionalTestsMutualAuthenticationNIOTS.__allTests__FunctionalTestsMutualAuthenticationNIOTS),
         testCase(GRPCClientStateMachineTests.__allTests__GRPCClientStateMachineTests),
+        testCase(GRPCCustomPayloadTests.__allTests__GRPCCustomPayloadTests),
         testCase(GRPCInsecureInteroperabilityTests.__allTests__GRPCInsecureInteroperabilityTests),
         testCase(GRPCSecureInteroperabilityTests.__allTests__GRPCSecureInteroperabilityTests),
         testCase(GRPCServerRequestRoutingHandlerTests.__allTests__GRPCServerRequestRoutingHandlerTests),