2
0

WebCORSHandler.swift 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import NIOCore
  17. import NIOHTTP1
  18. /// Handler that manages the CORS protocol for requests incoming from the browser.
  19. internal final class WebCORSHandler {
  20. let configuration: Server.Configuration.CORS
  21. private var state: State = .idle
  22. private enum State: Equatable {
  23. /// Starting state.
  24. case idle
  25. /// CORS preflight request is in progress.
  26. case processingPreflightRequest
  27. /// "Real" request is in progress.
  28. case processingRequest(origin: String?)
  29. }
  30. init(configuration: Server.Configuration.CORS) {
  31. self.configuration = configuration
  32. }
  33. }
  34. extension WebCORSHandler: ChannelInboundHandler {
  35. typealias InboundIn = HTTPServerRequestPart
  36. typealias InboundOut = HTTPServerRequestPart
  37. typealias OutboundOut = HTTPServerResponsePart
  38. func channelRead(context: ChannelHandlerContext, data: NIOAny) {
  39. switch self.unwrapInboundIn(data) {
  40. case let .head(head):
  41. self.receivedRequestHead(context: context, head)
  42. case let .body(body):
  43. self.receivedRequestBody(context: context, body)
  44. case let .end(trailers):
  45. self.receivedRequestEnd(context: context, trailers)
  46. }
  47. }
  48. private func receivedRequestHead(context: ChannelHandlerContext, _ head: HTTPRequestHead) {
  49. if head.method == .OPTIONS,
  50. head.headers.contains(.accessControlRequestMethod),
  51. let origin = head.headers.first(name: "origin")
  52. {
  53. // If the request is OPTIONS with a access-control-request-method header it's a CORS
  54. // preflight request and is not propagated further.
  55. self.state = .processingPreflightRequest
  56. self.handlePreflightRequest(context: context, head: head, origin: origin)
  57. } else {
  58. self.state = .processingRequest(origin: head.headers.first(name: "origin"))
  59. context.fireChannelRead(self.wrapInboundOut(.head(head)))
  60. }
  61. }
  62. private func receivedRequestBody(context: ChannelHandlerContext, _ body: ByteBuffer) {
  63. // OPTIONS requests do not have a body, but still handle this case to be
  64. // cautious.
  65. if self.state == .processingPreflightRequest {
  66. return
  67. }
  68. context.fireChannelRead(self.wrapInboundOut(.body(body)))
  69. }
  70. private func receivedRequestEnd(context: ChannelHandlerContext, _ trailers: HTTPHeaders?) {
  71. if self.state == .processingPreflightRequest {
  72. // End of OPTIONS request; reset state and finish the response.
  73. self.state = .idle
  74. context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
  75. } else {
  76. context.fireChannelRead(self.wrapInboundOut(.end(trailers)))
  77. }
  78. }
  79. private func handlePreflightRequest(
  80. context: ChannelHandlerContext,
  81. head: HTTPRequestHead,
  82. origin: String
  83. ) {
  84. let responseHead: HTTPResponseHead
  85. if let allowedOrigin = self.configuration.allowedOrigins.header(origin) {
  86. var headers = HTTPHeaders()
  87. headers.reserveCapacity(4 + self.configuration.allowedHeaders.count)
  88. headers.add(name: .accessControlAllowOrigin, value: allowedOrigin)
  89. headers.add(name: .accessControlAllowMethods, value: "POST")
  90. for value in self.configuration.allowedHeaders {
  91. headers.add(name: .accessControlAllowHeaders, value: value)
  92. }
  93. if self.configuration.allowCredentialedRequests {
  94. headers.add(name: .accessControlAllowCredentials, value: "true")
  95. }
  96. if self.configuration.preflightCacheExpiration > 0 {
  97. headers.add(
  98. name: .accessControlMaxAge,
  99. value: "\(self.configuration.preflightCacheExpiration)"
  100. )
  101. }
  102. responseHead = HTTPResponseHead(version: head.version, status: .ok, headers: headers)
  103. } else {
  104. // Not allowed; respond with 403. This is okay in a pre-flight request.
  105. responseHead = HTTPResponseHead(version: head.version, status: .forbidden)
  106. }
  107. context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
  108. }
  109. }
  110. extension WebCORSHandler: ChannelOutboundHandler {
  111. typealias OutboundIn = HTTPServerResponsePart
  112. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  113. let responsePart = self.unwrapOutboundIn(data)
  114. switch responsePart {
  115. case var .head(responseHead):
  116. switch self.state {
  117. case let .processingRequest(origin):
  118. self.prepareCORSResponseHead(&responseHead, origin: origin)
  119. context.write(self.wrapOutboundOut(.head(responseHead)), promise: promise)
  120. case .idle, .processingPreflightRequest:
  121. assertionFailure("Writing response head when no request is in progress")
  122. context.close(promise: nil)
  123. }
  124. case .body:
  125. context.write(data, promise: promise)
  126. case .end:
  127. self.state = .idle
  128. context.write(data, promise: promise)
  129. }
  130. }
  131. private func prepareCORSResponseHead(_ head: inout HTTPResponseHead, origin: String?) {
  132. guard let header = origin.flatMap({ self.configuration.allowedOrigins.header($0) }) else {
  133. // No origin or the origin is not allowed; don't treat it as a CORS request.
  134. return
  135. }
  136. head.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: header)
  137. if self.configuration.allowCredentialedRequests {
  138. head.headers.add(name: .accessControlAllowCredentials, value: "true")
  139. }
  140. //! FIXME: Check whether we can let browsers keep connections alive. It's not possible
  141. // now as the channel has a state that can't be reused since the pipeline is modified to
  142. // inject the gRPC call handler.
  143. head.headers.replaceOrAdd(name: "Connection", value: "close")
  144. }
  145. }
  146. extension HTTPHeaders {
  147. fileprivate enum CORSHeader: String {
  148. case accessControlRequestMethod = "access-control-request-method"
  149. case accessControlRequestHeaders = "access-control-request-headers"
  150. case accessControlAllowOrigin = "access-control-allow-origin"
  151. case accessControlAllowMethods = "access-control-allow-methods"
  152. case accessControlAllowHeaders = "access-control-allow-headers"
  153. case accessControlAllowCredentials = "access-control-allow-credentials"
  154. case accessControlMaxAge = "access-control-max-age"
  155. }
  156. fileprivate func contains(_ name: CORSHeader) -> Bool {
  157. return self.contains(name: name.rawValue)
  158. }
  159. fileprivate mutating func add(name: CORSHeader, value: String) {
  160. self.add(name: name.rawValue, value: value)
  161. }
  162. fileprivate mutating func replaceOrAdd(name: CORSHeader, value: String) {
  163. self.replaceOrAdd(name: name.rawValue, value: value)
  164. }
  165. }
  166. extension Server.Configuration.CORS.AllowedOrigins {
  167. internal func header(_ origin: String) -> String? {
  168. switch self.wrapped {
  169. case .all:
  170. return "*"
  171. case .originBased:
  172. return origin
  173. case let .only(allowed):
  174. return allowed.contains(origin) ? origin : nil
  175. case let .custom(custom):
  176. return custom.check(origin: origin)
  177. }
  178. }
  179. }