WebCORSHandlerTests.swift 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. /*
  2. * Copyright 2023, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. @testable import GRPC
  17. import NIOCore
  18. import NIOEmbedded
  19. import NIOHTTP1
  20. import XCTest
  21. internal final class WebCORSHandlerTests: XCTestCase {
  22. struct PreflightRequestSpec {
  23. var configuration: Server.Configuration.CORS
  24. var requestOrigin: Optional<String>
  25. var expectOrigin: Optional<String>
  26. var expectAllowedHeaders: [String]
  27. var expectAllowCredentials: Bool
  28. var expectMaxAge: Optional<String>
  29. var expectStatus: HTTPResponseStatus = .ok
  30. }
  31. func runPreflightRequestTest(spec: PreflightRequestSpec) throws {
  32. let channel = EmbeddedChannel(handler: WebCORSHandler(configuration: spec.configuration))
  33. var request = HTTPRequestHead(version: .http1_1, method: .OPTIONS, uri: "http://foo.example")
  34. if let origin = spec.requestOrigin {
  35. request.headers.add(name: "origin", value: origin)
  36. }
  37. request.headers.add(name: "access-control-request-method", value: "POST")
  38. try channel.writeRequestPart(.head(request))
  39. try channel.writeRequestPart(.end(nil))
  40. switch try channel.readResponsePart() {
  41. case let .head(response):
  42. XCTAssertEqual(response.version, request.version)
  43. if let expected = spec.expectOrigin {
  44. XCTAssertEqual(response.headers["access-control-allow-origin"], [expected])
  45. } else {
  46. XCTAssertFalse(response.headers.contains(name: "access-control-allow-origin"))
  47. }
  48. if spec.expectAllowedHeaders.isEmpty {
  49. XCTAssertFalse(response.headers.contains(name: "access-control-allow-headers"))
  50. } else {
  51. XCTAssertEqual(response.headers["access-control-allow-headers"], spec.expectAllowedHeaders)
  52. }
  53. if spec.expectAllowCredentials {
  54. XCTAssertEqual(response.headers["access-control-allow-credentials"], ["true"])
  55. } else {
  56. XCTAssertFalse(response.headers.contains(name: "access-control-allow-credentials"))
  57. }
  58. if let maxAge = spec.expectMaxAge {
  59. XCTAssertEqual(response.headers["access-control-max-age"], [maxAge])
  60. } else {
  61. XCTAssertFalse(response.headers.contains(name: "access-control-max-age"))
  62. }
  63. XCTAssertEqual(response.status, spec.expectStatus)
  64. case .body, .end, .none:
  65. XCTFail("Unexpected response part")
  66. }
  67. }
  68. func testOptionsPreflightAllowAllOrigins() throws {
  69. let spec = PreflightRequestSpec(
  70. configuration: .init(
  71. allowedOrigins: .all,
  72. allowedHeaders: ["x-grpc-web"],
  73. allowCredentialedRequests: false,
  74. preflightCacheExpiration: 60
  75. ),
  76. requestOrigin: "foo",
  77. expectOrigin: "*",
  78. expectAllowedHeaders: ["x-grpc-web"],
  79. expectAllowCredentials: false,
  80. expectMaxAge: "60"
  81. )
  82. try self.runPreflightRequestTest(spec: spec)
  83. }
  84. func testOptionsPreflightAllowSomeOrigins() throws {
  85. let spec = PreflightRequestSpec(
  86. configuration: .init(
  87. allowedOrigins: .only(["bar", "foo"]),
  88. allowedHeaders: ["x-grpc-web"],
  89. allowCredentialedRequests: false,
  90. preflightCacheExpiration: 60
  91. ),
  92. requestOrigin: "foo",
  93. expectOrigin: "foo",
  94. expectAllowedHeaders: ["x-grpc-web"],
  95. expectAllowCredentials: false,
  96. expectMaxAge: "60"
  97. )
  98. try self.runPreflightRequestTest(spec: spec)
  99. }
  100. func testOptionsPreflightAllowNoHeaders() throws {
  101. let spec = PreflightRequestSpec(
  102. configuration: .init(
  103. allowedOrigins: .all,
  104. allowedHeaders: [],
  105. allowCredentialedRequests: false,
  106. preflightCacheExpiration: 60
  107. ),
  108. requestOrigin: "foo",
  109. expectOrigin: "*",
  110. expectAllowedHeaders: [],
  111. expectAllowCredentials: false,
  112. expectMaxAge: "60"
  113. )
  114. try self.runPreflightRequestTest(spec: spec)
  115. }
  116. func testOptionsPreflightNoMaxAge() throws {
  117. let spec = PreflightRequestSpec(
  118. configuration: .init(
  119. allowedOrigins: .all,
  120. allowedHeaders: [],
  121. allowCredentialedRequests: false,
  122. preflightCacheExpiration: 0
  123. ),
  124. requestOrigin: "foo",
  125. expectOrigin: "*",
  126. expectAllowedHeaders: [],
  127. expectAllowCredentials: false,
  128. expectMaxAge: nil
  129. )
  130. try self.runPreflightRequestTest(spec: spec)
  131. }
  132. func testOptionsPreflightNegativeMaxAge() throws {
  133. let spec = PreflightRequestSpec(
  134. configuration: .init(
  135. allowedOrigins: .all,
  136. allowedHeaders: [],
  137. allowCredentialedRequests: false,
  138. preflightCacheExpiration: -1
  139. ),
  140. requestOrigin: "foo",
  141. expectOrigin: "*",
  142. expectAllowedHeaders: [],
  143. expectAllowCredentials: false,
  144. expectMaxAge: nil
  145. )
  146. try self.runPreflightRequestTest(spec: spec)
  147. }
  148. func testOptionsPreflightWithCredentials() throws {
  149. let spec = PreflightRequestSpec(
  150. configuration: .init(
  151. allowedOrigins: .all,
  152. allowedHeaders: [],
  153. allowCredentialedRequests: true,
  154. preflightCacheExpiration: 60
  155. ),
  156. requestOrigin: "foo",
  157. expectOrigin: "*",
  158. expectAllowedHeaders: [],
  159. expectAllowCredentials: true,
  160. expectMaxAge: "60"
  161. )
  162. try self.runPreflightRequestTest(spec: spec)
  163. }
  164. func testOptionsPreflightWithDisallowedOrigin() throws {
  165. let spec = PreflightRequestSpec(
  166. configuration: .init(
  167. allowedOrigins: .only(["foo"]),
  168. allowedHeaders: [],
  169. allowCredentialedRequests: false,
  170. preflightCacheExpiration: 60
  171. ),
  172. requestOrigin: "bar",
  173. expectOrigin: nil,
  174. expectAllowedHeaders: [],
  175. expectAllowCredentials: false,
  176. expectMaxAge: nil,
  177. expectStatus: .forbidden
  178. )
  179. try self.runPreflightRequestTest(spec: spec)
  180. }
  181. }
  182. extension WebCORSHandlerTests {
  183. struct RegularRequestSpec {
  184. var configuration: Server.Configuration.CORS
  185. var requestOrigin: Optional<String>
  186. var expectOrigin: Optional<String>
  187. var expectAllowCredentials: Bool
  188. }
  189. func runRegularRequestTest(
  190. spec: RegularRequestSpec
  191. ) throws {
  192. let channel = EmbeddedChannel(handler: WebCORSHandler(configuration: spec.configuration))
  193. var request = HTTPRequestHead(version: .http1_1, method: .OPTIONS, uri: "http://foo.example")
  194. if let origin = spec.requestOrigin {
  195. request.headers.add(name: "origin", value: origin)
  196. }
  197. try channel.writeRequestPart(.head(request))
  198. try channel.writeRequestPart(.end(nil))
  199. XCTAssertEqual(try channel.readRequestPart(), .head(request))
  200. XCTAssertEqual(try channel.readRequestPart(), .end(nil))
  201. let response = HTTPResponseHead(version: request.version, status: .imATeapot)
  202. try channel.writeResponsePart(.head(response))
  203. try channel.writeResponsePart(.end(nil))
  204. switch try channel.readResponsePart() {
  205. case let .head(head):
  206. // Should not be modified.
  207. XCTAssertEqual(head.version, response.version)
  208. XCTAssertEqual(head.status, response.status)
  209. if let expected = spec.expectOrigin {
  210. XCTAssertEqual(head.headers["access-control-allow-origin"], [expected])
  211. } else {
  212. XCTAssertFalse(head.headers.contains(name: "access-control-allow-origin"))
  213. }
  214. if spec.expectAllowCredentials {
  215. XCTAssertEqual(head.headers["access-control-allow-credentials"], ["true"])
  216. } else {
  217. XCTAssertFalse(head.headers.contains(name: "access-control-allow-credentials"))
  218. }
  219. case .body, .end, .none:
  220. XCTFail("Unexpected response part")
  221. }
  222. XCTAssertEqual(try channel.readResponsePart(), .end(nil))
  223. }
  224. func testRegularRequestWithWildcardOrigin() throws {
  225. let spec = RegularRequestSpec(
  226. configuration: .init(
  227. allowedOrigins: .all,
  228. allowCredentialedRequests: false
  229. ),
  230. requestOrigin: "foo",
  231. expectOrigin: "*",
  232. expectAllowCredentials: false
  233. )
  234. try self.runRegularRequestTest(spec: spec)
  235. }
  236. func testRegularRequestWithLimitedOrigin() throws {
  237. let spec = RegularRequestSpec(
  238. configuration: .init(
  239. allowedOrigins: .only(["foo", "bar"]),
  240. allowCredentialedRequests: false
  241. ),
  242. requestOrigin: "foo",
  243. expectOrigin: "foo",
  244. expectAllowCredentials: false
  245. )
  246. try self.runRegularRequestTest(spec: spec)
  247. }
  248. func testRegularRequestWithNoOrigin() throws {
  249. let spec = RegularRequestSpec(
  250. configuration: .init(
  251. allowedOrigins: .all,
  252. allowCredentialedRequests: false
  253. ),
  254. requestOrigin: nil,
  255. expectOrigin: nil,
  256. expectAllowCredentials: false
  257. )
  258. try self.runRegularRequestTest(spec: spec)
  259. }
  260. func testRegularRequestWithCredentials() throws {
  261. let spec = RegularRequestSpec(
  262. configuration: .init(
  263. allowedOrigins: .all,
  264. allowCredentialedRequests: true
  265. ),
  266. requestOrigin: "foo",
  267. expectOrigin: "*",
  268. expectAllowCredentials: true
  269. )
  270. try self.runRegularRequestTest(spec: spec)
  271. }
  272. func testRegularRequestWithDisallowedOrigin() throws {
  273. let spec = RegularRequestSpec(
  274. configuration: .init(
  275. allowedOrigins: .only(["foo"]),
  276. allowCredentialedRequests: true
  277. ),
  278. requestOrigin: "bar",
  279. expectOrigin: nil,
  280. expectAllowCredentials: false
  281. )
  282. try self.runRegularRequestTest(spec: spec)
  283. }
  284. }
  285. extension EmbeddedChannel {
  286. fileprivate func writeRequestPart(_ part: HTTPServerRequestPart) throws {
  287. try self.writeInbound(part)
  288. }
  289. fileprivate func readRequestPart() throws -> HTTPServerRequestPart? {
  290. try self.readInbound()
  291. }
  292. fileprivate func writeResponsePart(_ part: HTTPServerResponsePart) throws {
  293. try self.writeOutbound(part)
  294. }
  295. fileprivate func readResponsePart() throws -> HTTPServerResponsePart? {
  296. try self.readOutbound()
  297. }
  298. }