StreamingResponseCallContext.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Foundation
  17. import Logging
  18. import NIO
  19. import NIOHPACK
  20. import NIOHTTP1
  21. import SwiftProtobuf
  22. /// Abstract base class exposing a method to send multiple messages over the wire and a promise for the final RPC status.
  23. ///
  24. /// - When `statusPromise` is fulfilled, the call is closed and the provided status transmitted.
  25. /// - If `statusPromise` is failed and the error is of type `GRPCStatusTransformable`,
  26. /// the result of `error.asGRPCStatus()` will be returned to the client.
  27. /// - If `error.asGRPCStatus()` is not available, `GRPCStatus.processingError` is returned to the client.
  28. open class StreamingResponseCallContext<ResponsePayload>: ServerCallContextBase {
  29. typealias WrappedResponse = GRPCServerResponsePart<ResponsePayload>
  30. public let statusPromise: EventLoopPromise<GRPCStatus>
  31. public convenience init(
  32. eventLoop: EventLoop,
  33. headers: HPACKHeaders,
  34. logger: Logger,
  35. userInfo: UserInfo = UserInfo()
  36. ) {
  37. self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
  38. }
  39. @inlinable
  40. override internal init(
  41. eventLoop: EventLoop,
  42. headers: HPACKHeaders,
  43. logger: Logger,
  44. userInfoRef: Ref<UserInfo>
  45. ) {
  46. self.statusPromise = eventLoop.makePromise()
  47. super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
  48. }
  49. @available(*, deprecated, renamed: "init(eventLoop:path:headers:logger:userInfo:)")
  50. override public init(eventLoop: EventLoop, request: HTTPRequestHead, logger: Logger) {
  51. self.statusPromise = eventLoop.makePromise()
  52. super.init(eventLoop: eventLoop, request: request, logger: logger)
  53. }
  54. /// Send a response to the client.
  55. ///
  56. /// - Parameters:
  57. /// - message: The message to send to the client.
  58. /// - compression: Whether compression should be used for this response. If compression
  59. /// is enabled in the call context, the value passed here takes precedence. Defaults to
  60. /// deferring to the value set on the call context.
  61. /// - promise: A promise to complete once the message has been sent.
  62. open func sendResponse(
  63. _ message: ResponsePayload,
  64. compression: Compression = .deferToCallDefault,
  65. promise: EventLoopPromise<Void>?
  66. ) {
  67. fatalError("needs to be overridden")
  68. }
  69. /// Send a response to the client.
  70. ///
  71. /// - Parameters:
  72. /// - message: The message to send to the client.
  73. /// - compression: Whether compression should be used for this response. If compression
  74. /// is enabled in the call context, the value passed here takes precedence. Defaults to
  75. /// deferring to the value set on the call context.
  76. open func sendResponse(
  77. _ message: ResponsePayload,
  78. compression: Compression = .deferToCallDefault
  79. ) -> EventLoopFuture<Void> {
  80. let promise = self.eventLoop.makePromise(of: Void.self)
  81. self.sendResponse(message, compression: compression, promise: promise)
  82. return promise.futureResult
  83. }
  84. /// Sends a sequence of responses to the client.
  85. /// - Parameters:
  86. /// - messages: The messages to send to the client.
  87. /// - compression: Whether compression should be used for this response. If compression
  88. /// is enabled in the call context, the value passed here takes precedence. Defaults to
  89. /// deferring to the value set on the call context.
  90. /// - promise: A promise to complete once the messages have been sent.
  91. open func sendResponses<Messages: Sequence>(
  92. _ messages: Messages,
  93. compression: Compression = .deferToCallDefault,
  94. promise: EventLoopPromise<Void>?
  95. ) where Messages.Element == ResponsePayload {
  96. fatalError("needs to be overridden")
  97. }
  98. /// Sends a sequence of responses to the client.
  99. /// - Parameters:
  100. /// - messages: The messages to send to the client.
  101. /// - compression: Whether compression should be used for this response. If compression
  102. /// is enabled in the call context, the value passed here takes precedence. Defaults to
  103. /// deferring to the value set on the call context.
  104. open func sendResponses<Messages: Sequence>(
  105. _ messages: Messages,
  106. compression: Compression = .deferToCallDefault
  107. ) -> EventLoopFuture<Void> where Messages.Element == ResponsePayload {
  108. let promise = self.eventLoop.makePromise(of: Void.self)
  109. self.sendResponses(messages, compression: compression, promise: promise)
  110. return promise.futureResult
  111. }
  112. }
  113. @usableFromInline
  114. internal final class _StreamingResponseCallContext<Request, Response>:
  115. StreamingResponseCallContext<Response> {
  116. @usableFromInline
  117. internal let _sendResponse: (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
  118. @inlinable
  119. internal init(
  120. eventLoop: EventLoop,
  121. headers: HPACKHeaders,
  122. logger: Logger,
  123. userInfoRef: Ref<UserInfo>,
  124. sendResponse: @escaping (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
  125. ) {
  126. self._sendResponse = sendResponse
  127. super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
  128. }
  129. @inlinable
  130. override func sendResponse(
  131. _ message: Response,
  132. compression: Compression = .deferToCallDefault,
  133. promise: EventLoopPromise<Void>?
  134. ) {
  135. let compress = compression.isEnabled(callDefault: self.compressionEnabled)
  136. if self.eventLoop.inEventLoop {
  137. self._sendResponse(message, .init(compress: compress, flush: true), promise)
  138. } else {
  139. self.eventLoop.execute {
  140. self._sendResponse(message, .init(compress: compress, flush: true), promise)
  141. }
  142. }
  143. }
  144. @inlinable
  145. override func sendResponses<Messages: Sequence>(
  146. _ messages: Messages,
  147. compression: Compression = .deferToCallDefault,
  148. promise: EventLoopPromise<Void>?
  149. ) where Response == Messages.Element {
  150. if self.eventLoop.inEventLoop {
  151. self._sendResponses(messages, compression: compression, promise: promise)
  152. } else {
  153. self.eventLoop.execute {
  154. self._sendResponses(messages, compression: compression, promise: promise)
  155. }
  156. }
  157. }
  158. @inlinable
  159. internal func _sendResponses<Messages: Sequence>(
  160. _ messages: Messages,
  161. compression: Compression,
  162. promise: EventLoopPromise<Void>?
  163. ) where Response == Messages.Element {
  164. let compress = compression.isEnabled(callDefault: self.compressionEnabled)
  165. var iterator = messages.makeIterator()
  166. var next = iterator.next()
  167. while let current = next {
  168. next = iterator.next()
  169. // Attach the promise, if present, to the last message.
  170. let isLast = next == nil
  171. self._sendResponse(current, .init(compress: compress, flush: isLast), isLast ? promise : nil)
  172. }
  173. }
  174. }
  175. /// Concrete implementation of `StreamingResponseCallContext` used by our generated code.
  176. open class StreamingResponseCallContextImpl<ResponsePayload>: StreamingResponseCallContext<ResponsePayload> {
  177. public let channel: Channel
  178. /// - Parameters:
  179. /// - channel: The NIO channel the call is handled on.
  180. /// - headers: The headers provided with this call.
  181. /// - errorDelegate: Provides a means for transforming status promise failures to `GRPCStatusTransformable` before
  182. /// sending them to the client.
  183. /// - logger: A logger.
  184. ///
  185. /// Note: `errorDelegate` is not called for status promise that are `succeeded` with a non-OK status.
  186. public init(
  187. channel: Channel,
  188. headers: HPACKHeaders,
  189. errorDelegate: ServerErrorDelegate?,
  190. logger: Logger
  191. ) {
  192. self.channel = channel
  193. super.init(
  194. eventLoop: channel.eventLoop,
  195. headers: headers,
  196. logger: logger,
  197. userInfoRef: Ref(UserInfo())
  198. )
  199. self.statusPromise.futureResult.whenComplete { result in
  200. switch result {
  201. case let .success(status):
  202. self.channel.writeAndFlush(
  203. self.wrap(.end(status, self.trailers)),
  204. promise: nil
  205. )
  206. case let .failure(error):
  207. let (status, trailers) = self.processObserverError(error, delegate: errorDelegate)
  208. self.channel.writeAndFlush(self.wrap(.end(status, trailers)), promise: nil)
  209. }
  210. }
  211. }
  212. /// Wrap the response part in a `NIOAny`. This is useful in order to avoid explicitly spelling
  213. /// out `NIOAny(WrappedResponse(...))`.
  214. private func wrap(_ response: WrappedResponse) -> NIOAny {
  215. return NIOAny(response)
  216. }
  217. @available(*, deprecated, renamed: "init(channel:headers:errorDelegate:logger:)")
  218. public convenience init(
  219. channel: Channel,
  220. request: HTTPRequestHead,
  221. errorDelegate: ServerErrorDelegate?,
  222. logger: Logger
  223. ) {
  224. self.init(
  225. channel: channel,
  226. headers: HPACKHeaders(httpHeaders: request.headers, normalizeHTTPHeaders: false),
  227. errorDelegate: errorDelegate,
  228. logger: logger
  229. )
  230. }
  231. override open func sendResponse(
  232. _ message: ResponsePayload,
  233. compression: Compression = .deferToCallDefault,
  234. promise: EventLoopPromise<Void>?
  235. ) {
  236. let compress = compression.isEnabled(callDefault: self.compressionEnabled)
  237. self.channel.write(
  238. self.wrap(.message(message, .init(compress: compress, flush: true))),
  239. promise: promise
  240. )
  241. }
  242. override open func sendResponses<Messages: Sequence>(
  243. _ messages: Messages,
  244. compression: Compression = .deferToCallDefault,
  245. promise: EventLoopPromise<Void>?
  246. ) where ResponsePayload == Messages.Element {
  247. let compress = compression.isEnabled(callDefault: self.compressionEnabled)
  248. var iterator = messages.makeIterator()
  249. var next = iterator.next()
  250. while let current = next {
  251. next = iterator.next()
  252. // Attach the promise, if present, to the last message.
  253. let isLast = next == nil
  254. self.channel.write(
  255. self.wrap(.message(current, .init(compress: compress, flush: isLast))),
  256. promise: isLast ? promise : nil
  257. )
  258. }
  259. }
  260. }
  261. /// Concrete implementation of `StreamingResponseCallContext` used for testing.
  262. ///
  263. /// Simply records all sent messages.
  264. open class StreamingResponseCallContextTestStub<ResponsePayload>: StreamingResponseCallContext<ResponsePayload> {
  265. open var recordedResponses: [ResponsePayload] = []
  266. override open func sendResponse(
  267. _ message: ResponsePayload,
  268. compression: Compression = .deferToCallDefault,
  269. promise: EventLoopPromise<Void>?
  270. ) {
  271. self.recordedResponses.append(message)
  272. promise?.succeed(())
  273. }
  274. override open func sendResponses<Messages: Sequence>(
  275. _ messages: Messages,
  276. compression: Compression = .deferToCallDefault,
  277. promise: EventLoopPromise<Void>?
  278. ) where ResponsePayload == Messages.Element {
  279. self.recordedResponses.append(contentsOf: messages)
  280. promise?.succeed(())
  281. }
  282. }