|
|
@@ -489,7 +489,9 @@ extension Server.Configuration.CORS {
|
|
|
public struct AllowedOrigins: Hashable, Sendable {
|
|
|
enum Wrapped: Hashable, Sendable {
|
|
|
case all
|
|
|
+ case originBased
|
|
|
case only([String])
|
|
|
+ case custom(AnyCustomCORSAllowedOrigin)
|
|
|
}
|
|
|
|
|
|
private(set) var wrapped: Wrapped
|
|
|
@@ -500,10 +502,23 @@ extension Server.Configuration.CORS {
|
|
|
/// Allow all origin values.
|
|
|
public static let all = Self(.all)
|
|
|
|
|
|
+ /// Allow all origin values; similar to `all` but returns the value of the origin header field
|
|
|
+ /// in the 'access-control-allow-origin' response header (rather than "*").
|
|
|
+ public static let originBased = Self(.originBased)
|
|
|
+
|
|
|
/// Allow only the given origin values.
|
|
|
public static func only(_ allowed: [String]) -> Self {
|
|
|
return Self(.only(allowed))
|
|
|
}
|
|
|
+
|
|
|
+ /// Provide a custom CORS origin check.
|
|
|
+ ///
|
|
|
+ /// - Parameter checkOrigin: A closure which is called with the value of the 'origin' header
|
|
|
+ /// and returns the value to use in the 'access-control-allow-origin' response header,
|
|
|
+ /// or `nil` if the origin is not allowed.
|
|
|
+ public static func custom<C: GRPCCustomCORSAllowedOrigin>(_ custom: C) -> Self {
|
|
|
+ return Self(.custom(AnyCustomCORSAllowedOrigin(custom)))
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -530,3 +545,46 @@ extension Comparable {
|
|
|
return min(max(self, range.lowerBound), range.upperBound)
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+public protocol GRPCCustomCORSAllowedOrigin: Sendable, Hashable {
|
|
|
+ /// Returns the value to use for the 'access-control-allow-origin' response header for the given
|
|
|
+ /// value of the 'origin' request header.
|
|
|
+ ///
|
|
|
+ /// - Parameter origin: The value of the 'origin' request header field.
|
|
|
+ /// - Returns: The value to use for the 'access-control-allow-origin' header field or `nil` if no
|
|
|
+ /// CORS related headers should be returned.
|
|
|
+ func check(origin: String) -> String?
|
|
|
+}
|
|
|
+
|
|
|
+extension Server.Configuration.CORS.AllowedOrigins {
|
|
|
+ struct AnyCustomCORSAllowedOrigin: GRPCCustomCORSAllowedOrigin {
|
|
|
+ private var checkOrigin: @Sendable (String) -> String?
|
|
|
+ private let hashInto: @Sendable (inout Hasher) -> Void
|
|
|
+ #if swift(>=5.7)
|
|
|
+ private let isEqualTo: @Sendable (any GRPCCustomCORSAllowedOrigin) -> Bool
|
|
|
+ #else
|
|
|
+ private let isEqualTo: @Sendable (Any) -> Bool
|
|
|
+ #endif
|
|
|
+
|
|
|
+ init<W: GRPCCustomCORSAllowedOrigin>(_ wrap: W) {
|
|
|
+ self.checkOrigin = { wrap.check(origin: $0) }
|
|
|
+ self.hashInto = { wrap.hash(into: &$0) }
|
|
|
+ self.isEqualTo = { wrap == ($0 as? W) }
|
|
|
+ }
|
|
|
+
|
|
|
+ func check(origin: String) -> String? {
|
|
|
+ return self.checkOrigin(origin)
|
|
|
+ }
|
|
|
+
|
|
|
+ func hash(into hasher: inout Hasher) {
|
|
|
+ self.hashInto(&hasher)
|
|
|
+ }
|
|
|
+
|
|
|
+ static func == (
|
|
|
+ lhs: Server.Configuration.CORS.AllowedOrigins.AnyCustomCORSAllowedOrigin,
|
|
|
+ rhs: Server.Configuration.CORS.AllowedOrigins.AnyCustomCORSAllowedOrigin
|
|
|
+ ) -> Bool {
|
|
|
+ return lhs.isEqualTo(rhs)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|