WorkerService.swift 17 KB


  1. /*
  2. * Copyright 2024, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import GRPCCore
  17. import GRPCNIOTransportHTTP2
  18. import NIOConcurrencyHelpers
  19. import NIOCore
  20. import NIOPosix
  21. final class WorkerService: Sendable {
  22. private let state: NIOLockedValueBox<State>
  23. init() {
  24. self.state = NIOLockedValueBox(State())
  25. }
  26. private struct State {
  27. private var role: Role
  28. enum Role {
  29. case none
  30. case client(Client)
  31. case server(Server)
  32. }
  33. struct Server {
  34. var server: GRPCServer
  35. var stats: ServerStats
  36. var eventLoopGroup: MultiThreadedEventLoopGroup
  37. }
  38. struct Client {
  39. var clients: [BenchmarkClient]
  40. var stats: ClientStats
  41. var rpcStats: RPCStats
  42. }
  43. init() {
  44. self.role = .none
  45. }
  46. mutating func collectServerStats(replaceWith newStats: ServerStats? = nil) -> ServerStats? {
  47. switch self.role {
  48. case var .server(serverState):
  49. let stats = serverState.stats
  50. if let newStats = newStats {
  51. serverState.stats = newStats
  52. self.role = .server(serverState)
  53. }
  54. return stats
  55. case .client, .none:
  56. return nil
  57. }
  58. }
  59. mutating func collectClientStats(
  60. replaceWith newStats: ClientStats? = nil
  61. ) -> (ClientStats, RPCStats)? {
  62. switch self.role {
  63. case var .client(state):
  64. // Grab the existing stats and update if necessary.
  65. let stats = state.stats
  66. if let newStats = newStats {
  67. state.stats = newStats
  68. }
  69. // Merge in RPC stats from each client.
  70. for client in state.clients {
  71. try? state.rpcStats.merge(client.currentStats)
  72. }
  73. self.role = .client(state)
  74. return (stats, state.rpcStats)
  75. case .server, .none:
  76. return nil
  77. }
  78. }
  79. enum OnStartedServer {
  80. case runServer
  81. case invalidState(RPCError)
  82. }
  83. mutating func startedServer(
  84. _ server: GRPCServer,
  85. stats: ServerStats,
  86. eventLoopGroup: MultiThreadedEventLoopGroup
  87. ) -> OnStartedServer {
  88. let action: OnStartedServer
  89. switch self.role {
  90. case .none:
  91. let state = State.Server(server: server, stats: stats, eventLoopGroup: eventLoopGroup)
  92. self.role = .server(state)
  93. action = .runServer
  94. case .server:
  95. let error = RPCError(code: .alreadyExists, message: "A server has already been set up.")
  96. action = .invalidState(error)
  97. case .client:
  98. let error = RPCError(code: .failedPrecondition, message: "This worker has a client setup.")
  99. action = .invalidState(error)
  100. }
  101. return action
  102. }
  103. enum OnStartedClients {
  104. case runClients
  105. case invalidState(RPCError)
  106. }
  107. mutating func startedClients(
  108. _ clients: [BenchmarkClient],
  109. stats: ClientStats,
  110. rpcStats: RPCStats
  111. ) -> OnStartedClients {
  112. let action: OnStartedClients
  113. switch self.role {
  114. case .none:
  115. let state = State.Client(clients: clients, stats: stats, rpcStats: rpcStats)
  116. self.role = .client(state)
  117. action = .runClients
  118. case .server:
  119. let error = RPCError(code: .alreadyExists, message: "This worker has a server setup.")
  120. action = .invalidState(error)
  121. case .client:
  122. let error = RPCError(
  123. code: .failedPrecondition,
  124. message: "Clients have already been set up."
  125. )
  126. action = .invalidState(error)
  127. }
  128. return action
  129. }
  130. enum OnServerShutDown {
  131. case shutdown(MultiThreadedEventLoopGroup)
  132. case nothing
  133. }
  134. mutating func serverShutdown() -> OnServerShutDown {
  135. switch self.role {
  136. case .client:
  137. preconditionFailure("Invalid state")
  138. case .server(let state):
  139. self.role = .none
  140. return .shutdown(state.eventLoopGroup)
  141. case .none:
  142. return .nothing
  143. }
  144. }
  145. enum OnStopListening {
  146. case stopListening(GRPCServer)
  147. case nothing
  148. }
  149. func stopListening() -> OnStopListening {
  150. switch self.role {
  151. case .client:
  152. preconditionFailure("Invalid state")
  153. case .server(let state):
  154. return .stopListening(state.server)
  155. case .none:
  156. return .nothing
  157. }
  158. }
  159. enum OnCloseClient {
  160. case close([BenchmarkClient])
  161. case nothing
  162. }
  163. mutating func closeClients() -> OnCloseClient {
  164. switch self.role {
  165. case .client(let state):
  166. self.role = .none
  167. return .close(state.clients)
  168. case .server:
  169. preconditionFailure("Invalid state")
  170. case .none:
  171. return .nothing
  172. }
  173. }
  174. enum OnQuitWorker {
  175. case shutDownServer(GRPCServer)
  176. case shutDownClients([BenchmarkClient])
  177. case nothing
  178. }
  179. mutating func quit() -> OnQuitWorker {
  180. switch self.role {
  181. case .none:
  182. return .nothing
  183. case .client(let state):
  184. self.role = .none
  185. return .shutDownClients(state.clients)
  186. case .server(let state):
  187. self.role = .none
  188. return .shutDownServer(state.server)
  189. }
  190. }
  191. }
  192. }
  193. extension WorkerService: Grpc_Testing_WorkerService.ServiceProtocol {
  194. func quitWorker(
  195. request: ServerRequest<Grpc_Testing_Void>,
  196. context: ServerContext
  197. ) async throws -> ServerResponse<Grpc_Testing_Void> {
  198. let onQuit = self.state.withLockedValue { $0.quit() }
  199. switch onQuit {
  200. case .nothing:
  201. ()
  202. case .shutDownClients(let clients):
  203. for client in clients {
  204. client.shutdown()
  205. }
  206. case .shutDownServer(let server):
  207. server.beginGracefulShutdown()
  208. }
  209. return ServerResponse(message: Grpc_Testing_Void())
  210. }
  211. func coreCount(
  212. request: ServerRequest<Grpc_Testing_CoreRequest>,
  213. context: ServerContext
  214. ) async throws -> ServerResponse<Grpc_Testing_CoreResponse> {
  215. let coreCount = System.coreCount
  216. return ServerResponse(
  217. message: Grpc_Testing_WorkerService.Method.CoreCount.Output.with {
  218. $0.cores = Int32(coreCount)
  219. }
  220. )
  221. }
  222. func runServer(
  223. request: StreamingServerRequest<Grpc_Testing_ServerArgs>,
  224. context: ServerContext
  225. ) async throws -> StreamingServerResponse<Grpc_Testing_ServerStatus> {
  226. return StreamingServerResponse { writer in
  227. try await withThrowingTaskGroup(of: Void.self) { group in
  228. for try await message in request.messages {
  229. switch message.argtype {
  230. case let .some(.setup(serverConfig)):
  231. let (server, transport) = try await self.startServer(serverConfig)
  232. group.addTask {
  233. let result: Result<Void, any Error>
  234. do {
  235. try await server.serve()
  236. result = .success(())
  237. } catch {
  238. result = .failure(error)
  239. }
  240. switch self.state.withLockedValue({ $0.serverShutdown() }) {
  241. case .shutdown(let eventLoopGroup):
  242. try await eventLoopGroup.shutdownGracefully()
  243. case .nothing:
  244. ()
  245. }
  246. try result.get()
  247. }
  248. // Wait for the server to bind.
  249. let address = try await transport.listeningAddress
  250. let port: Int
  251. if let ipv4 = address.ipv4 {
  252. port = ipv4.port
  253. } else if let ipv6 = address.ipv6 {
  254. port = ipv6.port
  255. } else {
  256. throw RPCError(
  257. code: .internalError,
  258. message: "Server listening on unsupported address '\(address)'"
  259. )
  260. }
  261. // Tell the client what port the server is listening on.
  262. let message = Grpc_Testing_ServerStatus.with { $0.port = Int32(port) }
  263. try await writer.write(message)
  264. case let .some(.mark(mark)):
  265. let response = try await self.makeServerStatsResponse(reset: mark.reset)
  266. try await writer.write(response)
  267. case .none:
  268. ()
  269. }
  270. }
  271. // Request stream ended, tell the server to stop listening. Once it's finished it will
  272. // shutdown its ELG.
  273. switch self.state.withLockedValue({ $0.stopListening() }) {
  274. case .stopListening(let server):
  275. server.beginGracefulShutdown()
  276. case .nothing:
  277. ()
  278. }
  279. }
  280. return [:]
  281. }
  282. }
  283. func runClient(
  284. request: StreamingServerRequest<Grpc_Testing_ClientArgs>,
  285. context: ServerContext
  286. ) async throws -> StreamingServerResponse<Grpc_Testing_ClientStatus> {
  287. return StreamingServerResponse { writer in
  288. try await withThrowingTaskGroup(of: Void.self) { group in
  289. for try await message in request.messages {
  290. switch message.argtype {
  291. case let .setup(config):
  292. // Create the clients with the initial stats.
  293. let clients = try await self.setupClients(config)
  294. for client in clients {
  295. group.addTask {
  296. try await client.run()
  297. }
  298. }
  299. let message = try await self.makeClientStatsResponse(reset: false)
  300. try await writer.write(message)
  301. case let .mark(mark):
  302. let response = try await self.makeClientStatsResponse(reset: mark.reset)
  303. try await writer.write(response)
  304. case .none:
  305. ()
  306. }
  307. }
  308. switch self.state.withLockedValue({ $0.closeClients() }) {
  309. case .close(let clients):
  310. for client in clients {
  311. client.shutdown()
  312. }
  313. case .nothing:
  314. ()
  315. }
  316. try await group.waitForAll()
  317. return [:]
  318. }
  319. }
  320. }
  321. }
  322. extension WorkerService {
  323. private func startServer(
  324. _ serverConfig: Grpc_Testing_ServerConfig
  325. ) async throws -> (GRPCServer, HTTP2ServerTransport.Posix) {
  326. // Prepare an ELG, the test might require more than the default of one.
  327. let numberOfThreads: Int
  328. if serverConfig.asyncServerThreads > 0 {
  329. numberOfThreads = Int(serverConfig.asyncServerThreads)
  330. } else {
  331. numberOfThreads = System.coreCount
  332. }
  333. let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: numberOfThreads)
  334. // Don't restrict the max payload size, the client is always trusted.
  335. var config = HTTP2ServerTransport.Posix.Config.defaults
  336. config.rpc.maxRequestPayloadSize = .max
  337. let transport = HTTP2ServerTransport.Posix(
  338. address: .ipv4(host: "127.0.0.1", port: Int(serverConfig.port)),
  339. transportSecurity: .plaintext,
  340. config: config,
  341. eventLoopGroup: eventLoopGroup
  342. )
  343. let server = GRPCServer(transport: transport, services: [BenchmarkService()])
  344. let stats = try await ServerStats()
  345. // Hold on to the server and ELG in the state machine.
  346. let action = self.state.withLockedValue {
  347. $0.startedServer(server, stats: stats, eventLoopGroup: eventLoopGroup)
  348. }
  349. switch action {
  350. case .runServer:
  351. return (server, transport)
  352. case .invalidState(let error):
  353. server.beginGracefulShutdown()
  354. try await eventLoopGroup.shutdownGracefully()
  355. throw error
  356. }
  357. }
  358. private func makeServerStatsResponse(
  359. reset: Bool
  360. ) async throws -> Grpc_Testing_WorkerService.Method.RunServer.Output {
  361. let currentStats = try await ServerStats()
  362. let initialStats = self.state.withLockedValue { state in
  363. return state.collectServerStats(replaceWith: reset ? currentStats : nil)
  364. }
  365. guard let initialStats = initialStats else {
  366. throw RPCError(
  367. code: .notFound,
  368. message: "There are no initial server stats. A server must be setup before calling 'mark'."
  369. )
  370. }
  371. let differences = currentStats.difference(to: initialStats)
  372. return Grpc_Testing_WorkerService.Method.RunServer.Output.with {
  373. $0.stats = Grpc_Testing_ServerStats.with {
  374. $0.idleCpuTime = differences.idleCPUTime
  375. $0.timeElapsed = differences.time
  376. $0.timeSystem = differences.systemTime
  377. $0.timeUser = differences.userTime
  378. $0.totalCpuTime = differences.totalCPUTime
  379. }
  380. }
  381. }
  382. private func setupClients(_ config: Grpc_Testing_ClientConfig) async throws -> [BenchmarkClient] {
  383. guard let rpcType = BenchmarkClient.RPCType(config.rpcType) else {
  384. throw RPCError(code: .invalidArgument, message: "Unknown RPC type")
  385. }
  386. // Parse the server targets into resolvable targets.
  387. let ipv4Addresses = try self.parseServerTargets(config.serverTargets)
  388. let target = ResolvableTargets.IPv4(addresses: ipv4Addresses)
  389. var clients = [BenchmarkClient]()
  390. for _ in 0 ..< config.clientChannels {
  391. let client = BenchmarkClient(
  392. client: GRPCClient(
  393. transport: try .http2NIOPosix(
  394. target: target,
  395. transportSecurity: .plaintext
  396. )
  397. ),
  398. concurrentRPCs: Int(config.outstandingRpcsPerChannel),
  399. rpcType: rpcType,
  400. messagesPerStream: Int(config.messagesPerStream),
  401. protoParams: config.payloadConfig.simpleParams,
  402. histogramParams: config.histogramParams
  403. )
  404. clients.append(client)
  405. }
  406. let stats = ClientStats()
  407. let histogram = RPCStats.LatencyHistogram(
  408. resolution: config.histogramParams.resolution,
  409. maxBucketStart: config.histogramParams.maxPossible
  410. )
  411. let rpcStats = RPCStats(latencyHistogram: histogram)
  412. let action = self.state.withLockedValue { state in
  413. state.startedClients(clients, stats: stats, rpcStats: rpcStats)
  414. }
  415. switch action {
  416. case .runClients:
  417. return clients
  418. case .invalidState(let error):
  419. for client in clients {
  420. client.shutdown()
  421. }
  422. throw error
  423. }
  424. }
  425. private func parseServerTarget(_ target: String) -> GRPCNIOTransportCore.SocketAddress.IPv4? {
  426. guard let index = target.firstIndex(of: ":") else { return nil }
  427. let host = target[..<index]
  428. if let port = Int(target[target.index(after: index)...]) {
  429. return SocketAddress.IPv4(host: String(host), port: port)
  430. } else {
  431. return nil
  432. }
  433. }
  434. private func parseServerTargets(
  435. _ targets: [String]
  436. ) throws -> [GRPCNIOTransportCore.SocketAddress.IPv4] {
  437. try targets.map { target in
  438. if let ipv4 = self.parseServerTarget(target) {
  439. return ipv4
  440. } else {
  441. throw RPCError(
  442. code: .invalidArgument,
  443. message: """
  444. Couldn't parse target '\(target)'. Must be in the format '<host>:<port>' for IPv4 \
  445. or '[<host>]:<port>' for IPv6.
  446. """
  447. )
  448. }
  449. }
  450. }
  451. private func makeClientStatsResponse(
  452. reset: Bool
  453. ) async throws -> Grpc_Testing_WorkerService.Method.RunClient.Output {
  454. let currentUsageStats = ClientStats()
  455. let stats = self.state.withLockedValue { state in
  456. state.collectClientStats(replaceWith: reset ? currentUsageStats : nil)
  457. }
  458. guard let (initialUsageStats, rpcStats) = stats else {
  459. throw RPCError(
  460. code: .notFound,
  461. message: "There are no initial client stats. Clients must be setup before calling 'mark'."
  462. )
  463. }
  464. let differences = currentUsageStats.difference(to: initialUsageStats)
  465. let requestResults = rpcStats.requestResultCount.map { (key, value) in
  466. return Grpc_Testing_RequestResultCount.with {
  467. $0.statusCode = Int32(key.rawValue)
  468. $0.count = value
  469. }
  470. }
  471. return Grpc_Testing_WorkerService.Method.RunClient.Output.with {
  472. $0.stats = Grpc_Testing_ClientStats.with {
  473. $0.timeElapsed = differences.time
  474. $0.timeSystem = differences.systemTime
  475. $0.timeUser = differences.userTime
  476. $0.requestResults = requestResults
  477. $0.latencies = Grpc_Testing_HistogramData.with {
  478. $0.bucket = rpcStats.latencyHistogram.buckets
  479. $0.minSeen = rpcStats.latencyHistogram.minSeen
  480. $0.maxSeen = rpcStats.latencyHistogram.maxSeen
  481. $0.sum = rpcStats.latencyHistogram.sum
  482. $0.sumOfSquares = rpcStats.latencyHistogram.sumOfSquares
  483. $0.count = rpcStats.latencyHistogram.countOfValuesSeen
  484. }
  485. }
  486. }
  487. }
  488. }
  489. extension BenchmarkClient.RPCType {
  490. init?(_ rpcType: Grpc_Testing_RpcType) {
  491. switch rpcType {
  492. case .unary:
  493. self = .unary
  494. case .streaming:
  495. self = .streaming
  496. default:
  497. return nil
  498. }
  499. }
  500. }