Browse Source

Refactor channel connectivity to avoid multiple spin loops (#380)

The current Swift implementation of the gRPC channel's connectivity observer spins up a new ConnectivityObserver instance which then starts a new "spin loop" thread and continually runs, observing changes to the underlying gRPC Core channel's connectivity and piping those back through a callback closure. This means that there's a new spin loop spun up for each observer for each channel.

We can avoid having to spin up multiple spin loops for each observer (keeping only 0 or 1 per channel) by allowing a single ConnectivityObserver instance to pipe changes back to multiple callbacks.
Michael Rebello 6 years ago
parent
commit
10aff09790

+ 30 - 90
Sources/SwiftGRPC/Core/Channel.swift

@@ -14,18 +14,19 @@
  * limitations under the License.
  */
 #if SWIFT_PACKAGE
-  import CgRPC
-  import Dispatch
+import CgRPC
 #endif
 import Foundation
 
 /// A gRPC Channel
 public class Channel {
+  private let mutex = Mutex()
   /// Pointer to underlying C representation
   private let underlyingChannel: UnsafeMutableRawPointer
-
   /// Completion queue for channel call operations
   private let completionQueue: CompletionQueue
+  /// Observer for connectivity state changes. Created lazily if needed
+  private var connectivityObserver: ConnectivityObserver?
 
   /// Timeout for new calls
   public var timeout: TimeInterval = 600.0
@@ -33,9 +34,6 @@ public class Channel {
   /// Default host to use for new calls
   public var host: String
 
-  /// Connectivity state observers
-  private var connectivityObservers: [ConnectivityObserver] = []
-
   /// Initializes a gRPC channel
   ///
   /// - Parameter address: the address of the server to be called
@@ -47,12 +45,12 @@ public class Channel {
     let argumentWrappers = arguments.map { $0.toCArg() }
 
     underlyingChannel = withExtendedLifetime(argumentWrappers) {
-        var argumentValues = argumentWrappers.map { $0.wrapped }
-        if secure {
-          return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count))
-        } else {
-          return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
-        }
+      var argumentValues = argumentWrappers.map { $0.wrapped }
+      if secure {
+        return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count))
+      } else {
+        return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count))
+      }
     }
     completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
     completionQueue.run() // start a loop that watches the channel's completion queue
@@ -66,10 +64,10 @@ public class Channel {
     gRPC.initialize()
     host = googleAddress
     let argumentWrappers = arguments.map { $0.toCArg() }
-    
+
     underlyingChannel = withExtendedLifetime(argumentWrappers) {
-        var argumentValues = argumentWrappers.map { $0.wrapped }
-        return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count))
+      var argumentValues = argumentWrappers.map { $0.wrapped }
+      return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count))
     }
 
     completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
@@ -89,17 +87,19 @@ public class Channel {
     let argumentWrappers = arguments.map { $0.toCArg() }
 
     underlyingChannel = withExtendedLifetime(argumentWrappers) {
-        var argumentValues = argumentWrappers.map { $0.wrapped }
-        return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count))
+      var argumentValues = argumentWrappers.map { $0.wrapped }
+      return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count))
     }
     completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client")
     completionQueue.run() // start a loop that watches the channel's completion queue
   }
 
   deinit {
-    connectivityObservers.forEach { $0.shutdown() }
-    cgrpc_channel_destroy(underlyingChannel)
-    completionQueue.shutdown()
+    self.mutex.synchronize {
+      self.connectivityObserver?.shutdown()
+    }
+    cgrpc_channel_destroy(self.underlyingChannel)
+    self.completionQueue.shutdown()
   }
 
   /// Constructs a Call object to make a gRPC API call
@@ -109,7 +109,7 @@ public class Channel {
   /// - Parameter timeout: a timeout value in seconds
   /// - Returns: a Call object that can be used to perform the request
   public func makeCall(_ method: String, host: String = "", timeout: TimeInterval? = nil) -> Call {
-    let host = (host == "") ? self.host : host
+    let host = host.isEmpty ? self.host : host
     let timeout = timeout ?? self.timeout
     let underlyingCall = cgrpc_channel_create_call(underlyingChannel, method, host, timeout)!
     return Call(underlyingCall: underlyingCall, owned: true, completionQueue: completionQueue)
@@ -126,77 +126,17 @@ public class Channel {
   /// Subscribe to connectivity state changes
   ///
   /// - Parameter callback: block executed every time a new connectivity state is detected
-  public func subscribe(callback: @escaping (ConnectivityState) -> Void) {
-    connectivityObservers.append(ConnectivityObserver(underlyingChannel: underlyingChannel, currentState: connectivityState(), callback: callback))
-  }
-}
-
-private extension Channel {
-  final class ConnectivityObserver {
-    private let completionQueue: CompletionQueue
-    private let underlyingChannel: UnsafeMutableRawPointer
-    private let underlyingCompletionQueue: UnsafeMutableRawPointer
-    private let callback: (ConnectivityState) -> Void
-    private var lastState: ConnectivityState
-    private var hasBeenShutdown = false
-    private let stateMutex: Mutex = Mutex()
-
-    init(underlyingChannel: UnsafeMutableRawPointer, currentState: ConnectivityState, callback: @escaping (ConnectivityState) -> ()) {
-      self.underlyingChannel = underlyingChannel
-      self.underlyingCompletionQueue = cgrpc_completion_queue_create_for_next()
-      self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue, name: "Connectivity State")
-      self.callback = callback
-      self.lastState = currentState
-      run()
-    }
-
-    deinit {
-      shutdown()
-    }
-
-    private func run() {
-      let spinloopThreadQueue = DispatchQueue(label: "SwiftGRPC.ConnectivityObserver.run.spinloopThread")
-
-      spinloopThreadQueue.async {
-        while true  {
-          guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else {
-            return
-          }
-            
-          guard let underlyingState = self.lastState.underlyingState else { return }
-
-          let deadline: TimeInterval = 0.2
-          cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue, underlyingState, deadline, nil)
-          let event = self.completionQueue.wait(timeout: deadline)
-          
-          guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else {
-            return
-          }
-
-          switch event.type {
-          case .complete:
-            let newState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0))
-
-            if newState != self.lastState {
-              self.callback(newState)
-            }
-            self.lastState = newState
-
-          case .queueShutdown:
-            return
-
-          default:
-            continue
-          }
-        }
+  public func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) {
+    self.mutex.synchronize {
+      let observer: ConnectivityObserver
+      if let existingObserver = self.connectivityObserver {
+        observer = existingObserver
+      } else {
+        observer = ConnectivityObserver(underlyingChannel: self.underlyingChannel)
+        self.connectivityObserver = observer
       }
-    }
 
-    func shutdown() {
-      stateMutex.synchronize {
-        hasBeenShutdown = true
-      }
-      completionQueue.shutdown()
+      observer.addConnectivityObserver(callback: callback)
     }
   }
 }

+ 96 - 0
Sources/SwiftGRPC/Core/ChannelConnectivityObserver.swift

@@ -0,0 +1,96 @@
+/*
+ * Copyright 2016, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#if SWIFT_PACKAGE
+import CgRPC
+import Dispatch
+#endif
+import Foundation
+
+extension Channel {
+  /// Provides an interface for observing the connectivity of a given channel.
+  final class ConnectivityObserver {
+    private let mutex = Mutex()
+    private let completionQueue: CompletionQueue
+    private let underlyingChannel: UnsafeMutableRawPointer
+    private let underlyingCompletionQueue: UnsafeMutableRawPointer
+    private var callbacks = [(ConnectivityState) -> Void]()
+    private var hasBeenShutdown = false
+
+    init(underlyingChannel: UnsafeMutableRawPointer) {
+      self.underlyingChannel = underlyingChannel
+      self.underlyingCompletionQueue = cgrpc_completion_queue_create_for_next()
+      self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue,
+                                             name: "Connectivity State")
+      self.run()
+    }
+
+    deinit {
+      self.shutdown()
+    }
+
+    func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) {
+      self.mutex.synchronize {
+        self.callbacks.append(callback)
+      }
+    }
+
+    func shutdown() {
+      self.mutex.synchronize {
+        guard !self.hasBeenShutdown else { return }
+
+        self.hasBeenShutdown = true
+        self.completionQueue.shutdown()
+      }
+    }
+
+    // MARK: - Private
+
+    private func run() {
+      let spinloopThreadQueue = DispatchQueue(label: "SwiftGRPC.ConnectivityObserver.run.spinloopThread")
+      var lastState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0))
+      spinloopThreadQueue.async {
+        while (self.mutex.synchronize { !self.hasBeenShutdown }) {
+          guard let underlyingState = lastState.underlyingState else { return }
+
+          let deadline: TimeInterval = 0.2
+          cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue,
+                                                 underlyingState, deadline, nil)
+
+          let event = self.completionQueue.wait(timeout: deadline)
+          guard (self.mutex.synchronize { !self.hasBeenShutdown }) else {
+            return
+          }
+
+          switch event.type {
+          case .complete:
+            let newState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0))
+            guard newState != lastState else { continue }
+
+            let callbacks = self.mutex.synchronize { Array(self.callbacks) }
+            lastState = newState
+            callbacks.forEach { callback in callback(newState) }
+
+          case .queueShutdown:
+            return
+
+          default:
+            continue
+          }
+        }
+      }
+    }
+  }
+}

+ 22 - 3
Tests/SwiftGRPCTests/ChannelConnectivityTests.swift

@@ -21,7 +21,8 @@ final class ChannelConnectivityTests: BasicEchoTestCase {
 
   static var allTests: [(String, (ChannelConnectivityTests) -> () throws -> Void)] {
     return [
-      ("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash)
+      ("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash),
+      ("testMultipleConnectivityObserversAreCalled", testMultipleConnectivityObserversAreCalled),
     ]
   }
 }
@@ -30,12 +31,12 @@ extension ChannelConnectivityTests {
   func testDanglingConnectivityObserversDontCrash() {
     let completionHandlerExpectation = expectation(description: "completion handler called")
 
-    client?.channel.subscribe { connectivityState in
+    client.channel.addConnectivityObserver { connectivityState in
       print("ConnectivityState: \(connectivityState)")
     }
 
     let request = Echo_EchoRequest(text: "foo bar baz foo bar baz")
-    _ = try! client!.expand(request) { callResult in
+    _ = try! client.expand(request) { callResult in
       print("callResult.statusCode: \(callResult.statusCode)")
       completionHandlerExpectation.fulfill()
     }
@@ -46,4 +47,22 @@ extension ChannelConnectivityTests {
 
     waitForExpectations(timeout: 0.5)
   }
+
+  func testMultipleConnectivityObserversAreCalled() {
+    // Linux doesn't yet support `assertForOverFulfill` or `expectedFulfillmentCount`, and since these are
+    // called multiple times, we can't use expectations. https://bugs.swift.org/browse/SR-6249
+    var firstObserverCalled = false
+    var secondObserverCalled = false
+    client.channel.addConnectivityObserver { _ in firstObserverCalled = true }
+    client.channel.addConnectivityObserver { _ in secondObserverCalled = true }
+
+    let completionHandlerExpectation = expectation(description: "completion handler called")
+    _ = try! client.expand(Echo_EchoRequest(text: "foo bar baz foo bar baz")) { _ in
+      completionHandlerExpectation.fulfill()
+    }
+
+    waitForExpectations(timeout: 0.5)
+    XCTAssertTrue(firstObserverCalled)
+    XCTAssertTrue(secondObserverCalled)
+  }
 }