| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- /*
- * Copyright 2019, gRPC Authors All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- import NIOCore
- import NIOHTTP1
- /// Handler that manages the CORS protocol for requests incoming from the browser.
- internal final class WebCORSHandler {
- let configuration: Server.Configuration.CORS
- private var state: State = .idle
- private enum State: Equatable {
- /// Starting state.
- case idle
- /// CORS preflight request is in progress.
- case processingPreflightRequest
- /// "Real" request is in progress.
- case processingRequest(origin: String?)
- }
- init(configuration: Server.Configuration.CORS) {
- self.configuration = configuration
- }
- }
- extension WebCORSHandler: ChannelInboundHandler {
- typealias InboundIn = HTTPServerRequestPart
- typealias InboundOut = HTTPServerRequestPart
- typealias OutboundOut = HTTPServerResponsePart
- func channelRead(context: ChannelHandlerContext, data: NIOAny) {
- switch self.unwrapInboundIn(data) {
- case let .head(head):
- self.receivedRequestHead(context: context, head)
- case let .body(body):
- self.receivedRequestBody(context: context, body)
- case let .end(trailers):
- self.receivedRequestEnd(context: context, trailers)
- }
- }
- private func receivedRequestHead(context: ChannelHandlerContext, _ head: HTTPRequestHead) {
- if head.method == .OPTIONS,
- head.headers.contains(.accessControlRequestMethod),
- let origin = head.headers.first(name: "origin")
- {
- // If the request is OPTIONS with a access-control-request-method header it's a CORS
- // preflight request and is not propagated further.
- self.state = .processingPreflightRequest
- self.handlePreflightRequest(context: context, head: head, origin: origin)
- } else {
- self.state = .processingRequest(origin: head.headers.first(name: "origin"))
- context.fireChannelRead(self.wrapInboundOut(.head(head)))
- }
- }
- private func receivedRequestBody(context: ChannelHandlerContext, _ body: ByteBuffer) {
- // OPTIONS requests do not have a body, but still handle this case to be
- // cautious.
- if self.state == .processingPreflightRequest {
- return
- }
- context.fireChannelRead(self.wrapInboundOut(.body(body)))
- }
- private func receivedRequestEnd(context: ChannelHandlerContext, _ trailers: HTTPHeaders?) {
- if self.state == .processingPreflightRequest {
- // End of OPTIONS request; reset state and finish the response.
- self.state = .idle
- context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
- } else {
- context.fireChannelRead(self.wrapInboundOut(.end(trailers)))
- }
- }
- private func handlePreflightRequest(
- context: ChannelHandlerContext,
- head: HTTPRequestHead,
- origin: String
- ) {
- let responseHead: HTTPResponseHead
- if let allowedOrigin = self.configuration.allowedOrigins.header(origin) {
- var headers = HTTPHeaders()
- headers.reserveCapacity(4 + self.configuration.allowedHeaders.count)
- headers.add(name: .accessControlAllowOrigin, value: allowedOrigin)
- headers.add(name: .accessControlAllowMethods, value: "POST")
- for value in self.configuration.allowedHeaders {
- headers.add(name: .accessControlAllowHeaders, value: value)
- }
- if self.configuration.allowCredentialedRequests {
- headers.add(name: .accessControlAllowCredentials, value: "true")
- }
- if self.configuration.preflightCacheExpiration > 0 {
- headers.add(
- name: .accessControlMaxAge,
- value: "\(self.configuration.preflightCacheExpiration)"
- )
- }
- responseHead = HTTPResponseHead(version: head.version, status: .ok, headers: headers)
- } else {
- // Not allowed; respond with 403. This is okay in a pre-flight request.
- responseHead = HTTPResponseHead(version: head.version, status: .forbidden)
- }
- context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
- }
- }
- extension WebCORSHandler: ChannelOutboundHandler {
- typealias OutboundIn = HTTPServerResponsePart
- func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
- let responsePart = self.unwrapOutboundIn(data)
- switch responsePart {
- case var .head(responseHead):
- switch self.state {
- case let .processingRequest(origin):
- self.prepareCORSResponseHead(&responseHead, origin: origin)
- context.write(self.wrapOutboundOut(.head(responseHead)), promise: promise)
- case .idle, .processingPreflightRequest:
- assertionFailure("Writing response head when no request is in progress")
- context.close(promise: nil)
- }
- case .body:
- context.write(data, promise: promise)
- case .end:
- self.state = .idle
- context.write(data, promise: promise)
- }
- }
- private func prepareCORSResponseHead(_ head: inout HTTPResponseHead, origin: String?) {
- guard let header = origin.flatMap({ self.configuration.allowedOrigins.header($0) }) else {
- // No origin or the origin is not allowed; don't treat it as a CORS request.
- return
- }
- head.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: header)
- if self.configuration.allowCredentialedRequests {
- head.headers.add(name: .accessControlAllowCredentials, value: "true")
- }
- //! FIXME: Check whether we can let browsers keep connections alive. It's not possible
- // now as the channel has a state that can't be reused since the pipeline is modified to
- // inject the gRPC call handler.
- head.headers.replaceOrAdd(name: "Connection", value: "close")
- }
- }
- extension HTTPHeaders {
- fileprivate enum CORSHeader: String {
- case accessControlRequestMethod = "access-control-request-method"
- case accessControlRequestHeaders = "access-control-request-headers"
- case accessControlAllowOrigin = "access-control-allow-origin"
- case accessControlAllowMethods = "access-control-allow-methods"
- case accessControlAllowHeaders = "access-control-allow-headers"
- case accessControlAllowCredentials = "access-control-allow-credentials"
- case accessControlMaxAge = "access-control-max-age"
- }
- fileprivate func contains(_ name: CORSHeader) -> Bool {
- return self.contains(name: name.rawValue)
- }
- fileprivate mutating func add(name: CORSHeader, value: String) {
- self.add(name: name.rawValue, value: value)
- }
- fileprivate mutating func replaceOrAdd(name: CORSHeader, value: String) {
- self.replaceOrAdd(name: name.rawValue, value: value)
- }
- }
- extension Server.Configuration.CORS.AllowedOrigins {
- internal func header(_ origin: String) -> String? {
- switch self.wrapped {
- case .all:
- return "*"
- case .originBased:
- return origin
- case let .only(allowed):
- return allowed.contains(origin) ? origin : nil
- case let .custom(custom):
- return custom.check(origin: origin)
- }
- }
- }
|