ConnectionTests.swift 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import DequeModule
  17. import GRPCCore
  18. import GRPCHTTP2Core
  19. import NIOConcurrencyHelpers
  20. import NIOCore
  21. import NIOHPACK
  22. import NIOHTTP2
  23. import NIOPosix
  24. import XCTest
  25. @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
  26. final class ConnectionTests: XCTestCase {
  27. func testConnectThenClose() async throws {
  28. try await ConnectionTest.run(connector: .posix()) { context, event in
  29. switch event {
  30. case .connectSucceeded:
  31. context.connection.close()
  32. default:
  33. ()
  34. }
  35. } validateEvents: { _, events in
  36. XCTAssertEqual(events, [.connectSucceeded, .closed(.initiatedLocally)])
  37. }
  38. }
  39. func testConnectThenIdleTimeout() async throws {
  40. try await ConnectionTest.run(connector: .posix(maxIdleTime: .milliseconds(50))) { _, events in
  41. XCTAssertEqual(events, [.connectSucceeded, .closed(.idleTimeout)])
  42. }
  43. }
  44. func testConnectThenKeepaliveTimeout() async throws {
  45. try await ConnectionTest.run(
  46. connector: .posix(
  47. keepaliveTime: .milliseconds(50),
  48. keepaliveTimeout: .milliseconds(10),
  49. keepaliveWithoutCalls: true,
  50. dropPingAcks: true
  51. )
  52. ) { _, events in
  53. XCTAssertEqual(events, [.connectSucceeded, .closed(.keepaliveTimeout)])
  54. }
  55. }
  56. func testGoAwayWhenConnected() async throws {
  57. try await ConnectionTest.run(connector: .posix()) { context, event in
  58. switch event {
  59. case .connectSucceeded:
  60. let goAway = HTTP2Frame(
  61. streamID: .rootStream,
  62. payload: .goAway(
  63. lastStreamID: 0,
  64. errorCode: .noError,
  65. opaqueData: ByteBuffer(string: "Hello!")
  66. )
  67. )
  68. let accepted = try context.server.acceptedChannel
  69. accepted.writeAndFlush(goAway, promise: nil)
  70. default:
  71. ()
  72. }
  73. } validateEvents: { _, events in
  74. XCTAssertEqual(events, [.connectSucceeded, .goingAway(.noError, "Hello!"), .closed(.remote)])
  75. }
  76. }
  77. func testConnectionDropWhenConnected() async throws {
  78. try await ConnectionTest.run(connector: .posix()) { context, event in
  79. switch event {
  80. case .connectSucceeded:
  81. let accepted = try context.server.acceptedChannel
  82. accepted.close(mode: .all, promise: nil)
  83. default:
  84. ()
  85. }
  86. } validateEvents: { _, events in
  87. let error = RPCError(
  88. code: .unavailable,
  89. message: "The TCP connection was dropped unexpectedly."
  90. )
  91. let expected: [Connection.Event] = [.connectSucceeded, .closed(.error(error, wasIdle: true))]
  92. XCTAssertEqual(events, expected)
  93. }
  94. }
  95. func testConnectFails() async throws {
  96. let error = RPCError(code: .unimplemented, message: "")
  97. try await ConnectionTest.run(connector: .throwing(error)) { _, events in
  98. XCTAssertEqual(events, [.connectFailed(error)])
  99. }
  100. }
  101. func testConnectFailsOnAcceptedThenClosedTCPConnection() async throws {
  102. try await ConnectionTest.run(connector: .posix(), server: .closeOnAccept) { _, events in
  103. XCTAssertEqual(events.count, 1)
  104. let event = try XCTUnwrap(events.first)
  105. switch event {
  106. case .connectFailed(let error):
  107. XCTAssert(error, as: RPCError.self) { rpcError in
  108. XCTAssertEqual(rpcError.code, .unavailable)
  109. }
  110. default:
  111. XCTFail("Expected '.connectFailed', got '\(event)'")
  112. }
  113. }
  114. }
  115. func testMakeStreamOnActiveConnection() async throws {
  116. try await ConnectionTest.run(connector: .posix()) { context, event in
  117. switch event {
  118. case .connectSucceeded:
  119. let stream = try await context.connection.makeStream(
  120. descriptor: .echoGet,
  121. options: .defaults
  122. )
  123. try await stream.execute { inbound, outbound in
  124. try await outbound.write(.metadata(["foo": "bar", "bar": "baz"]))
  125. try await outbound.write(.message([0, 1, 2]))
  126. outbound.finish()
  127. var parts = [RPCResponsePart]()
  128. for try await part in inbound {
  129. switch part {
  130. case .metadata(let metadata):
  131. // Filter out any transport specific metadata
  132. parts.append(.metadata(Metadata(metadata.suffix(2))))
  133. case .message, .status:
  134. parts.append(part)
  135. }
  136. }
  137. let expected: [RPCResponsePart] = [
  138. .metadata(["foo": "bar", "bar": "baz"]),
  139. .message([0, 1, 2]),
  140. .status(Status(code: .ok, message: ""), [:]),
  141. ]
  142. XCTAssertEqual(parts, expected)
  143. }
  144. context.connection.close()
  145. default:
  146. ()
  147. }
  148. } validateEvents: { _, events in
  149. XCTAssertEqual(events, [.connectSucceeded, .closed(.initiatedLocally)])
  150. }
  151. }
  152. func testMakeStreamOnClosedConnection() async throws {
  153. try await ConnectionTest.run(connector: .posix()) { context, event in
  154. switch event {
  155. case .connectSucceeded:
  156. context.connection.close()
  157. case .closed:
  158. await XCTAssertThrowsErrorAsync(ofType: RPCError.self) {
  159. _ = try await context.connection.makeStream(descriptor: .echoGet, options: .defaults)
  160. } errorHandler: { error in
  161. XCTAssertEqual(error.code, .unavailable)
  162. }
  163. default:
  164. ()
  165. }
  166. } validateEvents: { context, events in
  167. XCTAssertEqual(events, [.connectSucceeded, .closed(.initiatedLocally)])
  168. }
  169. }
  170. func testMakeStreamOnNotRunningConnection() async throws {
  171. let connection = Connection(
  172. address: .ipv4(host: "ignored", port: 0),
  173. http2Connector: .never,
  174. defaultCompression: .none,
  175. enabledCompression: .none
  176. )
  177. await XCTAssertThrowsErrorAsync(ofType: RPCError.self) {
  178. _ = try await connection.makeStream(descriptor: .echoGet, options: .defaults)
  179. } errorHandler: { error in
  180. XCTAssertEqual(error.code, .unavailable)
  181. }
  182. }
  183. }
  184. extension ClientBootstrap {
  185. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
  186. func connect<T>(
  187. to address: GRPCHTTP2Core.SocketAddress,
  188. _ configure: @Sendable @escaping (Channel) -> EventLoopFuture<T>
  189. ) async throws -> T {
  190. if let ipv4 = address.ipv4 {
  191. return try await self.connect(
  192. host: ipv4.host,
  193. port: ipv4.port,
  194. channelInitializer: configure
  195. )
  196. } else if let ipv6 = address.ipv6 {
  197. return try await self.connect(
  198. host: ipv6.host,
  199. port: ipv6.port,
  200. channelInitializer: configure
  201. )
  202. } else if let uds = address.unixDomainSocket {
  203. return try await self.connect(
  204. unixDomainSocketPath: uds.path,
  205. channelInitializer: configure
  206. )
  207. } else if let vsock = address.virtualSocket {
  208. return try await self.connect(
  209. to: VsockAddress(
  210. cid: .init(Int(vsock.contextID.rawValue)),
  211. port: .init(Int(vsock.port.rawValue))
  212. ),
  213. channelInitializer: configure
  214. )
  215. } else {
  216. throw RPCError(code: .unimplemented, message: "Unhandled socket address: \(address)")
  217. }
  218. }
  219. }
  220. extension Metadata {
  221. init(_ sequence: some Sequence<Element>) {
  222. var metadata = Metadata()
  223. for (key, value) in sequence {
  224. switch value {
  225. case .string(let value):
  226. metadata.addString(value, forKey: key)
  227. case .binary(let value):
  228. metadata.addBinary(value, forKey: key)
  229. }
  230. }
  231. self = metadata
  232. }
  233. }