Browse Source

Allow timeouts to be rounded so they may comply with the spec (#569)

Motivation:

The gRPC spec dictates that timeouts may be no longer than 8 digits
long. As such methods for creating `GRPCTimeout`s all throw if the given
timeout is not valid. In the vast majority of cases this is not actually
an issue and having to handle the exception can be a pain.

Modifications:

Add non-throwing factory methods to `GRPCTimeout` which round the
timeout so that it may be encoded over the wire.

Result:

Timeouts are easier to create.
George Barnett 6 years ago
parent
commit
113c24797e

+ 150 - 24
Sources/GRPC/GRPCTimeout.swift

@@ -16,7 +16,7 @@
 import Foundation
 import NIO
 
-public enum GRPCTimeoutError: String, Error {
+public enum GRPCTimeoutError: String, Error, Equatable {
   case negative = "GRPCTimeout must be non-negative"
   case tooManyDigits = "GRPCTimeout must be at most 8 digits"
 }
@@ -27,28 +27,89 @@ public enum GRPCTimeoutError: String, Error {
 public struct GRPCTimeout: CustomStringConvertible, Equatable {
   public static let `default`: GRPCTimeout = try! .minutes(1)
   /// Creates an infinite timeout. This is a sentinel value which must __not__ be sent to a gRPC service.
-  public static let infinite: GRPCTimeout = GRPCTimeout(nanoseconds: Int64.max, description: "infinite")
+  public static let infinite: GRPCTimeout = GRPCTimeout(nanoseconds: Int64.max, wireEncoding: "infinite")
 
-  /// A description of the timeout in the format described in the
-  /// [gRPC protocol](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
-  public let description: String
+  /// The largest amount of any unit of time which may be represented by a gRPC timeout.
+  private static let maxAmount: Int64 = 99_999_999
+
+  /// The wire encoding of this timeout as described in the gRPC protocol.
+  /// See: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md.
+  public let wireEncoding: String
   public let nanoseconds: Int64
 
-  private init(nanoseconds: Int64, description: String) {
-    self.nanoseconds = nanoseconds
-    self.description = description
+  public var description: String {
+    return wireEncoding
   }
 
-  private static func makeTimeout(_ amount: Int, _ unit: GRPCTimeoutUnit) throws -> GRPCTimeout {
-    // Timeouts must be positive and at most 8-digits.
-    if amount < 0 { throw GRPCTimeoutError.negative }
-    if amount >= 100_000_000  { throw GRPCTimeoutError.tooManyDigits }
+  private init(nanoseconds: Int64, wireEncoding: String) {
+    self.nanoseconds = nanoseconds
+    self.wireEncoding = wireEncoding
+  }
 
+  /// Creates a `GRPCTimeout`.
+  ///
+  /// - Precondition: The amount should be greater than or equal to zero and less than or equal
+  ///   to `GRPCTimeout.maxAmount`.
+  private init(amount: Int64, unit: GRPCTimeoutUnit) {
+    precondition(amount >= 0 && amount <= GRPCTimeout.maxAmount)
     // See "Timeout" in https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
-    let description = "\(amount)\(unit.rawValue)"
-    let nanoseconds = Int64(amount) * unit.asNanoseconds
 
-    return GRPCTimeout(nanoseconds: nanoseconds, description: description)
+    // If we overflow at this point, which is certainly possible if `amount` is sufficiently large
+    // and `unit` is `.hours`, clamp the nanosecond timeout to `Int64.max`. It's about 292 years so
+    // it should be long enough for the user not to notice the difference should the rpc time out.
+    let (partial, overflow) = amount.multipliedReportingOverflow(by: unit.asNanoseconds)
+
+    self.init(
+      nanoseconds: overflow ? Int64.max : partial,
+      wireEncoding: "\(amount)\(unit.rawValue)"
+    )
+  }
+
+  /// Create a timeout by rounding up the timeout so that it may be represented in the gRPC
+  /// wire format.
+  private init(rounding amount: Int64, unit: GRPCTimeoutUnit) {
+    var roundedAmount = amount
+    var roundedUnit = unit
+
+    if roundedAmount <= 0 {
+      roundedAmount = 0
+    } else {
+      while roundedAmount > GRPCTimeout.maxAmount {
+        switch roundedUnit {
+        case .nanoseconds:
+          roundedAmount = roundedAmount.quotientRoundedUp(dividingBy: 1_000)
+          roundedUnit = .microseconds
+        case .microseconds:
+          roundedAmount = roundedAmount.quotientRoundedUp(dividingBy: 1_000)
+          roundedUnit = .milliseconds
+        case .milliseconds:
+          roundedAmount = roundedAmount.quotientRoundedUp(dividingBy: 1_000)
+          roundedUnit = .seconds
+        case .seconds:
+          roundedAmount = roundedAmount.quotientRoundedUp(dividingBy: 60)
+          roundedUnit = .minutes
+        case .minutes:
+          roundedAmount = roundedAmount.quotientRoundedUp(dividingBy: 60)
+          roundedUnit = .hours
+        case .hours:
+          roundedAmount = GRPCTimeout.maxAmount
+          roundedUnit = .hours
+        }
+      }
+    }
+
+    self.init(amount: roundedAmount, unit: roundedUnit)
+  }
+
+  private static func makeTimeout(_ amount: Int64, _ unit: GRPCTimeoutUnit) throws -> GRPCTimeout {
+    // Timeouts must be positive and at most 8-digits.
+    if amount < 0 {
+      throw GRPCTimeoutError.negative
+    }
+    if amount > GRPCTimeout.maxAmount {
+      throw GRPCTimeoutError.tooManyDigits
+    }
+    return .init(amount: amount, unit: unit)
   }
 
   /// Creates a new GRPCTimeout for the given amount of hours.
@@ -59,7 +120,16 @@ public struct GRPCTimeout: CustomStringConvertible, Equatable {
   /// - Returns: A `GRPCTimeout` representing the given number of hours.
   /// - Throws: `GRPCTimeoutError` if the amount was negative or more than 8 digits long.
   public static func hours(_ amount: Int) throws -> GRPCTimeout {
-    return try makeTimeout(amount, .hours)
+    return try makeTimeout(Int64(amount), .hours)
+  }
+
+  /// Creates a new GRPCTimeout for the given amount of hours.
+  ///
+  /// The timeout will be rounded up if it may not be represented in the wire format.
+  ///
+  /// - Parameter amount: The number of hours to represent.
+  public static func hours(rounding amount: Int) -> GRPCTimeout {
+    return .init(rounding: Int64(amount), unit: .hours)
   }
 
   /// Creates a new GRPCTimeout for the given amount of minutes.
@@ -70,7 +140,16 @@ public struct GRPCTimeout: CustomStringConvertible, Equatable {
   /// - Returns: A `GRPCTimeout` representing the given number of minutes.
   /// - Throws: `GRPCTimeoutError` if the amount was negative or more than 8 digits long.
   public static func minutes(_ amount: Int) throws -> GRPCTimeout {
-    return try makeTimeout(amount, .minutes)
+    return try makeTimeout(Int64(amount), .minutes)
+  }
+
+  /// Creates a new GRPCTimeout for the given amount of minutes.
+  ///
+  /// The timeout will be rounded up if it may not be represented in the wire format.
+  ///
+  /// - Parameter amount: The number of minutes to represent.
+  public static func minutes(rounding amount: Int) -> GRPCTimeout {
+    return .init(rounding: Int64(amount), unit: .minutes)
   }
 
   /// Creates a new GRPCTimeout for the given amount of seconds.
@@ -81,7 +160,16 @@ public struct GRPCTimeout: CustomStringConvertible, Equatable {
   /// - Returns: A `GRPCTimeout` representing the given number of seconds.
   /// - Throws: `GRPCTimeoutError` if the amount was negative or more than 8 digits long.
   public static func seconds(_ amount: Int) throws -> GRPCTimeout {
-    return try makeTimeout(amount, .seconds)
+    return try makeTimeout(Int64(amount), .seconds)
+  }
+
+  /// Creates a new GRPCTimeout for the given amount of seconds.
+  ///
+  /// The timeout will be rounded up if it may not be represented in the wire format.
+  ///
+  /// - Parameter amount: The number of seconds to represent.
+  public static func seconds(rounding amount: Int) -> GRPCTimeout {
+    return .init(rounding: Int64(amount), unit: .seconds)
   }
 
   /// Creates a new GRPCTimeout for the given amount of milliseconds.
@@ -92,7 +180,16 @@ public struct GRPCTimeout: CustomStringConvertible, Equatable {
   /// - Returns: A `GRPCTimeout` representing the given number of milliseconds.
   /// - Throws: `GRPCTimeoutError` if the amount was negative or more than 8 digits long.
   public static func milliseconds(_ amount: Int) throws -> GRPCTimeout {
-    return try makeTimeout(amount, .milliseconds)
+    return try makeTimeout(Int64(amount), .milliseconds)
+  }
+
+  /// Creates a new GRPCTimeout for the given amount of milliseconds.
+  ///
+  /// The timeout will be rounded up if it may not be represented in the wire format.
+  ///
+  /// - Parameter amount: The number of milliseconds to represent.
+  public static func milliseconds(rounding amount: Int) -> GRPCTimeout {
+    return .init(rounding: Int64(amount), unit: .milliseconds)
   }
 
   /// Creates a new GRPCTimeout for the given amount of microseconds.
@@ -103,7 +200,16 @@ public struct GRPCTimeout: CustomStringConvertible, Equatable {
   /// - Returns: A `GRPCTimeout` representing the given number of microseconds.
   /// - Throws: `GRPCTimeoutError` if the amount was negative or more than 8 digits long.
   public static func microseconds(_ amount: Int) throws -> GRPCTimeout {
-    return try makeTimeout(amount, .microseconds)
+    return try makeTimeout(Int64(amount), .microseconds)
+  }
+
+  /// Creates a new GRPCTimeout for the given amount of microseconds.
+  ///
+  /// The timeout will be rounded up if it may not be represented in the wire format.
+  ///
+  /// - Parameter amount: The number of microseconds to represent.
+  public static func microseconds(rounding amount: Int) -> GRPCTimeout {
+    return .init(rounding: Int64(amount), unit: .microseconds)
   }
 
   /// Creates a new GRPCTimeout for the given amount of nanoseconds.
@@ -114,18 +220,38 @@ public struct GRPCTimeout: CustomStringConvertible, Equatable {
   /// - Returns: A `GRPCTimeout` representing the given number of nanoseconds.
   /// - Throws: `GRPCTimeoutError` if the amount was negative or more than 8 digits long.
   public static func nanoseconds(_ amount: Int) throws -> GRPCTimeout {
-    return try makeTimeout(amount, .nanoseconds)
+    return try makeTimeout(Int64(amount), .nanoseconds)
+  }
+
+  /// Creates a new GRPCTimeout for the given amount of nanoseconds.
+  ///
+  /// The timeout will be rounded up if it may not be represented in the wire format.
+  ///
+  /// - Parameter amount: The number of nanoseconds to represent.
+  public static func nanoseconds(rounding amount: Int) -> GRPCTimeout {
+    return .init(rounding: Int64(amount), unit: .nanoseconds)
   }
 }
 
-extension GRPCTimeout {
+public extension GRPCTimeout {
   /// Returns a NIO `TimeAmount` representing the amount of time as this timeout.
-  public var asNIOTimeAmount: TimeAmount {
+  var asNIOTimeAmount: TimeAmount {
     return TimeAmount.nanoseconds(numericCast(nanoseconds))
   }
 }
 
-private enum GRPCTimeoutUnit: String {
+fileprivate extension Int64 {
+  /// Returns the quotient of this value when divided by `divisor` rounded up to the nearest
+  /// multiple of `divisor` if the remainder is non-zero.
+  ///
+  /// - Parameter divisor: The value to divide this value by.
+  func quotientRoundedUp(dividingBy divisor: Int64) -> Int64 {
+    let (quotient, remainder) = self.quotientAndRemainder(dividingBy: divisor)
+    return quotient + (remainder != 0 ? 1 : 0)
+  }
+}
+
+fileprivate enum GRPCTimeoutUnit: String {
   case hours = "H"
   case minutes = "M"
   case seconds = "S"

+ 15 - 0
Tests/GRPCTests/ClientTLSTests.swift

@@ -1,3 +1,18 @@
+/*
+ * Copyright 2019, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 import Foundation
 import GRPC
 import GRPCSampleData

+ 120 - 0
Tests/GRPCTests/GRPCTimeoutTests.swift

@@ -0,0 +1,120 @@
+/*
+ * Copyright 2019, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import Foundation
+import GRPC
+import XCTest
+
+class GRPCTimeoutTests: GRPCTestCase {
+  func testNegativeTimeoutThrows() throws {
+    XCTAssertThrowsError(try GRPCTimeout.seconds(-10)) { error in
+      XCTAssertEqual(error as? GRPCTimeoutError, GRPCTimeoutError.negative)
+    }
+  }
+
+  func testTooLargeTimeout() throws {
+    XCTAssertThrowsError(try GRPCTimeout.seconds(100_000_000)) { error in
+      XCTAssertEqual(error as? GRPCTimeoutError, GRPCTimeoutError.tooManyDigits)
+    }
+  }
+
+  func testRoundingNegativeTimeout() {
+    let timeout: GRPCTimeout = .seconds(rounding: -10)
+    XCTAssertEqual(String(describing: timeout), "0S")
+    XCTAssertEqual(timeout.nanoseconds, 0)
+  }
+
+  func testRoundingNanosecondsTimeout() throws {
+    let timeout: GRPCTimeout = .nanoseconds(rounding: 123_456_789)
+    XCTAssertEqual(timeout, try .microseconds(123457))
+
+    // 123_456_789 (nanoseconds) / 1_000
+    //   = 123_456.789
+    //   = 123_457 (microseconds, rounded up)
+    XCTAssertEqual(String(describing: timeout), "123457u")
+
+    // 123_457 (microseconds) * 1_000
+    //   = 123_457_000 (nanoseconds)
+    XCTAssertEqual(timeout.nanoseconds, 123_457_000)
+  }
+
+  func testRoundingMicrosecondsTimeout() throws {
+    let timeout: GRPCTimeout = .microseconds(rounding: 123_456_789)
+    XCTAssertEqual(timeout, try .milliseconds(123457))
+
+    // 123_456_789 (microseconds) / 1_000
+    //   = 123_456.789
+    //   = 123_457 (milliseconds, rounded up)
+    XCTAssertEqual(String(describing: timeout), "123457m")
+
+    // 123_457 (milliseconds) * 1_000 * 1_000
+    //   = 123_457_000_000 (nanoseconds)
+    XCTAssertEqual(timeout.nanoseconds, 123_457_000_000)
+  }
+
+  func testRoundingMillisecondsTimeout() throws {
+    let timeout: GRPCTimeout = .milliseconds(rounding: 123_456_789)
+    XCTAssertEqual(timeout, try .seconds(123457))
+
+    // 123_456_789 (milliseconds) / 1_000
+    //   = 123_456.789
+    //   = 123_457 (seconds, rounded up)
+    XCTAssertEqual(String(describing: timeout), "123457S")
+
+    // 123_457 (milliseconds) * 1_000 * 1_000 * 1_000
+    //   = 123_457_000_000_000 (nanoseconds)
+    XCTAssertEqual(timeout.nanoseconds, 123_457_000_000_000)
+  }
+
+  func testRoundingSecondsTimeout() throws {
+    let timeout: GRPCTimeout = .seconds(rounding: 123_456_789)
+    XCTAssertEqual(timeout, try .minutes(2057614))
+
+    // 123_456_789 (seconds) / 60
+    //   = 2_057_613.15
+    //   = 2_057_614 (minutes, rounded up)
+    XCTAssertEqual(String(describing: timeout), "2057614M")
+
+    // 2_057_614 (minutes) * 60 * 1_000 * 1_000 * 1_000
+    //   = 123_456_840_000_000_000 (nanoseconds)
+    XCTAssertEqual(timeout.nanoseconds, 123_456_840_000_000_000)
+  }
+
+  func testRoundingMinutesTimeout() throws {
+    let timeout: GRPCTimeout = .minutes(rounding: 123_456_789)
+    XCTAssertEqual(timeout, try .hours(2057614))
+
+    // 123_456_789 (minutes) / 60
+    //   = 2_057_613.15
+    //   = 2_057_614 (hours, rounded up)
+    XCTAssertEqual(String(describing: timeout), "2057614H")
+
+    // 123_457 (minutes) * 60 * 60 * 1_000 * 1_000 * 1_000
+    //   = 7_407_410_400_000_000_000 (nanoseconds)
+    XCTAssertEqual(timeout.nanoseconds, 7_407_410_400_000_000_000)
+  }
+
+  func testRoundingHoursTimeout() throws {
+    let timeout: GRPCTimeout = .hours(rounding: 123_456_789)
+    XCTAssertEqual(timeout, try .hours(99_999_999))
+
+    // Hours are the largest unit of time we have (as per the gRPC spec) so we can't round to a
+    // different unit. In this case we clamp to the largest value.
+    XCTAssertEqual(String(describing: timeout), "99999999H")
+    // Unfortunately the largest value representable by the specification is too long to represent
+    // in nanoseconds within 64 bits, again the value is clamped.
+    XCTAssertEqual(timeout.nanoseconds, Int64.max)
+  }
+}

+ 18 - 0
Tests/GRPCTests/XCTestManifests.swift

@@ -325,6 +325,23 @@ extension GRPCStatusMessageMarshallerTests {
     ]
 }
 
+extension GRPCTimeoutTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__GRPCTimeoutTests = [
+        ("testNegativeTimeoutThrows", testNegativeTimeoutThrows),
+        ("testRoundingHoursTimeout", testRoundingHoursTimeout),
+        ("testRoundingMicrosecondsTimeout", testRoundingMicrosecondsTimeout),
+        ("testRoundingMillisecondsTimeout", testRoundingMillisecondsTimeout),
+        ("testRoundingMinutesTimeout", testRoundingMinutesTimeout),
+        ("testRoundingNanosecondsTimeout", testRoundingNanosecondsTimeout),
+        ("testRoundingNegativeTimeout", testRoundingNegativeTimeout),
+        ("testRoundingSecondsTimeout", testRoundingSecondsTimeout),
+        ("testTooLargeTimeout", testTooLargeTimeout),
+    ]
+}
+
 extension GRPCTypeSizeTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -491,6 +508,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(GRPCSecureInteroperabilityTests.__allTests__GRPCSecureInteroperabilityTests),
         testCase(GRPCStatusCodeTests.__allTests__GRPCStatusCodeTests),
         testCase(GRPCStatusMessageMarshallerTests.__allTests__GRPCStatusMessageMarshallerTests),
+        testCase(GRPCTimeoutTests.__allTests__GRPCTimeoutTests),
         testCase(GRPCTypeSizeTests.__allTests__GRPCTypeSizeTests),
         testCase(HTTP1ToRawGRPCServerCodecTests.__allTests__HTTP1ToRawGRPCServerCodecTests),
         testCase(ImmediatelyFailingProviderTests.__allTests__ImmediatelyFailingProviderTests),