GRPCChannelHandlerResponseCapturingTestCase.swift 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. /*
  2. * Copyright 2019, gRPC Authors All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Foundation
  17. import NIO
  18. import NIOHTTP1
  19. @testable import GRPC
  20. import EchoModel
  21. import EchoImplementation
  22. import XCTest
  23. import Logging
  24. class CollectingChannelHandler<OutboundIn>: ChannelOutboundHandler {
  25. var responses: [OutboundIn] = []
  26. func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
  27. promise?.succeed(())
  28. responses.append(unwrapOutboundIn(data))
  29. }
  30. }
  31. class CollectingServerErrorDelegate: ServerErrorDelegate {
  32. var errors: [Error] = []
  33. var asGRPCErrors: [GRPCError]? {
  34. return self.errors as? [GRPCError]
  35. }
  36. var asGRPCServerErrors: [GRPCServerError]? {
  37. return (self.asGRPCErrors?.map { $0.wrappedError }) as? [GRPCServerError]
  38. }
  39. var asGRPCCommonErrors: [GRPCCommonError]? {
  40. return (self.asGRPCErrors?.map { $0.wrappedError }) as? [GRPCCommonError]
  41. }
  42. func observeLibraryError(_ error: Error) {
  43. self.errors.append(error)
  44. }
  45. }
  46. class GRPCChannelHandlerResponseCapturingTestCase: GRPCTestCase {
  47. static let echoProvider: [String: CallHandlerProvider] = ["echo.Echo": EchoProvider()]
  48. class var defaultServiceProvider: [String: CallHandlerProvider] {
  49. return echoProvider
  50. }
  51. func configureChannel(withHandlers handlers: [ChannelHandler]) -> EventLoopFuture<EmbeddedChannel> {
  52. let channel = EmbeddedChannel()
  53. return channel.pipeline.addHandlers(handlers, position: .first)
  54. .map { _ in channel }
  55. }
  56. var errorCollector: CollectingServerErrorDelegate = CollectingServerErrorDelegate()
  57. /// Waits for `count` responses to be collected and then returns them. The test fails if the number
  58. /// of collected responses does not match the expected.
  59. ///
  60. /// - Parameters:
  61. /// - count: expected number of responses.
  62. /// - servicesByName: service providers keyed by their service name.
  63. /// - callback: a callback called after the channel has been setup, intended to "fill" the channel
  64. /// with messages. The callback is called before this function returns.
  65. /// - Returns: The responses collected from the pipeline.
  66. func waitForGRPCChannelHandlerResponses(
  67. count: Int,
  68. servicesByName: [String: CallHandlerProvider] = defaultServiceProvider,
  69. callback: @escaping (EmbeddedChannel) throws -> Void
  70. ) throws -> [RawGRPCServerResponsePart] {
  71. let collector = CollectingChannelHandler<RawGRPCServerResponsePart>()
  72. try configureChannel(withHandlers: [collector, GRPCChannelHandler(servicesByName: servicesByName, errorDelegate: errorCollector, logger: Logger(label: "io.grpc.testing"))])
  73. .flatMapThrowing(callback)
  74. .wait()
  75. XCTAssertEqual(count, collector.responses.count)
  76. return collector.responses
  77. }
  78. }