| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- /*
- * Copyright 2023, gRPC Authors All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- import NIOCore
- import NIOEmbedded
- import NIOHTTP1
- import XCTest
- @testable import GRPC
- internal final class WebCORSHandlerTests: XCTestCase {
- struct PreflightRequestSpec {
- var configuration: Server.Configuration.CORS
- var requestOrigin: Optional<String>
- var expectOrigin: Optional<String>
- var expectAllowedHeaders: [String]
- var expectAllowCredentials: Bool
- var expectMaxAge: Optional<String>
- var expectStatus: HTTPResponseStatus = .ok
- }
- func runPreflightRequestTest(spec: PreflightRequestSpec) throws {
- let channel = EmbeddedChannel(handler: WebCORSHandler(configuration: spec.configuration))
- var request = HTTPRequestHead(version: .http1_1, method: .OPTIONS, uri: "http://foo.example")
- if let origin = spec.requestOrigin {
- request.headers.add(name: "origin", value: origin)
- }
- request.headers.add(name: "access-control-request-method", value: "POST")
- try channel.writeRequestPart(.head(request))
- try channel.writeRequestPart(.end(nil))
- switch try channel.readResponsePart() {
- case let .head(response):
- XCTAssertEqual(response.version, request.version)
- if let expected = spec.expectOrigin {
- XCTAssertEqual(response.headers["access-control-allow-origin"], [expected])
- } else {
- XCTAssertFalse(response.headers.contains(name: "access-control-allow-origin"))
- }
- if spec.expectAllowedHeaders.isEmpty {
- XCTAssertFalse(response.headers.contains(name: "access-control-allow-headers"))
- } else {
- XCTAssertEqual(response.headers["access-control-allow-headers"], spec.expectAllowedHeaders)
- }
- if spec.expectAllowCredentials {
- XCTAssertEqual(response.headers["access-control-allow-credentials"], ["true"])
- } else {
- XCTAssertFalse(response.headers.contains(name: "access-control-allow-credentials"))
- }
- if let maxAge = spec.expectMaxAge {
- XCTAssertEqual(response.headers["access-control-max-age"], [maxAge])
- } else {
- XCTAssertFalse(response.headers.contains(name: "access-control-max-age"))
- }
- XCTAssertEqual(response.status, spec.expectStatus)
- case .body, .end, .none:
- XCTFail("Unexpected response part")
- }
- }
- func testOptionsPreflightAllowAllOrigins() throws {
- let spec = PreflightRequestSpec(
- configuration: .init(
- allowedOrigins: .all,
- allowedHeaders: ["x-grpc-web"],
- allowCredentialedRequests: false,
- preflightCacheExpiration: 60
- ),
- requestOrigin: "foo",
- expectOrigin: "*",
- expectAllowedHeaders: ["x-grpc-web"],
- expectAllowCredentials: false,
- expectMaxAge: "60"
- )
- try self.runPreflightRequestTest(spec: spec)
- }
- func testOptionsPreflightOriginBased() throws {
- let spec = PreflightRequestSpec(
- configuration: .init(
- allowedOrigins: .originBased,
- allowedHeaders: ["x-grpc-web"],
- allowCredentialedRequests: false,
- preflightCacheExpiration: 60
- ),
- requestOrigin: "foo",
- expectOrigin: "foo",
- expectAllowedHeaders: ["x-grpc-web"],
- expectAllowCredentials: false,
- expectMaxAge: "60"
- )
- try self.runPreflightRequestTest(spec: spec)
- }
- func testOptionsPreflightCustom() throws {
- struct Wrapper: GRPCCustomCORSAllowedOrigin {
- func check(origin: String) -> String? {
- if origin == "foo" {
- return "bar"
- } else {
- return nil
- }
- }
- }
- let spec = PreflightRequestSpec(
- configuration: .init(
- allowedOrigins: .custom(Wrapper()),
- allowedHeaders: ["x-grpc-web"],
- allowCredentialedRequests: false,
- preflightCacheExpiration: 60
- ),
- requestOrigin: "foo",
- expectOrigin: "bar",
- expectAllowedHeaders: ["x-grpc-web"],
- expectAllowCredentials: false,
- expectMaxAge: "60"
- )
- try self.runPreflightRequestTest(spec: spec)
- }
- func testOptionsPreflightAllowSomeOrigins() throws {
- let spec = PreflightRequestSpec(
- configuration: .init(
- allowedOrigins: .only(["bar", "foo"]),
- allowedHeaders: ["x-grpc-web"],
- allowCredentialedRequests: false,
- preflightCacheExpiration: 60
- ),
- requestOrigin: "foo",
- expectOrigin: "foo",
- expectAllowedHeaders: ["x-grpc-web"],
- expectAllowCredentials: false,
- expectMaxAge: "60"
- )
- try self.runPreflightRequestTest(spec: spec)
- }
- func testOptionsPreflightAllowNoHeaders() throws {
- let spec = PreflightRequestSpec(
- configuration: .init(
- allowedOrigins: .all,
- allowedHeaders: [],
- allowCredentialedRequests: false,
- preflightCacheExpiration: 60
- ),
- requestOrigin: "foo",
- expectOrigin: "*",
- expectAllowedHeaders: [],
- expectAllowCredentials: false,
- expectMaxAge: "60"
- )
- try self.runPreflightRequestTest(spec: spec)
- }
- func testOptionsPreflightNoMaxAge() throws {
- let spec = PreflightRequestSpec(
- configuration: .init(
- allowedOrigins: .all,
- allowedHeaders: [],
- allowCredentialedRequests: false,
- preflightCacheExpiration: 0
- ),
- requestOrigin: "foo",
- expectOrigin: "*",
- expectAllowedHeaders: [],
- expectAllowCredentials: false,
- expectMaxAge: nil
- )
- try self.runPreflightRequestTest(spec: spec)
- }
- func testOptionsPreflightNegativeMaxAge() throws {
- let spec = PreflightRequestSpec(
- configuration: .init(
- allowedOrigins: .all,
- allowedHeaders: [],
- allowCredentialedRequests: false,
- preflightCacheExpiration: -1
- ),
- requestOrigin: "foo",
- expectOrigin: "*",
- expectAllowedHeaders: [],
- expectAllowCredentials: false,
- expectMaxAge: nil
- )
- try self.runPreflightRequestTest(spec: spec)
- }
- func testOptionsPreflightWithCredentials() throws {
- let spec = PreflightRequestSpec(
- configuration: .init(
- allowedOrigins: .all,
- allowedHeaders: [],
- allowCredentialedRequests: true,
- preflightCacheExpiration: 60
- ),
- requestOrigin: "foo",
- expectOrigin: "*",
- expectAllowedHeaders: [],
- expectAllowCredentials: true,
- expectMaxAge: "60"
- )
- try self.runPreflightRequestTest(spec: spec)
- }
- func testOptionsPreflightWithDisallowedOrigin() throws {
- let spec = PreflightRequestSpec(
- configuration: .init(
- allowedOrigins: .only(["foo"]),
- allowedHeaders: [],
- allowCredentialedRequests: false,
- preflightCacheExpiration: 60
- ),
- requestOrigin: "bar",
- expectOrigin: nil,
- expectAllowedHeaders: [],
- expectAllowCredentials: false,
- expectMaxAge: nil,
- expectStatus: .forbidden
- )
- try self.runPreflightRequestTest(spec: spec)
- }
- }
- extension WebCORSHandlerTests {
- struct RegularRequestSpec {
- var configuration: Server.Configuration.CORS
- var requestOrigin: Optional<String>
- var expectOrigin: Optional<String>
- var expectAllowCredentials: Bool
- }
- func runRegularRequestTest(
- spec: RegularRequestSpec
- ) throws {
- let channel = EmbeddedChannel(handler: WebCORSHandler(configuration: spec.configuration))
- var request = HTTPRequestHead(version: .http1_1, method: .OPTIONS, uri: "http://foo.example")
- if let origin = spec.requestOrigin {
- request.headers.add(name: "origin", value: origin)
- }
- try channel.writeRequestPart(.head(request))
- try channel.writeRequestPart(.end(nil))
- XCTAssertEqual(try channel.readRequestPart(), .head(request))
- XCTAssertEqual(try channel.readRequestPart(), .end(nil))
- let response = HTTPResponseHead(version: request.version, status: .imATeapot)
- try channel.writeResponsePart(.head(response))
- try channel.writeResponsePart(.end(nil))
- switch try channel.readResponsePart() {
- case let .head(head):
- // Should not be modified.
- XCTAssertEqual(head.version, response.version)
- XCTAssertEqual(head.status, response.status)
- if let expected = spec.expectOrigin {
- XCTAssertEqual(head.headers["access-control-allow-origin"], [expected])
- } else {
- XCTAssertFalse(head.headers.contains(name: "access-control-allow-origin"))
- }
- if spec.expectAllowCredentials {
- XCTAssertEqual(head.headers["access-control-allow-credentials"], ["true"])
- } else {
- XCTAssertFalse(head.headers.contains(name: "access-control-allow-credentials"))
- }
- case .body, .end, .none:
- XCTFail("Unexpected response part")
- }
- XCTAssertEqual(try channel.readResponsePart(), .end(nil))
- }
- func testRegularRequestWithWildcardOrigin() throws {
- let spec = RegularRequestSpec(
- configuration: .init(
- allowedOrigins: .all,
- allowCredentialedRequests: false
- ),
- requestOrigin: "foo",
- expectOrigin: "*",
- expectAllowCredentials: false
- )
- try self.runRegularRequestTest(spec: spec)
- }
- func testRegularRequestWithLimitedOrigin() throws {
- let spec = RegularRequestSpec(
- configuration: .init(
- allowedOrigins: .only(["foo", "bar"]),
- allowCredentialedRequests: false
- ),
- requestOrigin: "foo",
- expectOrigin: "foo",
- expectAllowCredentials: false
- )
- try self.runRegularRequestTest(spec: spec)
- }
- func testRegularRequestWithNoOrigin() throws {
- let spec = RegularRequestSpec(
- configuration: .init(
- allowedOrigins: .all,
- allowCredentialedRequests: false
- ),
- requestOrigin: nil,
- expectOrigin: nil,
- expectAllowCredentials: false
- )
- try self.runRegularRequestTest(spec: spec)
- }
- func testRegularRequestWithCredentials() throws {
- let spec = RegularRequestSpec(
- configuration: .init(
- allowedOrigins: .all,
- allowCredentialedRequests: true
- ),
- requestOrigin: "foo",
- expectOrigin: "*",
- expectAllowCredentials: true
- )
- try self.runRegularRequestTest(spec: spec)
- }
- func testRegularRequestWithDisallowedOrigin() throws {
- let spec = RegularRequestSpec(
- configuration: .init(
- allowedOrigins: .only(["foo"]),
- allowCredentialedRequests: true
- ),
- requestOrigin: "bar",
- expectOrigin: nil,
- expectAllowCredentials: false
- )
- try self.runRegularRequestTest(spec: spec)
- }
- }
- extension EmbeddedChannel {
- fileprivate func writeRequestPart(_ part: HTTPServerRequestPart) throws {
- try self.writeInbound(part)
- }
- fileprivate func readRequestPart() throws -> HTTPServerRequestPart? {
- try self.readInbound()
- }
- fileprivate func writeResponsePart(_ part: HTTPServerResponsePart) throws {
- try self.writeOutbound(part)
- }
- fileprivate func readResponsePart() throws -> HTTPServerResponsePart? {
- try self.readOutbound()
- }
- }
|