Browse Source

Fix some bugs with new handlers (#1102)

Motivation:

Manually wiring up the new handlers to run the rest of the test suite
highlighted a few rough edges.

Modifications:

- Make sure compression is correctly set: it must be enabled on the
  server, in the call context, and - if applicable - on the individaul
  message
- Call the error delegate in the right place
- Add a 'protocol violation' error

Result:

Fewer bugs.
George Barnett 5 years ago
parent
commit
40e4e25ea5

+ 1 - 0
Sources/GRPC/CallHandlers/BidirectionalStreamingCallHandler.swift

@@ -139,6 +139,7 @@ public class BidirectionalStreamingCallHandler<
         headers: headers,
         logger: self.logger,
         userInfoRef: self._userInfoRef,
+        compressionIsEnabled: self._callHandlerContext.encoding.isEnabled,
         sendResponse: self.sendResponse(_:metadata:promise:)
       )
       let observer = factory(context)

+ 38 - 27
Sources/GRPC/CallHandlers/BidirectionalStreamingServerHandler.swift

@@ -122,16 +122,12 @@ public final class BidirectionalStreamingServerHandler<
 
   @inlinable
   public func receiveError(_ error: Error) {
-    self._finish(error: error)
+    self.handleError(error)
+    self.finish()
   }
 
   @inlinable
   public func finish() {
-    self._finish(error: nil)
-  }
-
-  @inlinable
-  internal func _finish(error: Error?) {
     switch self.state {
     case .idle:
       self.interceptors = nil
@@ -139,8 +135,7 @@ public final class BidirectionalStreamingServerHandler<
 
     case let .creatingObserver(context),
          let .observing(_, context):
-      let error = error ?? GRPCStatus(code: .unavailable, message: nil)
-      context.statusPromise.fail(error)
+      context.statusPromise.fail(GRPCStatus(code: .unavailable, message: nil))
 
     case .completed:
       self.interceptors = nil
@@ -171,6 +166,7 @@ public final class BidirectionalStreamingServerHandler<
         headers: headers,
         logger: self.context.logger,
         userInfoRef: self.userInfoRef,
+        compressionIsEnabled: self.context.encoding.isEnabled,
         sendResponse: self.interceptResponse(_:metadata:promise:)
       )
 
@@ -187,7 +183,7 @@ public final class BidirectionalStreamingServerHandler<
       self.observerFactory(context).whenComplete(self.userFunctionResolvedWithResult(_:))
 
     case .creatingObserver, .observing:
-      self.handleError(GRPCError.InvalidState("Protocol violation: already received headers"))
+      self.handleError(GRPCError.ProtocolViolation("Multiple header blocks received on RPC"))
 
     case .completed:
       // We may receive headers from the interceptor pipeline if we have already finished (i.e. due
@@ -201,9 +197,7 @@ public final class BidirectionalStreamingServerHandler<
   internal func receiveInterceptedMessage(_ request: Request) {
     switch self.state {
     case .idle:
-      self.handleError(
-        GRPCError.InvalidState("Protocol violation: message received before headers")
-      )
+      self.handleError(GRPCError.ProtocolViolation("Message received before headers"))
     case .creatingObserver:
       self.requestBuffer.append(.message(request))
     case let .observing(observer, _):
@@ -219,9 +213,7 @@ public final class BidirectionalStreamingServerHandler<
   internal func receiveInterceptedEnd() {
     switch self.state {
     case .idle:
-      self.handleError(
-        GRPCError.InvalidState("Protocol violation: end of stream received before headers")
-      )
+      self.handleError(GRPCError.ProtocolViolation("End of stream received before headers"))
     case .creatingObserver:
       self.requestBuffer.append(.end)
     case let .observing(observer, _):
@@ -255,7 +247,7 @@ public final class BidirectionalStreamingServerHandler<
         }
 
       case let .failure(error):
-        self.handleError(error)
+        self.handleError(error, thrownFromHandler: true)
       }
 
     case .completed:
@@ -296,16 +288,12 @@ public final class BidirectionalStreamingServerHandler<
     case let .creatingObserver(context), let .observing(_, context):
       switch result {
       case let .success(status):
+        // We're sending end back, we're done.
+        self.state = .completed
         self.interceptors.send(.end(status, context.trailers), promise: nil)
 
       case let .failure(error):
-        let (status, trailers) = ServerErrorProcessor.processObserverError(
-          error,
-          headers: context.headers,
-          trailers: context.trailers,
-          delegate: self.context.errorDelegate
-        )
-        self.interceptors.send(.end(status, trailers), promise: nil)
+        self.handleError(error, thrownFromHandler: true)
       }
 
     case .completed:
@@ -314,9 +302,11 @@ public final class BidirectionalStreamingServerHandler<
   }
 
   @inlinable
-  internal func handleError(_ error: Error) {
+  internal func handleError(_ error: Error, thrownFromHandler isHandlerError: Bool = false) {
     switch self.state {
     case .idle:
+      assert(!isHandlerError)
+      self.state = .completed
       // We don't have a promise to fail. Just send back end.
       let (status, trailers) = ServerErrorProcessor.processLibraryError(
         error,
@@ -324,10 +314,31 @@ public final class BidirectionalStreamingServerHandler<
       )
       self.interceptors.send(.end(status, trailers), promise: nil)
 
-    case let .creatingObserver(context):
-      context.statusPromise.fail(error)
+    case let .creatingObserver(context),
+         let .observing(_, context):
+      // We don't have a promise to fail. Just send back end.
+      self.state = .completed
+
+      let status: GRPCStatus
+      let trailers: HPACKHeaders
 
-    case let .observing(_, context):
+      if isHandlerError {
+        (status, trailers) = ServerErrorProcessor.processObserverError(
+          error,
+          headers: context.headers,
+          trailers: context.trailers,
+          delegate: self.context.errorDelegate
+        )
+      } else {
+        (status, trailers) = ServerErrorProcessor.processLibraryError(
+          error,
+          delegate: self.context.errorDelegate
+        )
+      }
+
+      self.interceptors.send(.end(status, trailers), promise: nil)
+      // We're already in the 'completed' state so failing the promise will be a no-op in the
+      // callback to 'userFunctionStatusResolved' (but we also need to avoid leaking the promise.)
       context.statusPromise.fail(error)
 
     case .completed:

+ 44 - 26
Sources/GRPC/CallHandlers/ClientStreamingServerHandler.swift

@@ -123,16 +123,12 @@ public final class ClientStreamingServerHandler<
 
   @inlinable
   public func receiveError(_ error: Error) {
-    self._finish(error: error)
+    self.handleError(error)
+    self.finish()
   }
 
   @inlinable
   public func finish() {
-    self._finish(error: nil)
-  }
-
-  @inlinable
-  internal func _finish(error: Error?) {
     switch self.state {
     case .idle:
       self.interceptors = nil
@@ -140,8 +136,7 @@ public final class ClientStreamingServerHandler<
 
     case let .creatingObserver(context),
          let .observing(_, context):
-      let error = error ?? GRPCStatus(code: .unavailable, message: nil)
-      context.responsePromise.fail(error)
+      context.responsePromise.fail(GRPCStatus(code: .unavailable, message: nil))
 
     case .completed:
       self.interceptors = nil
@@ -178,7 +173,7 @@ public final class ClientStreamingServerHandler<
       self.state = .creatingObserver(context)
 
       // Register a callback on the response future.
-      context.responsePromise.futureResult.whenComplete(self.userFunctionCompleted(with:))
+      context.responsePromise.futureResult.whenComplete(self.userFunctionCompletedWithResult(_:))
 
       // Make an observer block and register a completion block.
       self.handlerFactory(context).whenComplete(self.userFunctionResolved(_:))
@@ -187,7 +182,7 @@ public final class ClientStreamingServerHandler<
       self.interceptors.send(.metadata([:]), promise: nil)
 
     case .creatingObserver, .observing:
-      self.handleError(GRPCError.InvalidState("Protocol violation: already received headers"))
+      self.handleError(GRPCError.ProtocolViolation("Multiple header blocks received"))
 
     case .completed:
       // We may receive headers from the interceptor pipeline if we have already finished (i.e. due
@@ -201,8 +196,7 @@ public final class ClientStreamingServerHandler<
   internal func receiveInterceptedMessage(_ request: Request) {
     switch self.state {
     case .idle:
-      self
-        .handleError(GRPCError.InvalidState("Protocol violation: message received before headers"))
+      self.handleError(GRPCError.ProtocolViolation("Message received before headers"))
     case .creatingObserver:
       self.requestBuffer.append(.message(request))
     case let .observing(observer, _):
@@ -218,7 +212,7 @@ public final class ClientStreamingServerHandler<
   internal func receiveInterceptedEnd() {
     switch self.state {
     case .idle:
-      self.handleError(GRPCError.InvalidState("Protocol violation: 'end received before headers'"))
+      self.handleError(GRPCError.ProtocolViolation("end received before headers"))
     case .creatingObserver:
       self.requestBuffer.append(.end)
     case let .observing(observer, _):
@@ -250,7 +244,7 @@ public final class ClientStreamingServerHandler<
         }
 
       case let .failure(error):
-        self.handleError(error)
+        self.handleError(error, thrownFromHandler: true)
       }
 
     case .completed:
@@ -260,7 +254,7 @@ public final class ClientStreamingServerHandler<
   }
 
   @inlinable
-  internal func userFunctionCompleted(with result: Result<Response, Error>) {
+  internal func userFunctionCompletedWithResult(_ result: Result<Response, Error>) {
     switch self.state {
     case .idle:
       // Invalid state: the user function can only complete if it exists..
@@ -268,22 +262,20 @@ public final class ClientStreamingServerHandler<
 
     case let .creatingObserver(context),
          let .observing(_, context):
-      self.state = .completed
-
       switch result {
       case let .success(response):
-        let metadata = MessageMetadata(compress: false, flush: false)
+        // Complete when we send end.
+        self.state = .completed
+
+        // Compression depends on whether it's enabled on the server and the setting in the caller
+        // context.
+        let compress = self.context.encoding.isEnabled && context.compressionEnabled
+        let metadata = MessageMetadata(compress: compress, flush: false)
         self.interceptors.send(.message(response, metadata), promise: nil)
         self.interceptors.send(.end(context.responseStatus, context.trailers), promise: nil)
 
       case let .failure(error):
-        let (status, trailers) = ServerErrorProcessor.processObserverError(
-          error,
-          headers: context.headers,
-          trailers: context.trailers,
-          delegate: self.context.errorDelegate
-        )
-        self.interceptors.send(.end(status, trailers), promise: nil)
+        self.handleError(error, thrownFromHandler: true)
       }
 
     case .completed:
@@ -293,9 +285,11 @@ public final class ClientStreamingServerHandler<
   }
 
   @inlinable
-  internal func handleError(_ error: Error) {
+  internal func handleError(_ error: Error, thrownFromHandler isHandlerError: Bool = false) {
     switch self.state {
     case .idle:
+      assert(!isHandlerError)
+      self.state = .completed
       // We don't have a promise to fail. Just send back end.
       let (status, trailers) = ServerErrorProcessor.processLibraryError(
         error,
@@ -305,6 +299,30 @@ public final class ClientStreamingServerHandler<
 
     case let .creatingObserver(context),
          let .observing(_, context):
+      // We don't have a promise to fail. Just send back end.
+      self.state = .completed
+
+      let status: GRPCStatus
+      let trailers: HPACKHeaders
+
+      if isHandlerError {
+        (status, trailers) = ServerErrorProcessor.processObserverError(
+          error,
+          headers: context.headers,
+          trailers: context.trailers,
+          delegate: self.context.errorDelegate
+        )
+      } else {
+        (status, trailers) = ServerErrorProcessor.processLibraryError(
+          error,
+          delegate: self.context.errorDelegate
+        )
+      }
+
+      self.interceptors.send(.end(status, trailers), promise: nil)
+      // We're already in the 'completed' state so failing the promise will be a no-op in the
+      // callback to 'userFunctionCompletedWithResult' (but we also need to avoid leaking the
+      // promise.)
       context.responsePromise.fail(error)
 
     case .completed:

+ 1 - 0
Sources/GRPC/CallHandlers/ServerStreamingCallHandler.swift

@@ -145,6 +145,7 @@ public final class ServerStreamingCallHandler<
         headers: headers,
         logger: self.logger,
         userInfoRef: self._userInfoRef,
+        compressionIsEnabled: self._callHandlerContext.encoding.isEnabled,
         sendResponse: self.sendResponse(_:metadata:promise:)
       )
       let observer = factory(context)

+ 56 - 24
Sources/GRPC/CallHandlers/ServerStreamingServerHandler.swift

@@ -119,16 +119,12 @@ public final class ServerStreamingServerHandler<
 
   @inlinable
   public func receiveError(_ error: Error) {
-    self._finish(error: error)
+    self.handleError(error)
+    self.finish()
   }
 
   @inlinable
   public func finish() {
-    self._finish(error: nil)
-  }
-
-  @inlinable
-  internal func _finish(error: Error?) {
     switch self.state {
     case .idle:
       self.interceptors = nil
@@ -136,8 +132,7 @@ public final class ServerStreamingServerHandler<
 
     case let .createdContext(context),
          let .invokedFunction(context):
-      let error = error ?? GRPCStatus(code: .unavailable, message: nil)
-      context.statusPromise.fail(error)
+      context.statusPromise.fail(GRPCStatus(code: .unavailable, message: nil))
 
     case .completed:
       self.interceptors = nil
@@ -154,8 +149,7 @@ public final class ServerStreamingServerHandler<
     case let .message(message):
       self.receiveInterceptedMessage(message)
     case .end:
-      // Ignored.
-      ()
+      self.receiveInterceptedEnd()
     }
   }
 
@@ -169,6 +163,7 @@ public final class ServerStreamingServerHandler<
         headers: headers,
         logger: self.context.logger,
         userInfoRef: self.userInfoRef,
+        compressionIsEnabled: self.context.encoding.isEnabled,
         sendResponse: self.interceptResponse(_:metadata:promise:)
       )
 
@@ -196,15 +191,17 @@ public final class ServerStreamingServerHandler<
   internal func receiveInterceptedMessage(_ request: Request) {
     switch self.state {
     case .idle:
-      self.handleError(
-        GRPCError.InvalidState("Protocol violation: message received before headers")
-      )
+      self.handleError(GRPCError.ProtocolViolation("Message received before headers"))
+
     case let .createdContext(context):
       self.state = .invokedFunction(context)
       // Complete the status promise with the function outcome.
       context.statusPromise.completeWith(self.userFunction(request, context))
+
     case .invokedFunction:
-      self.handleError(GRPCError.InvalidState("Protocol violation: already received message"))
+      let error = GRPCError.ProtocolViolation("Multiple messages received on server streaming RPC")
+      self.handleError(error)
+
     case .completed:
       // We received a message but we're already done: this may happen if we terminate the RPC
       // due to a channel error, for example.
@@ -212,6 +209,20 @@ public final class ServerStreamingServerHandler<
     }
   }
 
+  @inlinable
+  internal func receiveInterceptedEnd() {
+    switch self.state {
+    case .idle:
+      self.handleError(GRPCError.ProtocolViolation("End received before headers"))
+
+    case .createdContext:
+      self.handleError(GRPCError.ProtocolViolation("End received before message"))
+
+    case .invokedFunction, .completed:
+      ()
+    }
+  }
+
   // MARK: - User Function To Interceptors
 
   @inlinable
@@ -244,20 +255,15 @@ public final class ServerStreamingServerHandler<
 
     case let .createdContext(context),
          let .invokedFunction(context):
-      self.state = .completed
 
       switch result {
       case let .success(status):
+        // We're sending end back, we're done.
+        self.state = .completed
         self.interceptors.send(.end(status, context.trailers), promise: nil)
 
       case let .failure(error):
-        let (status, trailers) = ServerErrorProcessor.processObserverError(
-          error,
-          headers: context.headers,
-          trailers: context.trailers,
-          delegate: self.context.errorDelegate
-        )
-        self.interceptors.send(.end(status, trailers), promise: nil)
+        self.handleError(error, thrownFromHandler: true)
       }
 
     case .completed:
@@ -277,7 +283,7 @@ public final class ServerStreamingServerHandler<
 
     case let .message(message, metadata):
       do {
-        let bytes = try self.serializer.serialize(message, allocator: ByteBufferAllocator())
+        let bytes = try self.serializer.serialize(message, allocator: self.context.allocator)
         self.context.responseWriter.sendMessage(bytes, metadata: metadata, promise: promise)
       } catch {
         // Serialization failed: fail the promise and send end.
@@ -296,9 +302,11 @@ public final class ServerStreamingServerHandler<
   }
 
   @inlinable
-  internal func handleError(_ error: Error) {
+  internal func handleError(_ error: Error, thrownFromHandler isHandlerError: Bool = false) {
     switch self.state {
     case .idle:
+      assert(!isHandlerError)
+      self.state = .completed
       // We don't have a promise to fail. Just send back end.
       let (status, trailers) = ServerErrorProcessor.processLibraryError(
         error,
@@ -308,6 +316,30 @@ public final class ServerStreamingServerHandler<
 
     case let .createdContext(context),
          let .invokedFunction(context):
+      // We don't have a promise to fail. Just send back end.
+      self.state = .completed
+
+      let status: GRPCStatus
+      let trailers: HPACKHeaders
+
+      if isHandlerError {
+        (status, trailers) = ServerErrorProcessor.processObserverError(
+          error,
+          headers: context.headers,
+          trailers: context.trailers,
+          delegate: self.context.errorDelegate
+        )
+      } else {
+        (status, trailers) = ServerErrorProcessor.processLibraryError(
+          error,
+          delegate: self.context.errorDelegate
+        )
+      }
+
+      self.interceptors.send(.end(status, trailers), promise: nil)
+      // We're already in the 'completed' state so failing the promise will be a no-op in the
+      // callback to 'userFunctionCompletedWithResult' (but we also need to avoid leaking the
+      // promise.)
       context.statusPromise.fail(error)
 
     case .completed:

+ 57 - 25
Sources/GRPC/CallHandlers/UnaryServerHandler.swift

@@ -117,16 +117,12 @@ public final class UnaryServerHandler<
 
   @inlinable
   public func receiveError(_ error: Error) {
-    self._finish(error: error)
+    self.handleError(error)
+    self.finish()
   }
 
   @inlinable
   public func finish() {
-    self._finish(error: nil)
-  }
-
-  @inlinable
-  internal func _finish(error: Error?) {
     switch self.state {
     case .idle:
       self.interceptors = nil
@@ -134,8 +130,7 @@ public final class UnaryServerHandler<
 
     case let .createdContext(context),
          let .invokedFunction(context):
-      let error = error ?? GRPCStatus(code: .unavailable, message: nil)
-      context.responsePromise.fail(error)
+      context.responsePromise.fail(GRPCStatus(code: .unavailable, message: nil))
 
     case .completed:
       self.interceptors = nil
@@ -152,8 +147,7 @@ public final class UnaryServerHandler<
     case let .message(message):
       self.receiveInterceptedMessage(message)
     case .end:
-      // Ignored.
-      ()
+      self.receiveInterceptedEnd()
     }
   }
 
@@ -179,7 +173,7 @@ public final class UnaryServerHandler<
       self.interceptors.send(.metadata([:]), promise: nil)
 
     case .createdContext, .invokedFunction:
-      self.handleError(GRPCError.InvalidState("Protocol violation: already received headers"))
+      self.handleError(GRPCError.ProtocolViolation("Multiple header blocks received on RPC"))
 
     case .completed:
       // We may receive headers from the interceptor pipeline if we have already finished (i.e. due
@@ -193,8 +187,7 @@ public final class UnaryServerHandler<
   internal func receiveInterceptedMessage(_ request: Request) {
     switch self.state {
     case .idle:
-      self
-        .handleError(GRPCError.InvalidState("Protocol violation: message received before headers"))
+      self.handleError(GRPCError.ProtocolViolation("Message received before headers"))
 
     case let .createdContext(context):
       // Happy path: execute the function; complete the promise with the result.
@@ -203,7 +196,7 @@ public final class UnaryServerHandler<
 
     case .invokedFunction:
       // The function's already been invoked with a message.
-      self.handleError(GRPCError.InvalidState("Protocol violation: already received message"))
+      self.handleError(GRPCError.ProtocolViolation("Multiple messages received on unary RPC"))
 
     case .completed:
       // We received a message but we're already done: this may happen if we terminate the RPC
@@ -212,6 +205,20 @@ public final class UnaryServerHandler<
     }
   }
 
+  @inlinable
+  internal func receiveInterceptedEnd() {
+    switch self.state {
+    case .idle:
+      self.handleError(GRPCError.ProtocolViolation("End received before headers"))
+
+    case .createdContext:
+      self.handleError(GRPCError.ProtocolViolation("End received before message"))
+
+    case .invokedFunction, .completed:
+      ()
+    }
+  }
+
   // MARK: - User Function To Interceptors
 
   @inlinable
@@ -225,22 +232,21 @@ public final class UnaryServerHandler<
     // but before receiving a message.
     case let .createdContext(context),
          let .invokedFunction(context):
-      self.state = .completed
 
       switch result {
       case let .success(response):
-        let metadata = MessageMetadata(compress: false, flush: false)
+        // Complete, as we're sending 'end'.
+        self.state = .completed
+
+        // Compression depends on whether it's enabled on the server and the setting in the caller
+        // context.
+        let compress = self.context.encoding.isEnabled && context.compressionEnabled
+        let metadata = MessageMetadata(compress: compress, flush: false)
         self.interceptors.send(.message(response, metadata), promise: nil)
         self.interceptors.send(.end(context.responseStatus, context.trailers), promise: nil)
 
       case let .failure(error):
-        let (status, trailers) = ServerErrorProcessor.processObserverError(
-          error,
-          headers: context.headers,
-          trailers: context.trailers,
-          delegate: self.context.errorDelegate
-        )
-        self.interceptors.send(.end(status, trailers), promise: nil)
+        self.handleError(error, thrownFromHandler: true)
       }
 
     case .completed:
@@ -279,10 +285,12 @@ public final class UnaryServerHandler<
   }
 
   @inlinable
-  internal func handleError(_ error: Error) {
+  internal func handleError(_ error: Error, thrownFromHandler isHandlerError: Bool = false) {
     switch self.state {
     case .idle:
-      // We don't have a promise to fail. Just send end back.
+      assert(!isHandlerError)
+      self.state = .completed
+      // We don't have a promise to fail. Just send back end.
       let (status, trailers) = ServerErrorProcessor.processLibraryError(
         error,
         delegate: self.context.errorDelegate
@@ -291,6 +299,30 @@ public final class UnaryServerHandler<
 
     case let .createdContext(context),
          let .invokedFunction(context):
+      // We don't have a promise to fail. Just send back end.
+      self.state = .completed
+
+      let status: GRPCStatus
+      let trailers: HPACKHeaders
+
+      if isHandlerError {
+        (status, trailers) = ServerErrorProcessor.processObserverError(
+          error,
+          headers: context.headers,
+          trailers: context.trailers,
+          delegate: self.context.errorDelegate
+        )
+      } else {
+        (status, trailers) = ServerErrorProcessor.processLibraryError(
+          error,
+          delegate: self.context.errorDelegate
+        )
+      }
+
+      self.interceptors.send(.end(status, trailers), promise: nil)
+      // We're already in the 'completed' state so failing the promise will be a no-op in the
+      // callback to 'userFunctionCompletedWithResult' (but we also need to avoid leaking the
+      // promise.)
       context.responsePromise.fail(error)
 
     case .completed:

+ 10 - 0
Sources/GRPC/Compression/MessageEncoding.swift

@@ -115,6 +115,16 @@ extension ClientMessageEncoding {
 public enum ServerMessageEncoding {
   case enabled(Configuration)
   case disabled
+
+  @usableFromInline
+  internal var isEnabled: Bool {
+    switch self {
+    case .enabled:
+      return true
+    case .disabled:
+      return false
+    }
+  }
 }
 
 extension ServerMessageEncoding {

+ 18 - 2
Sources/GRPC/GRPCError.swift

@@ -254,7 +254,23 @@ public enum GRPCError {
     public var message: String
 
     public init(_ message: String) {
-      self.message = message
+      self.message = "Invalid state: \(message)"
+    }
+
+    public var description: String {
+      return self.message
+    }
+
+    public func makeGRPCStatus() -> GRPCStatus {
+      return GRPCStatus(code: .internalError, message: self.message)
+    }
+  }
+
+  public struct ProtocolViolation: GRPCErrorProtocol {
+    public var message: String
+
+    public init(_ message: String) {
+      self.message = "Protocol violation: \(message)"
     }
 
     public var description: String {
@@ -262,7 +278,7 @@ public enum GRPCError {
     }
 
     public func makeGRPCStatus() -> GRPCStatus {
-      return GRPCStatus(code: .internalError, message: "Invalid state: \(self.message)")
+      return GRPCStatus(code: .internalError, message: self.message)
     }
   }
 }

+ 16 - 2
Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift

@@ -126,29 +126,43 @@ internal final class _StreamingResponseCallContext<Request, Response>:
   @usableFromInline
   internal let _sendResponse: (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
 
+  @usableFromInline
+  internal let _compressionEnabledOnServer: Bool
+
   @inlinable
   internal init(
     eventLoop: EventLoop,
     headers: HPACKHeaders,
     logger: Logger,
     userInfoRef: Ref<UserInfo>,
+    compressionIsEnabled: Bool,
     sendResponse: @escaping (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
   ) {
     self._sendResponse = sendResponse
+    self._compressionEnabledOnServer = compressionIsEnabled
     super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
   }
 
+  @inlinable
+  internal func shouldCompress(_ compression: Compression) -> Bool {
+    guard self._compressionEnabledOnServer else {
+      return false
+    }
+    return compression.isEnabled(callDefault: self.compressionEnabled)
+  }
+
   @inlinable
   override func sendResponse(
     _ message: Response,
     compression: Compression = .deferToCallDefault,
     promise: EventLoopPromise<Void>?
   ) {
-    let compress = compression.isEnabled(callDefault: self.compressionEnabled)
     if self.eventLoop.inEventLoop {
+      let compress = self.shouldCompress(compression)
       self._sendResponse(message, .init(compress: compress, flush: true), promise)
     } else {
       self.eventLoop.execute {
+        let compress = self.shouldCompress(compression)
         self._sendResponse(message, .init(compress: compress, flush: true), promise)
       }
     }
@@ -175,7 +189,7 @@ internal final class _StreamingResponseCallContext<Request, Response>:
     compression: Compression,
     promise: EventLoopPromise<Void>?
   ) where Response == Messages.Element {
-    let compress = compression.isEnabled(callDefault: self.compressionEnabled)
+    let compress = self.shouldCompress(compression)
     var iterator = messages.makeIterator()
     var next = iterator.next()
 

+ 150 - 6
Tests/GRPCTests/UnaryServerHandlerTests.swift

@@ -23,6 +23,7 @@ import XCTest
 final class ResponseRecorder: GRPCServerResponseWriter {
   var metadata: HPACKHeaders?
   var messages: [ByteBuffer] = []
+  var messageMetadata: [MessageMetadata] = []
   var status: GRPCStatus?
   var trailers: HPACKHeaders?
 
@@ -38,6 +39,7 @@ final class ResponseRecorder: GRPCServerResponseWriter {
     promise: EventLoopPromise<Void>?
   ) {
     self.messages.append(bytes)
+    self.messageMetadata.append(metadata)
     promise?.succeed(())
   }
 
@@ -57,11 +59,11 @@ protocol ServerHandlerTestCase: GRPCTestCase {
 }
 
 extension ServerHandlerTestCase {
-  func makeCallHandlerContext() -> CallHandlerContext {
+  func makeCallHandlerContext(encoding: ServerMessageEncoding = .disabled) -> CallHandlerContext {
     return CallHandlerContext(
       errorDelegate: nil,
       logger: self.logger,
-      encoding: .disabled,
+      encoding: encoding,
       eventLoop: self.eventLoop,
       path: "/ignored",
       remoteAddress: nil,
@@ -79,10 +81,11 @@ class UnaryServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
   let recorder = ResponseRecorder()
 
   private func makeHandler(
+    encoding: ServerMessageEncoding = .disabled,
     function: @escaping (String, StatusOnlyCallContext) -> EventLoopFuture<String>
   ) -> UnaryServerHandler<StringSerializer, StringDeserializer> {
     return UnaryServerHandler(
-      context: self.makeCallHandlerContext(),
+      context: self.makeCallHandlerContext(encoding: encoding),
       requestDeserializer: StringDeserializer(),
       responseSerializer: StringSerializer(),
       interceptors: [],
@@ -124,10 +127,41 @@ class UnaryServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
     handler.finish()
 
     assertThat(self.recorder.messages.first, .is(buffer))
+    assertThat(self.recorder.messageMetadata.first?.compress, .is(false))
     assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
     assertThat(self.recorder.trailers, .is([:]))
   }
 
+  func testHappyPathWithCompressionEnabled() {
+    let handler = self.makeHandler(
+      encoding: .enabled(.init(decompressionLimit: .absolute(.max))),
+      function: self.echo(_:context:)
+    )
+
+    handler.receiveMetadata([:])
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+
+    assertThat(self.recorder.messages.first, .is(buffer))
+    assertThat(self.recorder.messageMetadata.first?.compress, .is(true))
+  }
+
+  func testHappyPathWithCompressionEnabledButDisabledByCaller() {
+    let handler = self.makeHandler(
+      encoding: .enabled(.init(decompressionLimit: .absolute(.max)))
+    ) { request, context in
+      context.compressionEnabled = false
+      return self.echo(request, context: context)
+    }
+
+    handler.receiveMetadata([:])
+    let buffer = ByteBuffer(string: "hello")
+    handler.receiveMessage(buffer)
+
+    assertThat(self.recorder.messages.first, .is(buffer))
+    assertThat(self.recorder.messageMetadata.first?.compress, .is(false))
+  }
+
   func testThrowingDeserializer() {
     let handler = UnaryServerHandler(
       context: self.makeCallHandlerContext(),
@@ -262,11 +296,12 @@ class ClientStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
   let recorder = ResponseRecorder()
 
   private func makeHandler(
+    encoding: ServerMessageEncoding = .disabled,
     observerFactory: @escaping (UnaryResponseCallContext<String>)
       -> EventLoopFuture<(StreamEvent<String>) -> Void>
   ) -> ClientStreamingServerHandler<StringSerializer, StringDeserializer> {
     return ClientStreamingServerHandler(
-      context: self.makeCallHandlerContext(),
+      context: self.makeCallHandlerContext(encoding: encoding),
       requestDeserializer: StringDeserializer(),
       responseSerializer: StringSerializer(),
       interceptors: [],
@@ -323,10 +358,45 @@ class ClientStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
     handler.finish()
 
     assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "1 2 3")))
+    assertThat(self.recorder.messageMetadata.first?.compress, .is(false))
     assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
     assertThat(self.recorder.trailers, .is([:]))
   }
 
+  func testHappyPathWithCompressionEnabled() {
+    let handler = self.makeHandler(
+      encoding: .enabled(.init(decompressionLimit: .absolute(.max))),
+      observerFactory: self.joinWithSpaces(context:)
+    )
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveMessage(ByteBuffer(string: "3"))
+    handler.receiveEnd()
+
+    assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "1 2 3")))
+    assertThat(self.recorder.messageMetadata.first?.compress, .is(true))
+  }
+
+  func testHappyPathWithCompressionEnabledButDisabledByCaller() {
+    let handler = self.makeHandler(
+      encoding: .enabled(.init(decompressionLimit: .absolute(.max)))
+    ) { context in
+      context.compressionEnabled = false
+      return self.joinWithSpaces(context: context)
+    }
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveMessage(ByteBuffer(string: "3"))
+    handler.receiveEnd()
+
+    assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "1 2 3")))
+    assertThat(self.recorder.messageMetadata.first?.compress, .is(false))
+  }
+
   func testThrowingDeserializer() {
     let handler = ClientStreamingServerHandler(
       context: self.makeCallHandlerContext(),
@@ -483,11 +553,12 @@ class ServerStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
   let recorder = ResponseRecorder()
 
   private func makeHandler(
+    encoding: ServerMessageEncoding = .disabled,
     userFunction: @escaping (String, StreamingResponseCallContext<String>)
       -> EventLoopFuture<GRPCStatus>
   ) -> ServerStreamingServerHandler<StringSerializer, StringDeserializer> {
     return ServerStreamingServerHandler(
-      context: self.makeCallHandlerContext(),
+      context: self.makeCallHandlerContext(encoding: encoding),
       requestDeserializer: StringDeserializer(),
       responseSerializer: StringSerializer(),
       interceptors: [],
@@ -535,10 +606,41 @@ class ServerStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestCase {
       self.recorder.messages,
       .is([ByteBuffer(string: "a"), ByteBuffer(string: "b")])
     )
+    assertThat(self.recorder.messageMetadata.map { $0.compress }, .is([false, false]))
     assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
     assertThat(self.recorder.trailers, .is([:]))
   }
 
+  func testHappyPathWithCompressionEnabled() {
+    let handler = self.makeHandler(
+      encoding: .enabled(.init(decompressionLimit: .absolute(.max))),
+      userFunction: self.breakOnSpaces(_:context:)
+    )
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "a"))
+    handler.receiveEnd()
+
+    assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "a")))
+    assertThat(self.recorder.messageMetadata.first?.compress, .is(true))
+  }
+
+  func testHappyPathWithCompressionEnabledButDisabledByCaller() {
+    let handler = self.makeHandler(
+      encoding: .enabled(.init(decompressionLimit: .absolute(.max)))
+    ) { request, context in
+      context.compressionEnabled = false
+      return self.breakOnSpaces(request, context: context)
+    }
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "a"))
+    handler.receiveEnd()
+
+    assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "a")))
+    assertThat(self.recorder.messageMetadata.first?.compress, .is(false))
+  }
+
   func testThrowingDeserializer() {
     let handler = ServerStreamingServerHandler(
       context: self.makeCallHandlerContext(),
@@ -673,11 +775,12 @@ class BidirectionalStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestC
   let recorder = ResponseRecorder()
 
   private func makeHandler(
+    encoding: ServerMessageEncoding = .disabled,
     observerFactory: @escaping (StreamingResponseCallContext<String>)
       -> EventLoopFuture<(StreamEvent<String>) -> Void>
   ) -> BidirectionalStreamingServerHandler<StringSerializer, StringDeserializer> {
     return BidirectionalStreamingServerHandler(
-      context: self.makeCallHandlerContext(),
+      context: self.makeCallHandlerContext(encoding: encoding),
       requestDeserializer: StringDeserializer(),
       responseSerializer: StringSerializer(),
       interceptors: [],
@@ -736,10 +839,51 @@ class BidirectionalStreamingServerHandlerTests: GRPCTestCase, ServerHandlerTestC
       self.recorder.messages,
       .is([ByteBuffer(string: "1"), ByteBuffer(string: "2"), ByteBuffer(string: "3")])
     )
+    assertThat(self.recorder.messageMetadata.map { $0.compress }, .is([false, false, false]))
     assertThat(self.recorder.status, .notNil(.hasCode(.ok)))
     assertThat(self.recorder.trailers, .is([:]))
   }
 
+  func testHappyPathWithCompressionEnabled() {
+    let handler = self.makeHandler(
+      encoding: .enabled(.init(decompressionLimit: .absolute(.max))),
+      observerFactory: self.echo(context:)
+    )
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveMessage(ByteBuffer(string: "3"))
+    handler.receiveEnd()
+
+    assertThat(
+      self.recorder.messages,
+      .is([ByteBuffer(string: "1"), ByteBuffer(string: "2"), ByteBuffer(string: "3")])
+    )
+    assertThat(self.recorder.messageMetadata.map { $0.compress }, .is([true, true, true]))
+  }
+
+  func testHappyPathWithCompressionEnabledButDisabledByCaller() {
+    let handler = self.makeHandler(
+      encoding: .enabled(.init(decompressionLimit: .absolute(.max)))
+    ) { context in
+      context.compressionEnabled = false
+      return self.echo(context: context)
+    }
+
+    handler.receiveMetadata([:])
+    handler.receiveMessage(ByteBuffer(string: "1"))
+    handler.receiveMessage(ByteBuffer(string: "2"))
+    handler.receiveMessage(ByteBuffer(string: "3"))
+    handler.receiveEnd()
+
+    assertThat(
+      self.recorder.messages,
+      .is([ByteBuffer(string: "1"), ByteBuffer(string: "2"), ByteBuffer(string: "3")])
+    )
+    assertThat(self.recorder.messageMetadata.map { $0.compress }, .is([false, false, false]))
+  }
+
   func testThrowingDeserializer() {
     let handler = BidirectionalStreamingServerHandler(
       context: self.makeCallHandlerContext(),