WebCORSHandlerTests.swift 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. /*
  2. * Copyright 2023, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import NIOCore
  17. import NIOEmbedded
  18. import NIOHTTP1
  19. import XCTest
  20. @testable import GRPC
  21. internal final class WebCORSHandlerTests: XCTestCase {
  22. struct PreflightRequestSpec {
  23. var configuration: Server.Configuration.CORS
  24. var requestOrigin: Optional<String>
  25. var expectOrigin: Optional<String>
  26. var expectAllowedHeaders: [String]
  27. var expectAllowCredentials: Bool
  28. var expectMaxAge: Optional<String>
  29. var expectStatus: HTTPResponseStatus = .ok
  30. }
  31. func runPreflightRequestTest(spec: PreflightRequestSpec) throws {
  32. let channel = EmbeddedChannel(handler: WebCORSHandler(configuration: spec.configuration))
  33. var request = HTTPRequestHead(version: .http1_1, method: .OPTIONS, uri: "http://foo.example")
  34. if let origin = spec.requestOrigin {
  35. request.headers.add(name: "origin", value: origin)
  36. }
  37. request.headers.add(name: "access-control-request-method", value: "POST")
  38. try channel.writeRequestPart(.head(request))
  39. try channel.writeRequestPart(.end(nil))
  40. switch try channel.readResponsePart() {
  41. case let .head(response):
  42. XCTAssertEqual(response.version, request.version)
  43. if let expected = spec.expectOrigin {
  44. XCTAssertEqual(response.headers["access-control-allow-origin"], [expected])
  45. } else {
  46. XCTAssertFalse(response.headers.contains(name: "access-control-allow-origin"))
  47. }
  48. if spec.expectAllowedHeaders.isEmpty {
  49. XCTAssertFalse(response.headers.contains(name: "access-control-allow-headers"))
  50. } else {
  51. XCTAssertEqual(response.headers["access-control-allow-headers"], spec.expectAllowedHeaders)
  52. }
  53. if spec.expectAllowCredentials {
  54. XCTAssertEqual(response.headers["access-control-allow-credentials"], ["true"])
  55. } else {
  56. XCTAssertFalse(response.headers.contains(name: "access-control-allow-credentials"))
  57. }
  58. if let maxAge = spec.expectMaxAge {
  59. XCTAssertEqual(response.headers["access-control-max-age"], [maxAge])
  60. } else {
  61. XCTAssertFalse(response.headers.contains(name: "access-control-max-age"))
  62. }
  63. XCTAssertEqual(response.status, spec.expectStatus)
  64. case .body, .end, .none:
  65. XCTFail("Unexpected response part")
  66. }
  67. }
  68. func testOptionsPreflightAllowAllOrigins() throws {
  69. let spec = PreflightRequestSpec(
  70. configuration: .init(
  71. allowedOrigins: .all,
  72. allowedHeaders: ["x-grpc-web"],
  73. allowCredentialedRequests: false,
  74. preflightCacheExpiration: 60
  75. ),
  76. requestOrigin: "foo",
  77. expectOrigin: "*",
  78. expectAllowedHeaders: ["x-grpc-web"],
  79. expectAllowCredentials: false,
  80. expectMaxAge: "60"
  81. )
  82. try self.runPreflightRequestTest(spec: spec)
  83. }
  84. func testOptionsPreflightOriginBased() throws {
  85. let spec = PreflightRequestSpec(
  86. configuration: .init(
  87. allowedOrigins: .originBased,
  88. allowedHeaders: ["x-grpc-web"],
  89. allowCredentialedRequests: false,
  90. preflightCacheExpiration: 60
  91. ),
  92. requestOrigin: "foo",
  93. expectOrigin: "foo",
  94. expectAllowedHeaders: ["x-grpc-web"],
  95. expectAllowCredentials: false,
  96. expectMaxAge: "60"
  97. )
  98. try self.runPreflightRequestTest(spec: spec)
  99. }
  100. func testOptionsPreflightCustom() throws {
  101. struct Wrapper: GRPCCustomCORSAllowedOrigin {
  102. func check(origin: String) -> String? {
  103. if origin == "foo" {
  104. return "bar"
  105. } else {
  106. return nil
  107. }
  108. }
  109. }
  110. let spec = PreflightRequestSpec(
  111. configuration: .init(
  112. allowedOrigins: .custom(Wrapper()),
  113. allowedHeaders: ["x-grpc-web"],
  114. allowCredentialedRequests: false,
  115. preflightCacheExpiration: 60
  116. ),
  117. requestOrigin: "foo",
  118. expectOrigin: "bar",
  119. expectAllowedHeaders: ["x-grpc-web"],
  120. expectAllowCredentials: false,
  121. expectMaxAge: "60"
  122. )
  123. try self.runPreflightRequestTest(spec: spec)
  124. }
  125. func testOptionsPreflightAllowSomeOrigins() throws {
  126. let spec = PreflightRequestSpec(
  127. configuration: .init(
  128. allowedOrigins: .only(["bar", "foo"]),
  129. allowedHeaders: ["x-grpc-web"],
  130. allowCredentialedRequests: false,
  131. preflightCacheExpiration: 60
  132. ),
  133. requestOrigin: "foo",
  134. expectOrigin: "foo",
  135. expectAllowedHeaders: ["x-grpc-web"],
  136. expectAllowCredentials: false,
  137. expectMaxAge: "60"
  138. )
  139. try self.runPreflightRequestTest(spec: spec)
  140. }
  141. func testOptionsPreflightAllowNoHeaders() throws {
  142. let spec = PreflightRequestSpec(
  143. configuration: .init(
  144. allowedOrigins: .all,
  145. allowedHeaders: [],
  146. allowCredentialedRequests: false,
  147. preflightCacheExpiration: 60
  148. ),
  149. requestOrigin: "foo",
  150. expectOrigin: "*",
  151. expectAllowedHeaders: [],
  152. expectAllowCredentials: false,
  153. expectMaxAge: "60"
  154. )
  155. try self.runPreflightRequestTest(spec: spec)
  156. }
  157. func testOptionsPreflightNoMaxAge() throws {
  158. let spec = PreflightRequestSpec(
  159. configuration: .init(
  160. allowedOrigins: .all,
  161. allowedHeaders: [],
  162. allowCredentialedRequests: false,
  163. preflightCacheExpiration: 0
  164. ),
  165. requestOrigin: "foo",
  166. expectOrigin: "*",
  167. expectAllowedHeaders: [],
  168. expectAllowCredentials: false,
  169. expectMaxAge: nil
  170. )
  171. try self.runPreflightRequestTest(spec: spec)
  172. }
  173. func testOptionsPreflightNegativeMaxAge() throws {
  174. let spec = PreflightRequestSpec(
  175. configuration: .init(
  176. allowedOrigins: .all,
  177. allowedHeaders: [],
  178. allowCredentialedRequests: false,
  179. preflightCacheExpiration: -1
  180. ),
  181. requestOrigin: "foo",
  182. expectOrigin: "*",
  183. expectAllowedHeaders: [],
  184. expectAllowCredentials: false,
  185. expectMaxAge: nil
  186. )
  187. try self.runPreflightRequestTest(spec: spec)
  188. }
  189. func testOptionsPreflightWithCredentials() throws {
  190. let spec = PreflightRequestSpec(
  191. configuration: .init(
  192. allowedOrigins: .all,
  193. allowedHeaders: [],
  194. allowCredentialedRequests: true,
  195. preflightCacheExpiration: 60
  196. ),
  197. requestOrigin: "foo",
  198. expectOrigin: "*",
  199. expectAllowedHeaders: [],
  200. expectAllowCredentials: true,
  201. expectMaxAge: "60"
  202. )
  203. try self.runPreflightRequestTest(spec: spec)
  204. }
  205. func testOptionsPreflightWithDisallowedOrigin() throws {
  206. let spec = PreflightRequestSpec(
  207. configuration: .init(
  208. allowedOrigins: .only(["foo"]),
  209. allowedHeaders: [],
  210. allowCredentialedRequests: false,
  211. preflightCacheExpiration: 60
  212. ),
  213. requestOrigin: "bar",
  214. expectOrigin: nil,
  215. expectAllowedHeaders: [],
  216. expectAllowCredentials: false,
  217. expectMaxAge: nil,
  218. expectStatus: .forbidden
  219. )
  220. try self.runPreflightRequestTest(spec: spec)
  221. }
  222. }
  223. extension WebCORSHandlerTests {
  224. struct RegularRequestSpec {
  225. var configuration: Server.Configuration.CORS
  226. var requestOrigin: Optional<String>
  227. var expectOrigin: Optional<String>
  228. var expectAllowCredentials: Bool
  229. }
  230. func runRegularRequestTest(
  231. spec: RegularRequestSpec
  232. ) throws {
  233. let channel = EmbeddedChannel(handler: WebCORSHandler(configuration: spec.configuration))
  234. var request = HTTPRequestHead(version: .http1_1, method: .OPTIONS, uri: "http://foo.example")
  235. if let origin = spec.requestOrigin {
  236. request.headers.add(name: "origin", value: origin)
  237. }
  238. try channel.writeRequestPart(.head(request))
  239. try channel.writeRequestPart(.end(nil))
  240. XCTAssertEqual(try channel.readRequestPart(), .head(request))
  241. XCTAssertEqual(try channel.readRequestPart(), .end(nil))
  242. let response = HTTPResponseHead(version: request.version, status: .imATeapot)
  243. try channel.writeResponsePart(.head(response))
  244. try channel.writeResponsePart(.end(nil))
  245. switch try channel.readResponsePart() {
  246. case let .head(head):
  247. // Should not be modified.
  248. XCTAssertEqual(head.version, response.version)
  249. XCTAssertEqual(head.status, response.status)
  250. if let expected = spec.expectOrigin {
  251. XCTAssertEqual(head.headers["access-control-allow-origin"], [expected])
  252. } else {
  253. XCTAssertFalse(head.headers.contains(name: "access-control-allow-origin"))
  254. }
  255. if spec.expectAllowCredentials {
  256. XCTAssertEqual(head.headers["access-control-allow-credentials"], ["true"])
  257. } else {
  258. XCTAssertFalse(head.headers.contains(name: "access-control-allow-credentials"))
  259. }
  260. case .body, .end, .none:
  261. XCTFail("Unexpected response part")
  262. }
  263. XCTAssertEqual(try channel.readResponsePart(), .end(nil))
  264. }
  265. func testRegularRequestWithWildcardOrigin() throws {
  266. let spec = RegularRequestSpec(
  267. configuration: .init(
  268. allowedOrigins: .all,
  269. allowCredentialedRequests: false
  270. ),
  271. requestOrigin: "foo",
  272. expectOrigin: "*",
  273. expectAllowCredentials: false
  274. )
  275. try self.runRegularRequestTest(spec: spec)
  276. }
  277. func testRegularRequestWithLimitedOrigin() throws {
  278. let spec = RegularRequestSpec(
  279. configuration: .init(
  280. allowedOrigins: .only(["foo", "bar"]),
  281. allowCredentialedRequests: false
  282. ),
  283. requestOrigin: "foo",
  284. expectOrigin: "foo",
  285. expectAllowCredentials: false
  286. )
  287. try self.runRegularRequestTest(spec: spec)
  288. }
  289. func testRegularRequestWithNoOrigin() throws {
  290. let spec = RegularRequestSpec(
  291. configuration: .init(
  292. allowedOrigins: .all,
  293. allowCredentialedRequests: false
  294. ),
  295. requestOrigin: nil,
  296. expectOrigin: nil,
  297. expectAllowCredentials: false
  298. )
  299. try self.runRegularRequestTest(spec: spec)
  300. }
  301. func testRegularRequestWithCredentials() throws {
  302. let spec = RegularRequestSpec(
  303. configuration: .init(
  304. allowedOrigins: .all,
  305. allowCredentialedRequests: true
  306. ),
  307. requestOrigin: "foo",
  308. expectOrigin: "*",
  309. expectAllowCredentials: true
  310. )
  311. try self.runRegularRequestTest(spec: spec)
  312. }
  313. func testRegularRequestWithDisallowedOrigin() throws {
  314. let spec = RegularRequestSpec(
  315. configuration: .init(
  316. allowedOrigins: .only(["foo"]),
  317. allowCredentialedRequests: true
  318. ),
  319. requestOrigin: "bar",
  320. expectOrigin: nil,
  321. expectAllowCredentials: false
  322. )
  323. try self.runRegularRequestTest(spec: spec)
  324. }
  325. }
  326. extension EmbeddedChannel {
  327. fileprivate func writeRequestPart(_ part: HTTPServerRequestPart) throws {
  328. try self.writeInbound(part)
  329. }
  330. fileprivate func readRequestPart() throws -> HTTPServerRequestPart? {
  331. try self.readInbound()
  332. }
  333. fileprivate func writeResponsePart(_ part: HTTPServerResponsePart) throws {
  334. try self.writeOutbound(part)
  335. }
  336. fileprivate func readResponsePart() throws -> HTTPServerResponsePart? {
  337. try self.readOutbound()
  338. }
  339. }