Browse Source

Normalize user-provider headers (#730)

Motivation:

HTTP/2 headers must be lowercased, however, HPACKHeaders does not
enforce this or normalize them for us. Metadata, i.e. headers provided
by the user should be normalized to respect this.

Motivated by: #722

Modifications:

- On the server: normalize all headers in the 2to1 server codec
- On the client: normalize custom metadata provided via call options
- Add tests

Result:

User provided headers are normalized to be lowercased.
George Barnett 5 years ago
parent
commit
b573b3bb12

+ 4 - 1
Sources/GRPC/GRPCClientStateMachine.swift

@@ -544,7 +544,10 @@ extension GRPCClientStateMachine.State {
     }
 
     // Add user-defined custom metadata: this should come after the call definition headers.
-    headers.add(contentsOf: customMetadata)
+    // TODO: make header normalization user-configurable.
+    headers.add(contentsOf: customMetadata.map { (name, value, indexing) in
+      return (name.lowercased(), value, indexing)
+    })
 
     return headers
   }

+ 1 - 1
Sources/GRPC/HTTPProtocolSwitcher.swift

@@ -133,7 +133,7 @@ extension HTTPProtocolSwitcher: ChannelInboundHandler, RemovableChannelHandler {
 
       case .http2:
         context.channel.configureHTTP2Pipeline(mode: .server) { (streamChannel, streamID) in
-            streamChannel.pipeline.addHandler(HTTP2ToHTTP1ServerCodec(streamID: streamID))
+            streamChannel.pipeline.addHandler(HTTP2ToHTTP1ServerCodec(streamID: streamID, normalizeHTTPHeaders: true))
               .flatMap { self.handlersInitializer(streamChannel) }
           }
           .map { _ in }

+ 37 - 0
Tests/GRPCTests/GRPCClientStateMachineTests.swift

@@ -678,6 +678,43 @@ extension GRPCClientStateMachineTests {
     }
   }
 
+  func testSendRequestHeadersNormalizesCustomMetadata() throws {
+    // `HPACKHeaders` uses case-insensitive lookup for header names so we can't check the equality
+    // for individual headers. We'll pull out the entries we care about by matching a sentinel value
+    // and then compare `HPACKHeaders` instances (since the equality check *is* case sensitive).
+    let filterKey = "a-key-for-filtering"
+    let customMetadata: HPACKHeaders = [
+      "partiallyLower": filterKey,
+      "ALLUPPER": filterKey
+    ]
+
+    var stateMachine = self.makeStateMachine(.clientIdleServerIdle(pendingWriteState: .one(), readArity: .one))
+    stateMachine.sendRequestHeaders(requestHead: .init(
+      method: "POST",
+      scheme: "http",
+      path: "/echo/Get",
+      host: "localhost",
+      timeout: .infinite,
+      customMetadata: customMetadata,
+      encoding: .disabled
+    )).assertSuccess { headers in
+      // Pull out the entries we care about by matching values
+      let filtered = headers.filter { (name, value, indexing) in
+        return value == filterKey
+      }.map { name, value, indexing in
+        return (name, value)
+      }
+
+      let justCustomMetadata = HPACKHeaders(filtered)
+      let expected: HPACKHeaders = [
+        "partiallylower": filterKey,
+        "allupper": filterKey
+      ]
+
+      XCTAssertEqual(justCustomMetadata, expected)
+    }
+  }
+
   func testSendRequestHeadersWithNoCompressionInEitherDirection() throws {
     var stateMachine = self.makeStateMachine(.clientIdleServerIdle(pendingWriteState: .one(), readArity: .one))
     stateMachine.sendRequestHeaders(requestHead: .init(

+ 174 - 0
Tests/GRPCTests/HeaderNormalizationTests.swift

@@ -0,0 +1,174 @@
+/*
+ * Copyright 2020, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+@testable import GRPC
+import EchoModel
+import EchoImplementation
+import NIO
+import NIOHTTP1
+import NIOHPACK
+import XCTest
+
+class EchoMetadataValidator: Echo_EchoProvider {
+  private func assertCustomMetadataIsLowercased(
+    _ headers: HTTPHeaders,
+    file: StaticString = #file,
+    line: UInt = #line
+  ) {
+    // Header lookup is case-insensitive so we need to pull out the values we know the client sent
+    // as custom-metadata and then compare a new set of headers.
+    let customMetadata = HTTPHeaders(headers.filter { name, value in value == "client" })
+    XCTAssertEqual(customMetadata, ["client": "client"], file: file, line: line)
+  }
+
+  func get(
+    request: Echo_EchoRequest,
+    context: StatusOnlyCallContext
+  ) -> EventLoopFuture<Echo_EchoResponse> {
+    self.assertCustomMetadataIsLowercased(context.request.headers)
+    context.trailingMetadata.add(name: "SERVER", value: "server")
+    return context.eventLoop.makeSucceededFuture(.with { $0.text = request.text })
+  }
+
+  func expand(
+    request: Echo_EchoRequest,
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<GRPCStatus> {
+    self.assertCustomMetadataIsLowercased(context.request.headers)
+    context.trailingMetadata.add(name: "SERVER", value: "server")
+    return context.eventLoop.makeSucceededFuture(.ok)
+  }
+
+  func collect(
+    context: UnaryResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    self.assertCustomMetadataIsLowercased(context.request.headers)
+    context.trailingMetadata.add(name: "SERVER", value: "server")
+    return context.eventLoop.makeSucceededFuture({ event in
+      switch event {
+      case .message:
+        ()
+      case .end:
+        context.responsePromise.succeed(.with { $0.text = "foo" })
+      }
+    })
+  }
+
+  func update(
+    context: StreamingResponseCallContext<Echo_EchoResponse>
+  ) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
+    self.assertCustomMetadataIsLowercased(context.request.headers)
+    context.trailingMetadata.add(name: "SERVER", value: "server")
+    return context.eventLoop.makeSucceededFuture({ event in
+      switch event {
+      case .message:
+        ()
+      case .end:
+        context.statusPromise.succeed(.ok)
+      }
+    })
+  }
+}
+
+class HeaderNormalizationTests: GRPCTestCase {
+  var group: EventLoopGroup!
+  var server: Server!
+  var channel: GRPCChannel!
+  var client: Echo_EchoServiceClient!
+
+  override func setUp() {
+    super.setUp()
+
+    self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
+
+    let serverConfig = Server.Configuration(
+      target: .hostAndPort("localhost", 0),
+      eventLoopGroup: self.group,
+      serviceProviders: [EchoMetadataValidator()]
+    )
+
+    self.server = try! Server.start(configuration: serverConfig).wait()
+
+    let clientConfig = ClientConnection.Configuration(
+      target: .hostAndPort("localhost", self.server.channel.localAddress!.port!),
+      eventLoopGroup: self.group
+    )
+
+    self.channel = ClientConnection(configuration: clientConfig)
+    self.client = Echo_EchoServiceClient(channel: self.channel)
+  }
+
+  override func tearDown() {
+    XCTAssertNoThrow(try self.channel.close().wait())
+    XCTAssertNoThrow(try self.server.close().wait())
+    XCTAssertNoThrow(try self.group.syncShutdownGracefully())
+  }
+
+  private func assertCustomMetadataIsLowercased(
+    _ headers: EventLoopFuture<HPACKHeaders>,
+    expectation: XCTestExpectation,
+    file: StaticString = #file,
+    line: UInt = #line
+  ) {
+    // Header lookup is case-insensitive so we need to pull out the values we know the server sent
+    // us as trailing-metadata and then compare a new set of headers.
+    headers.map { trailers -> HPACKHeaders in
+      let filtered = trailers.filter {
+        $0.value == "server"
+      }.map { (name, value, indexing) in
+        return (name, value)
+      }
+      return HPACKHeaders(filtered)
+    }.assertEqual(["server": "server"], fulfill: expectation, file: file, line: line)
+  }
+
+  func testHeadersAreNormalizedForUnary() throws {
+    let trailingMetadata = self.expectation(description: "received trailing metadata")
+    let options = CallOptions(customMetadata: ["CLIENT": "client"])
+    let rpc = self.client.get(.with { $0.text = "foo" }, callOptions: options)
+    self.assertCustomMetadataIsLowercased(rpc.trailingMetadata, expectation: trailingMetadata)
+    self.wait(for: [trailingMetadata], timeout: 1.0)
+  }
+
+  func testHeadersAreNormalizedForClientStreaming() throws {
+    let trailingMetadata = self.expectation(description: "received trailing metadata")
+    let options = CallOptions(customMetadata: ["CLIENT": "client"])
+    let rpc = self.client.collect(callOptions: options)
+    rpc.sendEnd(promise: nil)
+    self.assertCustomMetadataIsLowercased(rpc.trailingMetadata, expectation: trailingMetadata)
+    self.wait(for: [trailingMetadata], timeout: 1.0)
+  }
+
+  func testHeadersAreNormalizedForServerStreaming() throws {
+    let trailingMetadata = self.expectation(description: "received trailing metadata")
+    let options = CallOptions(customMetadata: ["CLIENT": "client"])
+    let rpc = self.client.expand(.with { $0.text = "foo" }, callOptions: options) {
+      XCTFail("unexpected response: \($0)")
+    }
+    self.assertCustomMetadataIsLowercased(rpc.trailingMetadata, expectation: trailingMetadata)
+    self.wait(for: [trailingMetadata], timeout: 1.0)
+  }
+
+  func testHeadersAreNormalizedForBidirectionalStreaming() throws {
+    let trailingMetadata = self.expectation(description: "received trailing metadata")
+    let options = CallOptions(customMetadata: ["CLIENT": "client"])
+    let rpc = self.client.update(callOptions: options) {
+      XCTFail("unexpected response: \($0)")
+    }
+    rpc.sendEnd(promise: nil)
+    self.assertCustomMetadataIsLowercased(rpc.trailingMetadata, expectation: trailingMetadata)
+    self.wait(for: [trailingMetadata], timeout: 1.0)
+  }
+}

+ 14 - 0
Tests/GRPCTests/XCTestManifests.swift

@@ -303,6 +303,7 @@ extension GRPCClientStateMachineTests {
         ("testSendRequestHeadersFromClientClosedServerIdle", testSendRequestHeadersFromClientClosedServerIdle),
         ("testSendRequestHeadersFromClosed", testSendRequestHeadersFromClosed),
         ("testSendRequestHeadersFromIdle", testSendRequestHeadersFromIdle),
+        ("testSendRequestHeadersNormalizesCustomMetadata", testSendRequestHeadersNormalizesCustomMetadata),
         ("testSendRequestHeadersWithNoCompressionForRequests", testSendRequestHeadersWithNoCompressionForRequests),
         ("testSendRequestHeadersWithNoCompressionForResponses", testSendRequestHeadersWithNoCompressionForResponses),
         ("testSendRequestHeadersWithNoCompressionInEitherDirection", testSendRequestHeadersWithNoCompressionInEitherDirection),
@@ -468,6 +469,18 @@ extension HTTP1ToGRPCServerCodecTests {
     ]
 }
 
+extension HeaderNormalizationTests {
+    // DO NOT MODIFY: This is autogenerated, use:
+    //   `swift test --generate-linuxmain`
+    // to regenerate.
+    static let __allTests__HeaderNormalizationTests = [
+        ("testHeadersAreNormalizedForBidirectionalStreaming", testHeadersAreNormalizedForBidirectionalStreaming),
+        ("testHeadersAreNormalizedForClientStreaming", testHeadersAreNormalizedForClientStreaming),
+        ("testHeadersAreNormalizedForServerStreaming", testHeadersAreNormalizedForServerStreaming),
+        ("testHeadersAreNormalizedForUnary", testHeadersAreNormalizedForUnary),
+    ]
+}
+
 extension ImmediatelyFailingProviderTests {
     // DO NOT MODIFY: This is autogenerated, use:
     //   `swift test --generate-linuxmain`
@@ -689,6 +702,7 @@ public func __allTests() -> [XCTestCaseEntry] {
         testCase(GRPCTimeoutTests.__allTests__GRPCTimeoutTests),
         testCase(GRPCTypeSizeTests.__allTests__GRPCTypeSizeTests),
         testCase(HTTP1ToGRPCServerCodecTests.__allTests__HTTP1ToGRPCServerCodecTests),
+        testCase(HeaderNormalizationTests.__allTests__HeaderNormalizationTests),
         testCase(ImmediatelyFailingProviderTests.__allTests__ImmediatelyFailingProviderTests),
         testCase(LengthPrefixedMessageReaderTests.__allTests__LengthPrefixedMessageReaderTests),
         testCase(MessageCompressionTests.__allTests__MessageCompressionTests),