浏览代码

Release stream callback, once the stream has finished (#1363)

Fabian Fett 3 年之前
父节点
当前提交
be02b34b53

+ 12 - 2
Sources/GRPC/ClientCalls/ResponseContainers.swift

@@ -98,7 +98,7 @@ internal class StreamingResponseParts<Response> {
   private let eventLoop: EventLoop
 
   /// A callback for response messages.
-  private let responseCallback: (Response) -> Void
+  private var responseCallback: Optional<(Response) -> Void>
 
   /// Lazy promises for the status, initial-, and trailing-metadata.
   private var initialMetadataPromise: LazyEventLoopPromise<HPACKHeaders>
@@ -139,9 +139,13 @@ internal class StreamingResponseParts<Response> {
       self.initialMetadataPromise.succeed(metadata)
 
     case let .message(response):
-      self.responseCallback(response)
+      self.responseCallback?(response)
 
     case let .end(status, trailers):
+      // Once the stream has finished, we must release the callback, to make sure don't
+      // break potential retain cycles (the callback may reference other object's that in
+      // turn reference `StreamingResponseParts`).
+      self.responseCallback = nil
       self.initialMetadataPromise.fail(status)
       self.trailingMetadataPromise.succeed(trailers)
       self.statusPromise.succeed(status)
@@ -149,6 +153,12 @@ internal class StreamingResponseParts<Response> {
   }
 
   internal func handleError(_ error: Error) {
+    self.eventLoop.assertInEventLoop()
+
+    // Once the stream has finished, we must release the callback, to make sure don't
+    // break potential retain cycles (the callback may reference other object's that in
+    // turn reference `StreamingResponseParts`).
+    self.responseCallback = nil
     let withoutContext = error.removingContext()
     let status = withoutContext.makeGRPCStatus()
     self.initialMetadataPromise.fail(withoutContext)

+ 13 - 4
Tests/GRPCTests/FakeChannelTests.swift

@@ -81,6 +81,10 @@ class FakeChannelTests: GRPCTestCase {
   }
 
   func testBidirectional() {
+    final class ResponseCollector {
+      private(set) var responses = [Response]()
+      func collect(_ response: Response) { self.responses.append(response) }
+    }
     var requests: [Request] = []
     let response = self.makeStreamingResponse { part in
       switch part {
@@ -91,10 +95,12 @@ class FakeChannelTests: GRPCTestCase {
       }
     }
 
-    var responses: [Response] = []
-    let call = self.makeBidirectionalStreamingCall {
-      responses.append($0)
+    var collector = ResponseCollector()
+    XCTAssertTrue(isKnownUniquelyReferenced(&collector))
+    let call = self.makeBidirectionalStreamingCall { [collector] in
+      collector.collect($0)
     }
+    XCTAssertFalse(isKnownUniquelyReferenced(&collector))
 
     XCTAssertNoThrow(try call.sendMessage(.with { $0.text = "1" }).wait())
     XCTAssertNoThrow(try call.sendMessage(.with { $0.text = "2" }).wait())
@@ -106,9 +112,12 @@ class FakeChannelTests: GRPCTestCase {
     XCTAssertNoThrow(try response.sendMessage(.with { $0.text = "4" }))
     XCTAssertNoThrow(try response.sendMessage(.with { $0.text = "5" }))
     XCTAssertNoThrow(try response.sendMessage(.with { $0.text = "6" }))
+    XCTAssertEqual(collector.responses.count, 3)
+    XCTAssertFalse(isKnownUniquelyReferenced(&collector))
     XCTAssertNoThrow(try response.sendEnd())
+    XCTAssertTrue(isKnownUniquelyReferenced(&collector))
 
-    XCTAssertEqual(responses, (4 ... 6).map { number in .with { $0.text = "\(number)" } })
+    XCTAssertEqual(collector.responses, (4 ... 6).map { number in .with { $0.text = "\(number)" } })
     XCTAssertTrue(try call.status.map { $0.isOk }.wait())
   }
 

+ 81 - 0
Tests/GRPCTests/StreamResponseHandlerRetainCycleTests.swift

@@ -0,0 +1,81 @@
+/*
+ * Copyright 2022, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import EchoImplementation
+import EchoModel
+import GRPC
+import NIOConcurrencyHelpers
+import NIOCore
+import NIOPosix
+import XCTest
+
+final class StreamResponseHandlerRetainCycleTests: GRPCTestCase {
+  var group: EventLoopGroup!
+  var server: Server!
+  var client: ClientConnection!
+
+  var echo: Echo_EchoClient!
+
+  override func setUp() {
+    super.setUp()
+    self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+
+    self.server = try! Server.insecure(group: self.group)
+      .withServiceProviders([EchoProvider()])
+      .withLogger(self.serverLogger)
+      .bind(host: "localhost", port: 0)
+      .wait()
+
+    self.client = ClientConnection.insecure(group: self.group)
+      .withBackgroundActivityLogger(self.clientLogger)
+      .connect(host: "localhost", port: self.server.channel.localAddress!.port!)
+
+    self.echo = Echo_EchoClient(
+      channel: self.client,
+      defaultCallOptions: CallOptions(logger: self.clientLogger)
+    )
+  }
+
+  override func tearDown() {
+    XCTAssertNoThrow(try self.client.close().wait())
+    XCTAssertNoThrow(try self.server.close().wait())
+    XCTAssertNoThrow(try self.group.syncShutdownGracefully())
+    super.tearDown()
+  }
+
+  func testHandlerClosureIsReleasedOnceStreamEnds() {
+    final class Counter {
+      private let atomic = NIOAtomic.makeAtomic(value: 0)
+      func increment() { self.atomic.add(1) }
+      var value: Int {
+        self.atomic.load()
+      }
+    }
+
+    var counter = Counter()
+    XCTAssertTrue(isKnownUniquelyReferenced(&counter))
+    let get = self.echo.update { [capturedCounter = counter] _ in
+      capturedCounter.increment()
+    }
+    XCTAssertFalse(isKnownUniquelyReferenced(&counter))
+
+    get.sendMessage(.init(text: "hello world"), promise: nil)
+    XCTAssertFalse(isKnownUniquelyReferenced(&counter))
+    XCTAssertNoThrow(try get.sendEnd().wait())
+    XCTAssertNoThrow(try get.status.wait())
+    XCTAssertEqual(counter.value, 1)
+    XCTAssertTrue(isKnownUniquelyReferenced(&counter))
+  }
+}