UnaryCallHandler.swift 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Foundation
  17. import Logging
  18. import NIO
  19. import NIOHTTP1
  20. import SwiftProtobuf
  21. /// Handles unary calls. Calls the observer block with the request message.
  22. ///
  23. /// - The observer block is implemented by the framework user and returns a future containing the call result.
  24. /// - To return a response to the client, the framework user should complete that future
  25. /// (similar to e.g. serving regular HTTP requests in frameworks such as Vapor).
  26. public final class UnaryCallHandler<
  27. RequestPayload,
  28. ResponsePayload
  29. >: _BaseCallHandler<RequestPayload, ResponsePayload> {
  30. public typealias EventObserver = (RequestPayload) -> EventLoopFuture<ResponsePayload>
  31. private var eventObserver: EventObserver?
  32. private var callContext: UnaryResponseCallContext<ResponsePayload>?
  33. private let eventObserverFactory: (UnaryResponseCallContext<ResponsePayload>) -> EventObserver
  34. internal init<Serializer: MessageSerializer, Deserializer: MessageDeserializer>(
  35. serializer: Serializer,
  36. deserializer: Deserializer,
  37. callHandlerContext: CallHandlerContext,
  38. eventObserverFactory: @escaping (UnaryResponseCallContext<ResponsePayload>) -> EventObserver
  39. ) where Serializer.Input == ResponsePayload, Deserializer.Output == RequestPayload {
  40. self.eventObserverFactory = eventObserverFactory
  41. super.init(
  42. callHandlerContext: callHandlerContext,
  43. codec: GRPCServerCodecHandler(serializer: serializer, deserializer: deserializer)
  44. )
  45. }
  46. override internal func processHead(_ head: HTTPRequestHead, context: ChannelHandlerContext) {
  47. let callContext = UnaryResponseCallContextImpl<ResponsePayload>(
  48. channel: context.channel,
  49. request: head,
  50. errorDelegate: self.errorDelegate,
  51. logger: self.logger
  52. )
  53. self.callContext = callContext
  54. self.eventObserver = self.eventObserverFactory(callContext)
  55. callContext.responsePromise.futureResult.whenComplete { _ in
  56. // When done, reset references to avoid retain cycles.
  57. self.eventObserver = nil
  58. self.callContext = nil
  59. }
  60. context.writeAndFlush(self.wrapOutboundOut(.headers([:])), promise: nil)
  61. }
  62. override internal func processMessage(_ message: RequestPayload) throws {
  63. guard let eventObserver = self.eventObserver,
  64. let context = self.callContext else {
  65. self.logger.error(
  66. "processMessage(_:) called before the call started or after the call completed",
  67. source: "GRPC"
  68. )
  69. throw GRPCError.StreamCardinalityViolation.request.captureContext()
  70. }
  71. let resultFuture = eventObserver(message)
  72. resultFuture
  73. // Fulfil the response promise with whatever response (or error) the framework user has provided.
  74. .cascade(to: context.responsePromise)
  75. self.eventObserver = nil
  76. }
  77. override internal func endOfStreamReceived() throws {
  78. if self.eventObserver != nil {
  79. throw GRPCError.StreamCardinalityViolation.request.captureContext()
  80. }
  81. }
  82. override internal func sendErrorStatusAndMetadata(_ statusAndMetadata: GRPCStatusAndMetadata) {
  83. if let metadata = statusAndMetadata.metadata {
  84. self.callContext?.trailingMetadata.add(contentsOf: metadata)
  85. }
  86. self.callContext?.responsePromise.fail(statusAndMetadata.status)
  87. }
  88. }