WebCORSHandler.swift 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import NIOCore
  17. import NIOHTTP1
  18. /// Handler that manages the CORS protocol for requests incoming from the browser.
  19. internal final class WebCORSHandler {
  20. let configuration: Server.Configuration.CORS
  21. private var state: State = .idle
  22. private enum State: Equatable {
  23. /// Starting state.
  24. case idle
  25. /// CORS preflight request is in progress.
  26. case processingPreflightRequest
  27. /// "Real" request is in progress.
  28. case processingRequest(origin: String?)
  29. }
  30. init(configuration: Server.Configuration.CORS) {
  31. self.configuration = configuration
  32. }
  33. }
  34. extension WebCORSHandler: ChannelInboundHandler {
  35. typealias InboundIn = HTTPServerRequestPart
  36. typealias InboundOut = HTTPServerRequestPart
  37. typealias OutboundOut = HTTPServerResponsePart
  38. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  39. switch self.unwrapInboundIn(data) {
  40. case let .head(head):
  41. self.receivedRequestHead(context: context, head)
  42. case let .body(body):
  43. self.receivedRequestBody(context: context, body)
  44. case let .end(trailers):
  45. self.receivedRequestEnd(context: context, trailers)
  46. }
  47. }
  48. private func receivedRequestHead(context: ChannelHandlerContext, _ head: HTTPRequestHead) {
  49. if head.method == .OPTIONS,
  50. head.headers.contains(.accessControlRequestMethod),
  51. let origin = head.headers.first(name: "origin") {
  52. // If the request is OPTIONS with a access-control-request-method header it's a CORS
  53. // preflight request and is not propagated further.
  54. self.state = .processingPreflightRequest
  55. self.handlePreflightRequest(context: context, head: head, origin: origin)
  56. } else {
  57. self.state = .processingRequest(origin: head.headers.first(name: "origin"))
  58. context.fireChannelRead(self.wrapInboundOut(.head(head)))
  59. }
  60. }
  61. private func receivedRequestBody(context: ChannelHandlerContext, _ body: ByteBuffer) {
  62. // OPTIONS requests do not have a body, but still handle this case to be
  63. // cautious.
  64. if self.state == .processingPreflightRequest {
  65. return
  66. }
  67. context.fireChannelRead(self.wrapInboundOut(.body(body)))
  68. }
  69. private func receivedRequestEnd(context: ChannelHandlerContext, _ trailers: HTTPHeaders?) {
  70. if self.state == .processingPreflightRequest {
  71. // End of OPTIONS request; reset state and finish the response.
  72. self.state = .idle
  73. context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
  74. } else {
  75. context.fireChannelRead(self.wrapInboundOut(.end(trailers)))
  76. }
  77. }
  78. private func handlePreflightRequest(
  79. context: ChannelHandlerContext,
  80. head: HTTPRequestHead,
  81. origin: String
  82. ) {
  83. let responseHead: HTTPResponseHead
  84. if let allowedOrigin = self.configuration.allowedOrigins.header(origin) {
  85. var headers = HTTPHeaders()
  86. headers.reserveCapacity(4 + self.configuration.allowedHeaders.count)
  87. headers.add(name: .accessControlAllowOrigin, value: allowedOrigin)
  88. headers.add(name: .accessControlAllowMethods, value: "POST")
  89. for value in self.configuration.allowedHeaders {
  90. headers.add(name: .accessControlAllowHeaders, value: value)
  91. }
  92. if self.configuration.allowCredentialedRequests {
  93. headers.add(name: .accessControlAllowCredentials, value: "true")
  94. }
  95. if self.configuration.preflightCacheExpiration > 0 {
  96. headers.add(
  97. name: .accessControlMaxAge,
  98. value: "\(self.configuration.preflightCacheExpiration)"
  99. )
  100. }
  101. responseHead = HTTPResponseHead(version: head.version, status: .ok, headers: headers)
  102. } else {
  103. // Not allowed; respond with 403. This is okay in a pre-flight request.
  104. responseHead = HTTPResponseHead(version: head.version, status: .forbidden)
  105. }
  106. context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
  107. }
  108. }
  109. extension WebCORSHandler: ChannelOutboundHandler {
  110. typealias OutboundIn = HTTPServerResponsePart
  111. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  112. let responsePart = self.unwrapOutboundIn(data)
  113. switch responsePart {
  114. case var .head(responseHead):
  115. switch self.state {
  116. case let .processingRequest(origin):
  117. self.prepareCORSResponseHead(&responseHead, origin: origin)
  118. context.write(self.wrapOutboundOut(.head(responseHead)), promise: promise)
  119. case .idle, .processingPreflightRequest:
  120. assertionFailure("Writing response head when no request is in progress")
  121. context.close(promise: nil)
  122. }
  123. case .body:
  124. context.write(data, promise: promise)
  125. case .end:
  126. self.state = .idle
  127. context.write(data, promise: promise)
  128. }
  129. }
  130. private func prepareCORSResponseHead(_ head: inout HTTPResponseHead, origin: String?) {
  131. guard let header = origin.flatMap({ self.configuration.allowedOrigins.header($0) }) else {
  132. // No origin or the origin is not allowed; don't treat it as a CORS request.
  133. return
  134. }
  135. head.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: header)
  136. if self.configuration.allowCredentialedRequests {
  137. head.headers.add(name: .accessControlAllowCredentials, value: "true")
  138. }
  139. //! FIXME: Check whether we can let browsers keep connections alive. It's not possible
  140. // now as the channel has a state that can't be reused since the pipeline is modified to
  141. // inject the gRPC call handler.
  142. head.headers.replaceOrAdd(name: "Connection", value: "close")
  143. }
  144. }
  145. extension HTTPHeaders {
  146. fileprivate enum CORSHeader: String {
  147. case accessControlRequestMethod = "access-control-request-method"
  148. case accessControlRequestHeaders = "access-control-request-headers"
  149. case accessControlAllowOrigin = "access-control-allow-origin"
  150. case accessControlAllowMethods = "access-control-allow-methods"
  151. case accessControlAllowHeaders = "access-control-allow-headers"
  152. case accessControlAllowCredentials = "access-control-allow-credentials"
  153. case accessControlMaxAge = "access-control-max-age"
  154. }
  155. fileprivate func contains(_ name: CORSHeader) -> Bool {
  156. return self.contains(name: name.rawValue)
  157. }
  158. fileprivate mutating func add(name: CORSHeader, value: String) {
  159. self.add(name: name.rawValue, value: value)
  160. }
  161. fileprivate mutating func replaceOrAdd(name: CORSHeader, value: String) {
  162. self.replaceOrAdd(name: name.rawValue, value: value)
  163. }
  164. }
  165. extension Server.Configuration.CORS.AllowedOrigins {
  166. internal func header(_ origin: String) -> String? {
  167. switch self.wrapped {
  168. case .all:
  169. return "*"
  170. case let .only(allowed):
  171. return allowed.contains(origin) ? origin : nil
  172. }
  173. }
  174. }