StreamingResponseCallContext.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Foundation
  17. import Logging
  18. import NIO
  19. import NIOHPACK
  20. import NIOHTTP1
  21. import SwiftProtobuf
  22. /// Abstract base class exposing a method to send multiple messages over the wire and a promise for the final RPC status.
  23. ///
  24. /// - When `statusPromise` is fulfilled, the call is closed and the provided status transmitted.
  25. /// - If `statusPromise` is failed and the error is of type `GRPCStatusTransformable`,
  26. /// the result of `error.asGRPCStatus()` will be returned to the client.
  27. /// - If `error.asGRPCStatus()` is not available, `GRPCStatus.processingError` is returned to the client.
  28. open class StreamingResponseCallContext<ResponsePayload>: ServerCallContextBase {
  29. typealias WrappedResponse = GRPCServerResponsePart<ResponsePayload>
  30. public let statusPromise: EventLoopPromise<GRPCStatus>
  31. public convenience init(
  32. eventLoop: EventLoop,
  33. headers: HPACKHeaders,
  34. logger: Logger,
  35. userInfo: UserInfo = UserInfo()
  36. ) {
  37. self.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: .init(userInfo))
  38. }
  39. override internal init(
  40. eventLoop: EventLoop,
  41. headers: HPACKHeaders,
  42. logger: Logger,
  43. userInfoRef: Ref<UserInfo>
  44. ) {
  45. self.statusPromise = eventLoop.makePromise()
  46. super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
  47. }
  48. @available(*, deprecated, renamed: "init(eventLoop:path:headers:logger:userInfo:)")
  49. override public init(eventLoop: EventLoop, request: HTTPRequestHead, logger: Logger) {
  50. self.statusPromise = eventLoop.makePromise()
  51. super.init(eventLoop: eventLoop, request: request, logger: logger)
  52. }
  53. /// Send a response to the client.
  54. ///
  55. /// - Parameters:
  56. /// - message: The message to send to the client.
  57. /// - compression: Whether compression should be used for this response. If compression
  58. /// is enabled in the call context, the value passed here takes precedence. Defaults to
  59. /// deferring to the value set on the call context.
  60. /// - promise: A promise to complete once the message has been sent.
  61. open func sendResponse(
  62. _ message: ResponsePayload,
  63. compression: Compression = .deferToCallDefault,
  64. promise: EventLoopPromise<Void>?
  65. ) {
  66. fatalError("needs to be overridden")
  67. }
  68. /// Send a response to the client.
  69. ///
  70. /// - Parameters:
  71. /// - message: The message to send to the client.
  72. /// - compression: Whether compression should be used for this response. If compression
  73. /// is enabled in the call context, the value passed here takes precedence. Defaults to
  74. /// deferring to the value set on the call context.
  75. open func sendResponse(
  76. _ message: ResponsePayload,
  77. compression: Compression = .deferToCallDefault
  78. ) -> EventLoopFuture<Void> {
  79. let promise = self.eventLoop.makePromise(of: Void.self)
  80. self.sendResponse(message, compression: compression, promise: promise)
  81. return promise.futureResult
  82. }
  83. /// Sends a sequence of responses to the client.
  84. /// - Parameters:
  85. /// - messages: The messages to send to the client.
  86. /// - compression: Whether compression should be used for this response. If compression
  87. /// is enabled in the call context, the value passed here takes precedence. Defaults to
  88. /// deferring to the value set on the call context.
  89. /// - promise: A promise to complete once the messages have been sent.
  90. open func sendResponses<Messages: Sequence>(
  91. _ messages: Messages,
  92. compression: Compression = .deferToCallDefault,
  93. promise: EventLoopPromise<Void>?
  94. ) where Messages.Element == ResponsePayload {
  95. fatalError("needs to be overridden")
  96. }
  97. /// Sends a sequence of responses to the client.
  98. /// - Parameters:
  99. /// - messages: The messages to send to the client.
  100. /// - compression: Whether compression should be used for this response. If compression
  101. /// is enabled in the call context, the value passed here takes precedence. Defaults to
  102. /// deferring to the value set on the call context.
  103. open func sendResponses<Messages: Sequence>(
  104. _ messages: Messages,
  105. compression: Compression = .deferToCallDefault
  106. ) -> EventLoopFuture<Void> where Messages.Element == ResponsePayload {
  107. let promise = self.eventLoop.makePromise(of: Void.self)
  108. self.sendResponses(messages, compression: compression, promise: promise)
  109. return promise.futureResult
  110. }
  111. }
  112. internal final class _StreamingResponseCallContext<Request, Response>:
  113. StreamingResponseCallContext<Response> {
  114. private let _sendResponse: (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
  115. internal init(
  116. eventLoop: EventLoop,
  117. headers: HPACKHeaders,
  118. logger: Logger,
  119. userInfoRef: Ref<UserInfo>,
  120. sendResponse: @escaping (Response, MessageMetadata, EventLoopPromise<Void>?) -> Void
  121. ) {
  122. self._sendResponse = sendResponse
  123. super.init(eventLoop: eventLoop, headers: headers, logger: logger, userInfoRef: userInfoRef)
  124. }
  125. override func sendResponse(
  126. _ message: Response,
  127. compression: Compression = .deferToCallDefault,
  128. promise: EventLoopPromise<Void>?
  129. ) {
  130. let compress = compression.isEnabled(callDefault: self.compressionEnabled)
  131. if self.eventLoop.inEventLoop {
  132. self._sendResponse(message, .init(compress: compress, flush: true), promise)
  133. } else {
  134. self.eventLoop.execute {
  135. self._sendResponse(message, .init(compress: compress, flush: true), promise)
  136. }
  137. }
  138. }
  139. override func sendResponses<Messages: Sequence>(
  140. _ messages: Messages,
  141. compression: Compression = .deferToCallDefault,
  142. promise: EventLoopPromise<Void>?
  143. ) where Response == Messages.Element {
  144. if self.eventLoop.inEventLoop {
  145. self._sendResponses(messages, compression: compression, promise: promise)
  146. } else {
  147. self.eventLoop.execute {
  148. self._sendResponses(messages, compression: compression, promise: promise)
  149. }
  150. }
  151. }
  152. private func _sendResponses<Messages: Sequence>(
  153. _ messages: Messages,
  154. compression: Compression,
  155. promise: EventLoopPromise<Void>?
  156. ) where Response == Messages.Element {
  157. let compress = compression.isEnabled(callDefault: self.compressionEnabled)
  158. var iterator = messages.makeIterator()
  159. var next = iterator.next()
  160. while let current = next {
  161. next = iterator.next()
  162. // Attach the promise, if present, to the last message.
  163. let isLast = next == nil
  164. self._sendResponse(current, .init(compress: compress, flush: isLast), isLast ? promise : nil)
  165. }
  166. }
  167. }
  168. /// Concrete implementation of `StreamingResponseCallContext` used by our generated code.
  169. open class StreamingResponseCallContextImpl<ResponsePayload>: StreamingResponseCallContext<ResponsePayload> {
  170. public let channel: Channel
  171. /// - Parameters:
  172. /// - channel: The NIO channel the call is handled on.
  173. /// - headers: The headers provided with this call.
  174. /// - errorDelegate: Provides a means for transforming status promise failures to `GRPCStatusTransformable` before
  175. /// sending them to the client.
  176. /// - logger: A logger.
  177. ///
  178. /// Note: `errorDelegate` is not called for status promise that are `succeeded` with a non-OK status.
  179. public init(
  180. channel: Channel,
  181. headers: HPACKHeaders,
  182. errorDelegate: ServerErrorDelegate?,
  183. logger: Logger
  184. ) {
  185. self.channel = channel
  186. super.init(
  187. eventLoop: channel.eventLoop,
  188. headers: headers,
  189. logger: logger,
  190. userInfoRef: Ref(UserInfo())
  191. )
  192. self.statusPromise.futureResult.whenComplete { result in
  193. switch result {
  194. case let .success(status):
  195. self.channel.writeAndFlush(
  196. self.wrap(.end(status, self.trailers)),
  197. promise: nil
  198. )
  199. case let .failure(error):
  200. let (status, trailers) = self.processObserverError(error, delegate: errorDelegate)
  201. self.channel.writeAndFlush(self.wrap(.end(status, trailers)), promise: nil)
  202. }
  203. }
  204. }
  205. /// Wrap the response part in a `NIOAny`. This is useful in order to avoid explicitly spelling
  206. /// out `NIOAny(WrappedResponse(...))`.
  207. private func wrap(_ response: WrappedResponse) -> NIOAny {
  208. return NIOAny(response)
  209. }
  210. @available(*, deprecated, renamed: "init(channel:headers:errorDelegate:logger:)")
  211. public convenience init(
  212. channel: Channel,
  213. request: HTTPRequestHead,
  214. errorDelegate: ServerErrorDelegate?,
  215. logger: Logger
  216. ) {
  217. self.init(
  218. channel: channel,
  219. headers: HPACKHeaders(httpHeaders: request.headers, normalizeHTTPHeaders: false),
  220. errorDelegate: errorDelegate,
  221. logger: logger
  222. )
  223. }
  224. override open func sendResponse(
  225. _ message: ResponsePayload,
  226. compression: Compression = .deferToCallDefault,
  227. promise: EventLoopPromise<Void>?
  228. ) {
  229. let compress = compression.isEnabled(callDefault: self.compressionEnabled)
  230. self.channel.write(
  231. self.wrap(.message(message, .init(compress: compress, flush: true))),
  232. promise: promise
  233. )
  234. }
  235. override open func sendResponses<Messages: Sequence>(
  236. _ messages: Messages,
  237. compression: Compression = .deferToCallDefault,
  238. promise: EventLoopPromise<Void>?
  239. ) where ResponsePayload == Messages.Element {
  240. let compress = compression.isEnabled(callDefault: self.compressionEnabled)
  241. var iterator = messages.makeIterator()
  242. var next = iterator.next()
  243. while let current = next {
  244. next = iterator.next()
  245. // Attach the promise, if present, to the last message.
  246. let isLast = next == nil
  247. self.channel.write(
  248. self.wrap(.message(current, .init(compress: compress, flush: isLast))),
  249. promise: isLast ? promise : nil
  250. )
  251. }
  252. }
  253. }
  254. /// Concrete implementation of `StreamingResponseCallContext` used for testing.
  255. ///
  256. /// Simply records all sent messages.
  257. open class StreamingResponseCallContextTestStub<ResponsePayload>: StreamingResponseCallContext<ResponsePayload> {
  258. open var recordedResponses: [ResponsePayload] = []
  259. override open func sendResponse(
  260. _ message: ResponsePayload,
  261. compression: Compression = .deferToCallDefault,
  262. promise: EventLoopPromise<Void>?
  263. ) {
  264. self.recordedResponses.append(message)
  265. promise?.succeed(())
  266. }
  267. override open func sendResponses<Messages: Sequence>(
  268. _ messages: Messages,
  269. compression: Compression = .deferToCallDefault,
  270. promise: EventLoopPromise<Void>?
  271. ) where ResponsePayload == Messages.Element {
  272. self.recordedResponses.append(contentsOf: messages)
  273. promise?.succeed(())
  274. }
  275. }