Browse Source

Allow for more CORS configuration (#1594)

Motivation:

We added some level of CORS configuration support in #1583. This change adds
further flexibility.

Modifications:

- Add an 'originBased' mode where the value of the origin header is
  returned in the response head.
- Add a custom fallback where the user can specify a callback which
  is passed the value of the origin header and returns the value to
  return in the 'access-control-allow-origin' response header (or nil,
  if the origin is not allowed).

Result:

More flexibility for CORS.
George Barnett 2 years ago
parent
commit
ef8ffb937b

+ 58 - 0
Sources/GRPC/Server.swift

@@ -489,7 +489,9 @@ extension Server.Configuration.CORS {
   public struct AllowedOrigins: Hashable, Sendable {
     enum Wrapped: Hashable, Sendable {
       case all
+      case originBased
       case only([String])
+      case custom(AnyCustomCORSAllowedOrigin)
     }
 
     private(set) var wrapped: Wrapped
@@ -500,10 +502,23 @@ extension Server.Configuration.CORS {
     /// Allow all origin values.
     public static let all = Self(.all)
 
+    /// Allow all origin values; similar to `all` but returns the value of the origin header field
+    /// in the 'access-control-allow-origin' response header (rather than "*").
+    public static let originBased = Self(.originBased)
+
     /// Allow only the given origin values.
     public static func only(_ allowed: [String]) -> Self {
       return Self(.only(allowed))
     }
+
+    /// Provide a custom CORS origin check.
+    ///
+    /// - Parameter checkOrigin: A closure which is called with the value of the 'origin' header
+    ///     and returns the value to use in the 'access-control-allow-origin' response header,
+    ///     or `nil` if the origin is not allowed.
+    public static func custom<C: GRPCCustomCORSAllowedOrigin>(_ custom: C) -> Self {
+      return Self(.custom(AnyCustomCORSAllowedOrigin(custom)))
+    }
   }
 }
 
@@ -530,3 +545,46 @@ extension Comparable {
     return min(max(self, range.lowerBound), range.upperBound)
   }
 }
+
+public protocol GRPCCustomCORSAllowedOrigin: Sendable, Hashable {
+  /// Returns the value to use for the 'access-control-allow-origin' response header for the given
+  /// value of the 'origin' request header.
+  ///
+  /// - Parameter origin: The value of the 'origin' request header field.
+  /// - Returns: The value to use for the 'access-control-allow-origin' header field or `nil` if no
+  ///     CORS related headers should be returned.
+  func check(origin: String) -> String?
+}
+
+extension Server.Configuration.CORS.AllowedOrigins {
+  struct AnyCustomCORSAllowedOrigin: GRPCCustomCORSAllowedOrigin {
+    private var checkOrigin: @Sendable (String) -> String?
+    private let hashInto: @Sendable (inout Hasher) -> Void
+    #if swift(>=5.7)
+    private let isEqualTo: @Sendable (any GRPCCustomCORSAllowedOrigin) -> Bool
+    #else
+    private let isEqualTo: @Sendable (Any) -> Bool
+    #endif
+
+    init<W: GRPCCustomCORSAllowedOrigin>(_ wrap: W) {
+      self.checkOrigin = { wrap.check(origin: $0) }
+      self.hashInto = { wrap.hash(into: &$0) }
+      self.isEqualTo = { wrap == ($0 as? W) }
+    }
+
+    func check(origin: String) -> String? {
+      return self.checkOrigin(origin)
+    }
+
+    func hash(into hasher: inout Hasher) {
+      self.hashInto(&hasher)
+    }
+
+    static func == (
+      lhs: Server.Configuration.CORS.AllowedOrigins.AnyCustomCORSAllowedOrigin,
+      rhs: Server.Configuration.CORS.AllowedOrigins.AnyCustomCORSAllowedOrigin
+    ) -> Bool {
+      return lhs.isEqualTo(rhs)
+    }
+  }
+}

+ 4 - 0
Sources/GRPC/WebCORSHandler.swift

@@ -198,8 +198,12 @@ extension Server.Configuration.CORS.AllowedOrigins {
     switch self.wrapped {
     case .all:
       return "*"
+    case .originBased:
+      return origin
     case let .only(allowed):
       return allowed.contains(origin) ? origin : nil
+    case let .custom(custom):
+      return custom.check(origin: origin)
     }
   }
 }

+ 44 - 0
Tests/GRPCTests/WebCORSHandlerTests.swift

@@ -93,6 +93,50 @@ internal final class WebCORSHandlerTests: XCTestCase {
     try self.runPreflightRequestTest(spec: spec)
   }
 
+  func testOptionsPreflightOriginBased() throws {
+    let spec = PreflightRequestSpec(
+      configuration: .init(
+        allowedOrigins: .originBased,
+        allowedHeaders: ["x-grpc-web"],
+        allowCredentialedRequests: false,
+        preflightCacheExpiration: 60
+      ),
+      requestOrigin: "foo",
+      expectOrigin: "foo",
+      expectAllowedHeaders: ["x-grpc-web"],
+      expectAllowCredentials: false,
+      expectMaxAge: "60"
+    )
+    try self.runPreflightRequestTest(spec: spec)
+  }
+
+  func testOptionsPreflightCustom() throws {
+    struct Wrapper: GRPCCustomCORSAllowedOrigin {
+      func check(origin: String) -> String? {
+        if origin == "foo" {
+          return "bar"
+        } else {
+          return nil
+        }
+      }
+    }
+
+    let spec = PreflightRequestSpec(
+      configuration: .init(
+        allowedOrigins: .custom(Wrapper()),
+        allowedHeaders: ["x-grpc-web"],
+        allowCredentialedRequests: false,
+        preflightCacheExpiration: 60
+      ),
+      requestOrigin: "foo",
+      expectOrigin: "bar",
+      expectAllowedHeaders: ["x-grpc-web"],
+      expectAllowCredentials: false,
+      expectMaxAge: "60"
+    )
+    try self.runPreflightRequestTest(spec: spec)
+  }
+
   func testOptionsPreflightAllowSomeOrigins() throws {
     let spec = PreflightRequestSpec(
       configuration: .init(