GRPCClientChannelHandler.swift 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Logging
  17. import NIOCore
  18. import NIOHPACK
  19. import NIOHTTP1
  20. import NIOHTTP2
  21. import SwiftProtobuf
  22. /// A gRPC client request message part.
  23. ///
  24. /// - Important: This is **NOT** part of the public API. It is declared as
  25. /// `public` because it is used within performance tests.
  26. public enum _GRPCClientRequestPart<Request> {
  27. /// The 'head' of the request, that is, information about the initiation of the RPC.
  28. case head(_GRPCRequestHead)
  29. /// A deserialized request message to send to the server.
  30. case message(_MessageContext<Request>)
  31. /// Indicates that the client does not intend to send any further messages.
  32. case end
  33. }
  34. /// As `_GRPCClientRequestPart` but messages are serialized.
  35. /// - Important: This is **NOT** part of the public API.
  36. public typealias _RawGRPCClientRequestPart = _GRPCClientRequestPart<ByteBuffer>
  37. /// A gRPC client response message part.
  38. ///
  39. /// - Important: This is **NOT** part of the public API.
  40. public enum _GRPCClientResponsePart<Response> {
  41. /// Metadata received as the server acknowledges the RPC.
  42. case initialMetadata(HPACKHeaders)
  43. /// A deserialized response message received from the server.
  44. case message(_MessageContext<Response>)
  45. /// The metadata received at the end of the RPC.
  46. case trailingMetadata(HPACKHeaders)
  47. /// The final status of the RPC.
  48. case status(GRPCStatus)
  49. }
  50. /// As `_GRPCClientResponsePart` but messages are serialized.
  51. /// - Important: This is **NOT** part of the public API.
  52. public typealias _RawGRPCClientResponsePart = _GRPCClientResponsePart<ByteBuffer>
  53. /// - Important: This is **NOT** part of the public API. It is declared as
  54. /// `public` because it is used within performance tests.
  55. public struct _GRPCRequestHead {
  56. private final class _Storage {
  57. var method: String
  58. var scheme: String
  59. var path: String
  60. var host: String
  61. var deadline: NIODeadline
  62. var encoding: ClientMessageEncoding
  63. init(
  64. method: String,
  65. scheme: String,
  66. path: String,
  67. host: String,
  68. deadline: NIODeadline,
  69. encoding: ClientMessageEncoding
  70. ) {
  71. self.method = method
  72. self.scheme = scheme
  73. self.path = path
  74. self.host = host
  75. self.deadline = deadline
  76. self.encoding = encoding
  77. }
  78. func copy() -> _Storage {
  79. return .init(
  80. method: self.method,
  81. scheme: self.scheme,
  82. path: self.path,
  83. host: self.host,
  84. deadline: self.deadline,
  85. encoding: self.encoding
  86. )
  87. }
  88. }
  89. private var _storage: _Storage
  90. // Don't put this in storage: it would CoW for every mutation.
  91. internal var customMetadata: HPACKHeaders
  92. internal var method: String {
  93. get {
  94. return self._storage.method
  95. }
  96. set {
  97. if !isKnownUniquelyReferenced(&self._storage) {
  98. self._storage = self._storage.copy()
  99. }
  100. self._storage.method = newValue
  101. }
  102. }
  103. internal var scheme: String {
  104. get {
  105. return self._storage.scheme
  106. }
  107. set {
  108. if !isKnownUniquelyReferenced(&self._storage) {
  109. self._storage = self._storage.copy()
  110. }
  111. self._storage.scheme = newValue
  112. }
  113. }
  114. internal var path: String {
  115. get {
  116. return self._storage.path
  117. }
  118. set {
  119. if !isKnownUniquelyReferenced(&self._storage) {
  120. self._storage = self._storage.copy()
  121. }
  122. self._storage.path = newValue
  123. }
  124. }
  125. internal var host: String {
  126. get {
  127. return self._storage.host
  128. }
  129. set {
  130. if !isKnownUniquelyReferenced(&self._storage) {
  131. self._storage = self._storage.copy()
  132. }
  133. self._storage.host = newValue
  134. }
  135. }
  136. internal var deadline: NIODeadline {
  137. get {
  138. return self._storage.deadline
  139. }
  140. set {
  141. if !isKnownUniquelyReferenced(&self._storage) {
  142. self._storage = self._storage.copy()
  143. }
  144. self._storage.deadline = newValue
  145. }
  146. }
  147. internal var encoding: ClientMessageEncoding {
  148. get {
  149. return self._storage.encoding
  150. }
  151. set {
  152. if !isKnownUniquelyReferenced(&self._storage) {
  153. self._storage = self._storage.copy()
  154. }
  155. self._storage.encoding = newValue
  156. }
  157. }
  158. public init(
  159. method: String,
  160. scheme: String,
  161. path: String,
  162. host: String,
  163. deadline: NIODeadline,
  164. customMetadata: HPACKHeaders,
  165. encoding: ClientMessageEncoding
  166. ) {
  167. self._storage = .init(
  168. method: method,
  169. scheme: scheme,
  170. path: path,
  171. host: host,
  172. deadline: deadline,
  173. encoding: encoding
  174. )
  175. self.customMetadata = customMetadata
  176. }
  177. }
  178. extension _GRPCRequestHead {
  179. internal init(
  180. scheme: String,
  181. path: String,
  182. host: String,
  183. options: CallOptions,
  184. requestID: String?
  185. ) {
  186. let metadata: HPACKHeaders
  187. if let requestID = requestID, let requestIDHeader = options.requestIDHeader {
  188. var customMetadata = options.customMetadata
  189. customMetadata.add(name: requestIDHeader, value: requestID)
  190. metadata = customMetadata
  191. } else {
  192. metadata = options.customMetadata
  193. }
  194. self = _GRPCRequestHead(
  195. method: options.cacheable ? "GET" : "POST",
  196. scheme: scheme,
  197. path: path,
  198. host: host,
  199. deadline: options.timeLimit.makeDeadline(),
  200. customMetadata: metadata,
  201. encoding: options.messageEncoding
  202. )
  203. }
  204. }
  205. /// The type of gRPC call.
  206. public enum GRPCCallType: Hashable, Sendable {
  207. /// Unary: a single request and a single response.
  208. case unary
  209. /// Client streaming: many requests and a single response.
  210. case clientStreaming
  211. /// Server streaming: a single request and many responses.
  212. case serverStreaming
  213. /// Bidirectional streaming: many request and many responses.
  214. case bidirectionalStreaming
  215. public var isStreamingRequests: Bool {
  216. switch self {
  217. case .clientStreaming, .bidirectionalStreaming:
  218. return true
  219. case .unary, .serverStreaming:
  220. return false
  221. }
  222. }
  223. public var isStreamingResponses: Bool {
  224. switch self {
  225. case .serverStreaming, .bidirectionalStreaming:
  226. return true
  227. case .unary, .clientStreaming:
  228. return false
  229. }
  230. }
  231. }
  232. // MARK: - GRPCClientChannelHandler
  233. /// A channel handler for gRPC clients which translates HTTP/2 frames into gRPC messages.
  234. ///
  235. /// This channel handler should typically be used in conjunction with another handler which
  236. /// reads the parsed `GRPCClientResponsePart<Response>` messages and surfaces them to the caller
  237. /// in some fashion. Note that for unary and client streaming RPCs this handler will only emit at
  238. /// most one response message.
  239. ///
  240. /// This handler relies heavily on the `GRPCClientStateMachine` to manage the state of the request
  241. /// and response streams, which share a single HTTP/2 stream for transport.
  242. ///
  243. /// Typical usage of this handler is with a `HTTP2StreamMultiplexer` from SwiftNIO HTTP2:
  244. ///
  245. /// ```
  246. /// let multiplexer: HTTP2StreamMultiplexer = // ...
  247. /// multiplexer.createStreamChannel(promise: nil) { (channel, streamID) in
  248. /// let clientChannelHandler = GRPCClientChannelHandler<Request, Response>(
  249. /// streamID: streamID,
  250. /// callType: callType,
  251. /// logger: logger
  252. /// )
  253. /// return channel.pipeline.addHandler(clientChannelHandler)
  254. /// }
  255. /// ```
  256. internal final class GRPCClientChannelHandler {
  257. private let logger: Logger
  258. private var stateMachine: GRPCClientStateMachine
  259. private let maximumReceiveMessageLength: Int
  260. /// Creates a new gRPC channel handler for clients to translate HTTP/2 frames to gRPC messages.
  261. ///
  262. /// - Parameters:
  263. /// - callType: Type of RPC call being made.
  264. /// - maximumReceiveMessageLength: Maximum allowed length in bytes of a received message.
  265. /// - logger: Logger.
  266. internal init(
  267. callType: GRPCCallType,
  268. maximumReceiveMessageLength: Int,
  269. logger: Logger
  270. ) {
  271. self.logger = logger
  272. self.maximumReceiveMessageLength = maximumReceiveMessageLength
  273. switch callType {
  274. case .unary:
  275. self.stateMachine = .init(requestArity: .one, responseArity: .one)
  276. case .clientStreaming:
  277. self.stateMachine = .init(requestArity: .many, responseArity: .one)
  278. case .serverStreaming:
  279. self.stateMachine = .init(requestArity: .one, responseArity: .many)
  280. case .bidirectionalStreaming:
  281. self.stateMachine = .init(requestArity: .many, responseArity: .many)
  282. }
  283. }
  284. }
  285. // MARK: - GRPCClientChannelHandler: Inbound
  286. extension GRPCClientChannelHandler: ChannelInboundHandler {
  287. internal typealias InboundIn = HTTP2Frame.FramePayload
  288. internal typealias InboundOut = _RawGRPCClientResponsePart
  289. internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  290. let payload = self.unwrapInboundIn(data)
  291. switch payload {
  292. case let .headers(content):
  293. self.readHeaders(content: content, context: context)
  294. case let .data(content):
  295. self.readData(content: content, context: context)
  296. // We don't need to handle other frame type, just drop them instead.
  297. default:
  298. // TODO: synthesise a more precise `GRPCStatus` from RST_STREAM frames in accordance
  299. // with: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#errors
  300. break
  301. }
  302. }
  303. /// Read the content from an HTTP/2 HEADERS frame received from the server.
  304. ///
  305. /// We can receive headers in two cases:
  306. /// - when the RPC is being acknowledged, and
  307. /// - when the RPC is being terminated.
  308. ///
  309. /// It is also possible for the RPC to be acknowledged and terminated at the same time, the
  310. /// specification refers to this as a "Trailers-Only" response.
  311. ///
  312. /// - Parameter content: Content of the headers frame.
  313. /// - Parameter context: Channel handler context.
  314. private func readHeaders(
  315. content: HTTP2Frame.FramePayload.Headers,
  316. context: ChannelHandlerContext
  317. ) {
  318. self.logger.trace(
  319. "received HTTP2 frame",
  320. metadata: [
  321. MetadataKey.h2Payload: "HEADERS",
  322. MetadataKey.h2Headers: "\(content.headers)",
  323. MetadataKey.h2EndStream: "\(content.endStream)",
  324. ]
  325. )
  326. // In the case of a "Trailers-Only" response there's no guarantee that end-of-stream will be set
  327. // on the headers frame: end stream may be sent on an empty data frame as well. If the headers
  328. // contain a gRPC status code then they must be for a "Trailers-Only" response.
  329. if content.endStream || content.headers.contains(name: GRPCHeaderName.statusCode) {
  330. // We have the headers, pass them to the next handler:
  331. context.fireChannelRead(self.wrapInboundOut(.trailingMetadata(content.headers)))
  332. // Are they valid headers?
  333. let result = self.stateMachine.receiveEndOfResponseStream(content.headers)
  334. .mapError { error -> GRPCError.WithContext in
  335. // The headers aren't valid so let's figure out a reasonable error to forward:
  336. switch error {
  337. case let .invalidContentType(contentType):
  338. return GRPCError.InvalidContentType(contentType).captureContext()
  339. case let .invalidHTTPStatus(status):
  340. return GRPCError.InvalidHTTPStatus(status).captureContext()
  341. case .invalidState:
  342. return GRPCError.InvalidState("parsing end-of-stream trailers").captureContext()
  343. }
  344. }
  345. // Okay, what should we tell the next handler?
  346. switch result {
  347. case let .success(status):
  348. context.fireChannelRead(self.wrapInboundOut(.status(status)))
  349. case let .failure(error):
  350. context.fireErrorCaught(error)
  351. }
  352. } else {
  353. // "Normal" response headers, but are they valid?
  354. let result = self.stateMachine.receiveResponseHeaders(content.headers)
  355. .mapError { error -> GRPCError.WithContext in
  356. // The headers aren't valid so let's figure out a reasonable error to forward:
  357. switch error {
  358. case let .invalidContentType(contentType):
  359. return GRPCError.InvalidContentType(contentType).captureContext()
  360. case let .invalidHTTPStatus(status):
  361. return GRPCError.InvalidHTTPStatus(status).captureContext()
  362. case .unsupportedMessageEncoding:
  363. return GRPCError.CompressionUnsupported().captureContext()
  364. case .invalidState:
  365. return GRPCError.InvalidState("parsing headers").captureContext()
  366. }
  367. }
  368. // Okay, what should we tell the next handler?
  369. switch result {
  370. case .success:
  371. context.fireChannelRead(self.wrapInboundOut(.initialMetadata(content.headers)))
  372. case let .failure(error):
  373. context.fireErrorCaught(error)
  374. }
  375. }
  376. }
  377. /// Reads the content from an HTTP/2 DATA frame received from the server and buffers the bytes
  378. /// necessary to deserialize a message (or messages).
  379. ///
  380. /// - Parameter content: Content of the data frame.
  381. /// - Parameter context: Channel handler context.
  382. private func readData(content: HTTP2Frame.FramePayload.Data, context: ChannelHandlerContext) {
  383. // Note: this is replicated from NIO's HTTP2ToHTTP1ClientCodec.
  384. guard case var .byteBuffer(buffer) = content.data else {
  385. preconditionFailure("Received DATA frame with non-ByteBuffer IOData")
  386. }
  387. self.logger.trace(
  388. "received HTTP2 frame",
  389. metadata: [
  390. MetadataKey.h2Payload: "DATA",
  391. MetadataKey.h2DataBytes: "\(content.data.readableBytes)",
  392. MetadataKey.h2EndStream: "\(content.endStream)",
  393. ]
  394. )
  395. self.consumeBytes(from: &buffer, context: context)
  396. // End stream is set; we don't usually expect this but can handle it in some situations.
  397. if content.endStream, let status = self.stateMachine.receiveEndOfResponseStream() {
  398. self.logger.warning("Unexpected end stream set on DATA frame")
  399. context.fireChannelRead(self.wrapInboundOut(.status(status)))
  400. }
  401. }
  402. private func consumeBytes(from buffer: inout ByteBuffer, context: ChannelHandlerContext) {
  403. // Do we have bytes to read? If there are no bytes to read then we can't do anything. This may
  404. // happen if the end-of-stream flag is not set on the trailing headers frame (i.e. the one
  405. // containing the gRPC status code) and an additional empty data frame is sent with the
  406. // end-of-stream flag set.
  407. guard buffer.readableBytes > 0 else {
  408. return
  409. }
  410. // Feed the buffer into the state machine.
  411. let result = self.stateMachine.receiveResponseBuffer(
  412. &buffer,
  413. maxMessageLength: self.maximumReceiveMessageLength
  414. ).mapError { error -> GRPCError.WithContext in
  415. switch error {
  416. case .cardinalityViolation:
  417. return GRPCError.StreamCardinalityViolation.response.captureContext()
  418. case .deserializationFailed, .leftOverBytes:
  419. return GRPCError.DeserializationFailure().captureContext()
  420. case let .decompressionLimitExceeded(compressedSize):
  421. return GRPCError.DecompressionLimitExceeded(compressedSize: compressedSize)
  422. .captureContext()
  423. case let .lengthExceedsLimit(underlyingError):
  424. return underlyingError.captureContext()
  425. case .invalidState:
  426. return GRPCError.InvalidState("parsing data as a response message").captureContext()
  427. }
  428. }
  429. // Did we get any messages?
  430. switch result {
  431. case let .success(messages):
  432. // Awesome: we got some messages. The state machine guarantees we only get at most a single
  433. // message for unary and client-streaming RPCs.
  434. for message in messages {
  435. // Note: `compressed: false` is currently just a placeholder. This is fine since the message
  436. // context is not currently exposed to the user. If we implement interceptors for the client
  437. // and decide to surface this information then we'll need to extract that information from
  438. // the message reader.
  439. context.fireChannelRead(self.wrapInboundOut(.message(.init(message, compressed: false))))
  440. }
  441. case let .failure(error):
  442. context.fireErrorCaught(error)
  443. }
  444. }
  445. }
  446. // MARK: - GRPCClientChannelHandler: Outbound
  447. extension GRPCClientChannelHandler: ChannelOutboundHandler {
  448. internal typealias OutboundIn = _RawGRPCClientRequestPart
  449. internal typealias OutboundOut = HTTP2Frame.FramePayload
  450. internal func write(
  451. context: ChannelHandlerContext,
  452. data: NIOAny,
  453. promise: EventLoopPromise<Void>?
  454. ) {
  455. switch self.unwrapOutboundIn(data) {
  456. case let .head(requestHead):
  457. // Feed the request into the state machine:
  458. switch self.stateMachine.sendRequestHeaders(
  459. requestHead: requestHead,
  460. allocator: context.channel.allocator
  461. ) {
  462. case let .success(headers):
  463. // We're clear to write some headers. Create an appropriate frame and write it.
  464. let framePayload = HTTP2Frame.FramePayload.headers(.init(headers: headers))
  465. self.logger.trace(
  466. "writing HTTP2 frame",
  467. metadata: [
  468. MetadataKey.h2Payload: "HEADERS",
  469. MetadataKey.h2Headers: "\(headers)",
  470. MetadataKey.h2EndStream: "false",
  471. ]
  472. )
  473. context.write(self.wrapOutboundOut(framePayload), promise: promise)
  474. case let .failure(sendRequestHeadersError):
  475. switch sendRequestHeadersError {
  476. case .invalidState:
  477. // This is bad: we need to trigger an error and close the channel.
  478. promise?.fail(sendRequestHeadersError)
  479. context.fireErrorCaught(GRPCError.InvalidState("unable to initiate RPC").captureContext())
  480. }
  481. }
  482. case let .message(request):
  483. // Feed the request message into the state machine:
  484. let result = self.stateMachine.sendRequest(
  485. request.message,
  486. compressed: request.compressed,
  487. promise: promise
  488. )
  489. switch result {
  490. case .success:
  491. ()
  492. case let .failure(writeError):
  493. switch writeError {
  494. case .cardinalityViolation:
  495. // This is fine: we can ignore the request. The RPC can continue as if nothing went wrong.
  496. promise?.fail(writeError)
  497. case .serializationFailed:
  498. // This is bad: we need to trigger an error and close the channel.
  499. promise?.fail(writeError)
  500. context.fireErrorCaught(GRPCError.SerializationFailure().captureContext())
  501. case .invalidState:
  502. promise?.fail(writeError)
  503. context
  504. .fireErrorCaught(GRPCError.InvalidState("unable to write message").captureContext())
  505. }
  506. }
  507. case .end:
  508. // About to send end: write any outbound messages first.
  509. while let (result, promise) = self.stateMachine.nextRequest() {
  510. switch result {
  511. case let .success(buffer):
  512. let framePayload: HTTP2Frame.FramePayload = .data(
  513. .init(data: .byteBuffer(buffer), endStream: false)
  514. )
  515. self.logger.trace(
  516. "writing HTTP2 frame",
  517. metadata: [
  518. MetadataKey.h2Payload: "DATA",
  519. MetadataKey.h2DataBytes: "\(buffer.readableBytes)",
  520. MetadataKey.h2EndStream: "false",
  521. ]
  522. )
  523. context.write(self.wrapOutboundOut(framePayload), promise: promise)
  524. case let .failure(error):
  525. context.fireErrorCaught(error)
  526. promise?.fail(error)
  527. return
  528. }
  529. }
  530. // Okay: can we close the request stream?
  531. switch self.stateMachine.sendEndOfRequestStream() {
  532. case .success:
  533. // We can. Send an empty DATA frame with end-stream set.
  534. let empty = context.channel.allocator.buffer(capacity: 0)
  535. let framePayload: HTTP2Frame.FramePayload = .data(
  536. .init(data: .byteBuffer(empty), endStream: true)
  537. )
  538. self.logger.trace(
  539. "writing HTTP2 frame",
  540. metadata: [
  541. MetadataKey.h2Payload: "DATA",
  542. MetadataKey.h2DataBytes: "0",
  543. MetadataKey.h2EndStream: "true",
  544. ]
  545. )
  546. context.write(self.wrapOutboundOut(framePayload), promise: promise)
  547. case let .failure(error):
  548. // Why can't we close the request stream?
  549. switch error {
  550. case .alreadyClosed:
  551. // This is fine: we can just ignore it. The RPC can continue as if nothing went wrong.
  552. promise?.fail(error)
  553. case .invalidState:
  554. // This is bad: we need to trigger an error and close the channel.
  555. promise?.fail(error)
  556. context
  557. .fireErrorCaught(
  558. GRPCError.InvalidState("unable to close request stream")
  559. .captureContext()
  560. )
  561. }
  562. }
  563. }
  564. }
  565. func flush(context: ChannelHandlerContext) {
  566. // Drain any requests.
  567. while let (result, promise) = self.stateMachine.nextRequest() {
  568. switch result {
  569. case let .success(buffer):
  570. let framePayload: HTTP2Frame.FramePayload = .data(
  571. .init(data: .byteBuffer(buffer), endStream: false)
  572. )
  573. self.logger.trace(
  574. "writing HTTP2 frame",
  575. metadata: [
  576. MetadataKey.h2Payload: "DATA",
  577. MetadataKey.h2DataBytes: "\(buffer.readableBytes)",
  578. MetadataKey.h2EndStream: "false",
  579. ]
  580. )
  581. context.write(self.wrapOutboundOut(framePayload), promise: promise)
  582. case let .failure(error):
  583. context.fireErrorCaught(error)
  584. promise?.fail(error)
  585. return
  586. }
  587. }
  588. context.flush()
  589. }
  590. }