diff --git a/.github/lock.yml b/.github/lock.yml deleted file mode 100644 index 78f7b19b71d..00000000000 --- a/.github/lock.yml +++ /dev/null @@ -1,2 +0,0 @@ -daysUntilLock: 180 -lockComment: false diff --git a/.github/mergeable.yml b/.github/mergeable.yml index d647dafb7ab..187de98277b 100644 --- a/.github/mergeable.yml +++ b/.github/mergeable.yml @@ -5,32 +5,17 @@ mergeable: - do: label must_include: regex: '^Type:' - fail: - - do: checks - status: 'failure' - payload: - title: 'Need an appropriate "Type:" label' - summary: 'Need an appropriate "Type:" label' - - when: pull_request.* - # This validator requires either the "no release notes" label OR a "Release" milestone - # to be considered successful. However, validators "pass" in mergeable only if all - # checks pass. So it is implemented in reverse. - # I.e.: !(!no_relnotes && !release_milestone) ==> no_relnotes || release_milestone - # If both validators pass, then it is considered a failure, and if either fails, it is - # considered a success. - validate: - - do: label - must_exclude: - regex: '^no release notes$' + - do: description + must_include: + # Allow: + # RELEASE NOTES: none (case insensitive) + # + # RELEASE NOTES: N/A (case insensitive) + # + # RELEASE NOTES: + # * + regex: '^RELEASE NOTES:\s*([Nn][Oo][Nn][Ee]|[Nn]/[Aa]|\n(\*|-)\s*.+)$' + regex_flag: 'm' - do: milestone - must_exclude: + must_include: regex: 'Release$' - pass: - - do: checks - status: 'failure' # fail on pass - payload: - title: 'Need Release milestone or "no release notes" label' - summary: 'Need Release milestone or "no release notes" label' - fail: - - do: checks - status: 'success' # pass on fail diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index 8f69dbc4fe8..00000000000 --- a/.github/stale.yml +++ /dev/null @@ -1,58 +0,0 @@ -# Configuration for probot-stale - https://github.com/probot/stale - -# Number of days of inactivity before an Issue or Pull Request becomes stale -daysUntilStale: 6 - -# Number of days of inactivity before an Issue or Pull Request with the stale label is closed. -# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale. -daysUntilClose: 7 - -# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled) -onlyLabels: - - "Status: Requires Reporter Clarification" - -# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable -exemptLabels: [] - -# Set to true to ignore issues in a project (defaults to false) -exemptProjects: false - -# Set to true to ignore issues in a milestone (defaults to false) -exemptMilestones: false - -# Set to true to ignore issues with an assignee (defaults to false) -exemptAssignees: false - -# Label to use when marking as stale -staleLabel: "stale" - -# Comment to post when marking as stale. Set to `false` to disable -markComment: > - This issue is labeled as requiring an update from the reporter, and no update has been received - after 6 days. If no update is provided in the next 7 days, this issue will be automatically closed. - -# Comment to post when removing the stale label. -# unmarkComment: > -# Your comment here. - -# Comment to post when closing a stale Issue or Pull Request. -# closeComment: > -# Your comment here. - -# Limit the number of actions per hour, from 1-30. Default is 30 -limitPerRun: 1 - -# Limit to only `issues` or `pulls` -# only: issues - -# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls': -# pulls: -# daysUntilStale: 30 -# markComment: > -# This pull request has been automatically marked as stale because it has not had -# recent activity. It will be closed if no further activity occurs. Thank you -# for your contributions. - -# issues: -# exemptLabels: -# - confirmed diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 0c3806bdc23..2a73b94079c 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -3,16 +3,20 @@ name: "CodeQL" on: push: branches: [ master ] - pull_request: - # The branches below must be a subset of the branches above - branches: [ master ] schedule: - cron: '24 20 * * 3' +permissions: + contents: read + security-events: write + pull-requests: read + actions: read + jobs: analyze: name: Analyze runs-on: ubuntu-latest + timeout-minutes: 30 strategy: fail-fast: false diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml new file mode 100644 index 00000000000..5f49c7900a3 --- /dev/null +++ b/.github/workflows/lock.yml @@ -0,0 +1,20 @@ +name: 'Lock Threads' + +on: + workflow_dispatch: + schedule: + - cron: '22 1 * * *' + +permissions: + issues: write + pull-requests: write + +jobs: + lock: + runs-on: ubuntu-latest + steps: + - uses: dessant/lock-threads@v2 + with: + github-token: ${{ github.token }} + issue-lock-inactive-days: 180 + pr-lock-inactive-days: 180 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000000..5e01a1e70c4 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,30 @@ +name: Stale bot + +on: + workflow_dispatch: + schedule: + - cron: "44 */2 * * *" + +jobs: + stale: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + + steps: + - uses: actions/stale@v4 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + days-before-stale: 6 + days-before-close: 7 + only-labels: 'Status: Requires Reporter Clarification' + stale-issue-label: 'stale' + stale-pr-label: 'stale' + operations-per-run: 999 + stale-issue-message: > + This issue is labeled as requiring an update from the reporter, and no update has been received + after 6 days. If no update is provided in the next 7 days, this issue will be automatically closed. + stale-pr-message: > + This PR is labeled as requiring an update from the reporter, and no update has been received + after 6 days. If no update is provided in the next 7 days, this issue will be automatically closed. diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 378e2846676..b687fdb3dc3 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -7,6 +7,9 @@ on: schedule: - cron: 0 0 * * * # daily at 00:00 +permissions: + contents: read + # Always force the use of Go modules env: GO111MODULE: on @@ -15,6 +18,7 @@ jobs: # Check generated protos match their source repos (optional for PRs). vet-proto: runs-on: ubuntu-latest + timeout-minutes: 20 steps: # Setup the environment. - name: Setup Go @@ -34,77 +38,80 @@ jobs: env: VET_SKIP_PROTO: 1 runs-on: ubuntu-latest + timeout-minutes: 20 strategy: matrix: include: - - type: vet - goversion: 1.15 - - type: race - goversion: 1.15 - - type: 386 - goversion: 1.15 - - type: retry - goversion: 1.15 + - type: vet+tests + goversion: 1.17 + + - type: tests + goversion: 1.17 + testflags: -race + - type: extras - goversion: 1.15 + goversion: 1.17 + - type: tests - goversion: 1.14 + goversion: 1.17 + goarch: 386 + + - type: tests + goversion: 1.17 + goarch: arm64 + - type: tests - goversion: 1.13 - - type: tests111 - goversion: 1.11 # Keep until interop tests no longer require Go1.11 + goversion: 1.16 + + - type: tests + goversion: 1.15 steps: # Setup the environment. - - name: Setup GOARCH=386 - if: ${{ matrix.type == '386' }} - run: echo "GOARCH=386" >> $GITHUB_ENV - - name: Setup RETRY - if: ${{ matrix.type == 'retry' }} - run: echo "GRPC_GO_RETRY=on" >> $GITHUB_ENV + - name: Setup GOARCH + if: matrix.goarch != '' + run: echo "GOARCH=${{ matrix.goarch }}" >> $GITHUB_ENV + + - name: Setup qemu emulator + if: matrix.goarch == 'arm64' + # setup qemu-user-static emulator and register it with binfmt_misc so that aarch64 binaries + # are automatically executed using qemu. + run: docker run --rm --privileged multiarch/qemu-user-static:5.2.0-2 --reset --credential yes --persistent yes + + - name: Setup GRPC environment + if: matrix.grpcenv != '' + run: echo "${{ matrix.grpcenv }}" >> $GITHUB_ENV + - name: Setup Go uses: actions/setup-go@v2 with: go-version: ${{ matrix.goversion }} + - name: Checkout repo uses: actions/checkout@v2 # Only run vet for 'vet' runs. - name: Run vet.sh - if: ${{ matrix.type == 'vet' }} + if: startsWith(matrix.type, 'vet') run: ./vet.sh -install && ./vet.sh - # Main tests run for everything except when testing "extras", the race - # detector and Go1.11 (where we run a reduced set of tests). + # Main tests run for everything except when testing "extras" + # (where we run a reduced set of tests). - name: Run tests - if: ${{ matrix.type != 'extras' && matrix.type != 'race' && matrix.type != 'tests111' }} + if: contains(matrix.type, 'tests') run: | go version - go test -cpu 1,4 -timeout 7m google.golang.org/grpc/... + go test ${{ matrix.testflags }} -cpu 1,4 -timeout 7m google.golang.org/grpc/... + cd ${GITHUB_WORKSPACE}/security/advancedtls && go test ${{ matrix.testflags }} -timeout 2m google.golang.org/grpc/security/advancedtls/... + cd ${GITHUB_WORKSPACE}/security/authorization && go test ${{ matrix.testflags }} -timeout 2m google.golang.org/grpc/security/authorization/... - # Race detector tests - - name: Run test race - if: ${{ matrix.TYPE == 'race' }} - run: | - go version - go test -race -cpu 1,4 -timeout 7m google.golang.org/grpc/... # Non-core gRPC tests (examples, interop, etc) - name: Run extras tests - if: ${{ matrix.TYPE == 'extras' }} + if: matrix.type == 'extras' run: | go version examples/examples_test.sh security/advancedtls/examples/examples_test.sh interop/interop_test.sh - cd ${GITHUB_WORKSPACE}/security/advancedtls && go test -cpu 1,4 -timeout 7m google.golang.org/grpc/security/advancedtls/... - cd ${GITHUB_WORKSPACE}/security/authorization && go test -cpu 1,4 -timeout 7m google.golang.org/grpc/security/authorization/... - - # Reduced set of tests for Go 1.11 - - name: Run Go1.11 tests - if: ${{ matrix.type == 'tests111' }} - run: | - go version - tests=$(find ${GITHUB_WORKSPACE} -name '*_test.go' | xargs -n1 dirname | sort -u | sed "s:^${GITHUB_WORKSPACE}:.:" | sed "s:\/$::" | grep -v ^./security | grep -v ^./credentials/sts | grep -v ^./credentials/tls/certprovider | grep -v ^./credentials/xds | grep -v ^./xds ) - echo "Running tests for " ${tests} - go test -cpu 1,4 -timeout 7m ${tests} + xds/internal/test/e2e/run.sh diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 5847d94e551..00000000000 --- a/.travis.yml +++ /dev/null @@ -1,42 +0,0 @@ -language: go - -matrix: - include: - - go: 1.14.x - env: VET=1 GO111MODULE=on - - go: 1.14.x - env: RACE=1 GO111MODULE=on - - go: 1.14.x - env: RUN386=1 - - go: 1.14.x - env: GRPC_GO_RETRY=on - - go: 1.14.x - env: TESTEXTRAS=1 - - go: 1.13.x - env: GO111MODULE=on - - go: 1.12.x - env: GO111MODULE=on - - go: 1.11.x # Keep until interop tests no longer require Go1.11 - env: GO111MODULE=on - -go_import_path: google.golang.org/grpc - -before_install: - - if [[ "${GO111MODULE}" = "on" ]]; then mkdir "${HOME}/go"; export GOPATH="${HOME}/go"; fi - - if [[ -n "${RUN386}" ]]; then export GOARCH=386; fi - - if [[ "${TRAVIS_EVENT_TYPE}" = "cron" && -z "${RUN386}" ]]; then RACE=1; fi - - if [[ "${TRAVIS_EVENT_TYPE}" != "cron" ]]; then export VET_SKIP_PROTO=1; fi - -install: - - try3() { eval "$*" || eval "$*" || eval "$*"; } - - try3 'if [[ "${GO111MODULE}" = "on" ]]; then go mod download; else make testdeps; fi' - - if [[ -n "${GAE}" ]]; then source ./install_gae.sh; make testappenginedeps; fi - - if [[ -n "${VET}" ]]; then ./vet.sh -install; fi - -script: - - set -e - - if [[ -n "${TESTEXTRAS}" ]]; then examples/examples_test.sh; security/advancedtls/examples/examples_test.sh; interop/interop_test.sh; make testsubmodule; exit 0; fi - - if [[ -n "${VET}" ]]; then ./vet.sh; fi - - if [[ -n "${GAE}" ]]; then make testappengine; exit 0; fi - - if [[ -n "${RACE}" ]]; then make testrace; exit 0; fi - - make test diff --git a/Documentation/server-reflection-tutorial.md b/Documentation/server-reflection-tutorial.md index b1781fa68dc..9f26656f22b 100644 --- a/Documentation/server-reflection-tutorial.md +++ b/Documentation/server-reflection-tutorial.md @@ -58,7 +58,7 @@ $ go run examples/features/reflection/server/main.go Open a new terminal and make sure you are in the directory where grpc_cli lives: ```sh -$ cd /bins/opt +$ cd /bins/opt ``` ### List services diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 093c82b3afe..c6672c0a3ef 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -8,17 +8,18 @@ See [CONTRIBUTING.md](https://github.com/grpc/grpc-community/blob/master/CONTRIB for general contribution guidelines. ## Maintainers (in alphabetical order) -- [canguler](https://github.com/canguler), Google LLC + - [cesarghali](https://github.com/cesarghali), Google LLC - [dfawley](https://github.com/dfawley), Google LLC - [easwars](https://github.com/easwars), Google LLC -- [jadekler](https://github.com/jadekler), Google LLC - [menghanl](https://github.com/menghanl), Google LLC - [srini100](https://github.com/srini100), Google LLC ## Emeritus Maintainers (in alphabetical order) - [adelez](https://github.com/adelez), Google LLC +- [canguler](https://github.com/canguler), Google LLC - [iamqizhao](https://github.com/iamqizhao), Google LLC +- [jadekler](https://github.com/jadekler), Google LLC - [jtattermusch](https://github.com/jtattermusch), Google LLC - [lyuxuan](https://github.com/lyuxuan), Google LLC - [makmukhi](https://github.com/makmukhi), Google LLC diff --git a/Makefile b/Makefile index 1f0722f1624..1f8960922b3 100644 --- a/Makefile +++ b/Makefile @@ -41,8 +41,6 @@ vetdeps: clean \ proto \ test \ - testappengine \ - testappenginedeps \ testrace \ vet \ vetdeps diff --git a/NOTICE.txt b/NOTICE.txt new file mode 100644 index 00000000000..530197749e9 --- /dev/null +++ b/NOTICE.txt @@ -0,0 +1,13 @@ +Copyright 2014 gRPC authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md index 3949a683fb5..0e6ae69a584 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,6 @@ errors. [Go module]: https://github.com/golang/go/wiki/Modules [gRPC]: https://grpc.io [Go gRPC docs]: https://grpc.io/docs/languages/go -[Performance benchmark]: https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5652536396611584&widget=490377658&container=1286539696 +[Performance benchmark]: https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5180705743044608 [quick start]: https://grpc.io/docs/languages/go/quickstart [go-releases]: https://golang.org/doc/devel/release.html diff --git a/admin/admin.go b/admin/admin.go new file mode 100644 index 00000000000..803a4b93534 --- /dev/null +++ b/admin/admin.go @@ -0,0 +1,58 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package admin provides a convenient method for registering a collection of +// administration services to a gRPC server. The services registered are: +// +// - Channelz: https://github.com/grpc/proposal/blob/master/A14-channelz.md +// +// - CSDS: https://github.com/grpc/proposal/blob/master/A40-csds-support.md +// +// Experimental +// +// Notice: All APIs in this package are experimental and may be removed in a +// later release. +package admin + +import ( + "google.golang.org/grpc" + channelzservice "google.golang.org/grpc/channelz/service" + internaladmin "google.golang.org/grpc/internal/admin" +) + +func init() { + // Add a list of default services to admin here. Optional services, like + // CSDS, will be added by other packages. + internaladmin.AddService(func(registrar grpc.ServiceRegistrar) (func(), error) { + channelzservice.RegisterChannelzServiceToServer(registrar) + return nil, nil + }) +} + +// Register registers the set of admin services to the given server. +// +// The returned cleanup function should be called to clean up the resources +// allocated for the service handlers after the server is stopped. +// +// Note that if `s` is not a *grpc.Server or a *xds.GRPCServer, CSDS will not be +// registered because CSDS generated code is old and doesn't support interface +// `grpc.ServiceRegistrar`. +// https://github.com/envoyproxy/go-control-plane/issues/403 +func Register(s grpc.ServiceRegistrar) (cleanup func(), _ error) { + return internaladmin.Register(s) +} diff --git a/security/advancedtls/sni_appengine.go b/admin/admin_test.go similarity index 63% rename from security/advancedtls/sni_appengine.go rename to admin/admin_test.go index fffbb0107dd..0ee4aade0f3 100644 --- a/security/advancedtls/sni_appengine.go +++ b/admin/admin_test.go @@ -1,8 +1,6 @@ -// +build appengine - /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +16,19 @@ * */ -package advancedtls +package admin_test import ( - "crypto/tls" + "testing" + + "google.golang.org/grpc/admin/test" + "google.golang.org/grpc/codes" ) -// buildGetCertificates is a no-op for appengine builds. -func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) { - return nil, nil +func TestRegisterNoCSDS(t *testing.T) { + test.RunRegisterTests(t, test.ExpectedStatusCodes{ + ChannelzCode: codes.OK, + // CSDS is not registered because xDS isn't imported. + CSDSCode: codes.Unimplemented, + }) } diff --git a/admin/test/admin_test.go b/admin/test/admin_test.go new file mode 100644 index 00000000000..f0f784bfdf3 --- /dev/null +++ b/admin/test/admin_test.go @@ -0,0 +1,38 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// This file has the same content as admin_test.go, difference is that this is +// in another package, and it imports "xds", so we can test that csds is +// registered when xds is imported. + +package test_test + +import ( + "testing" + + "google.golang.org/grpc/admin/test" + "google.golang.org/grpc/codes" + _ "google.golang.org/grpc/xds" +) + +func TestRegisterWithCSDS(t *testing.T) { + test.RunRegisterTests(t, test.ExpectedStatusCodes{ + ChannelzCode: codes.OK, + CSDSCode: codes.OK, + }) +} diff --git a/admin/test/utils.go b/admin/test/utils.go new file mode 100644 index 00000000000..1add8afa824 --- /dev/null +++ b/admin/test/utils.go @@ -0,0 +1,114 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package test contains test only functions for package admin. It's used by +// admin/admin_test.go and admin/test/admin_test.go. +package test + +import ( + "context" + "net" + "testing" + "time" + + v3statusgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + v3statuspb "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/admin" + channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/xds" + "google.golang.org/grpc/status" +) + +const ( + defaultTestTimeout = 10 * time.Second +) + +// ExpectedStatusCodes contains the expected status code for each RPC (can be +// OK). +type ExpectedStatusCodes struct { + ChannelzCode codes.Code + CSDSCode codes.Code +} + +// RunRegisterTests makes a client, runs the RPCs, and compares the status +// codes. +func RunRegisterTests(t *testing.T, ec ExpectedStatusCodes) { + nodeID := uuid.New().String() + bootstrapCleanup, err := xds.SetupBootstrapFile(xds.BootstrapOptions{ + Version: xds.TransportV3, + NodeID: nodeID, + ServerURI: "no.need.for.a.server", + }) + if err != nil { + t.Fatal(err) + } + defer bootstrapCleanup() + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("cannot create listener: %v", err) + } + + server := grpc.NewServer() + defer server.Stop() + cleanup, err := admin.Register(server) + if err != nil { + t.Fatalf("failed to register admin: %v", err) + } + defer cleanup() + go func() { + server.Serve(lis) + }() + + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("cannot connect to server: %v", err) + } + + t.Run("channelz", func(t *testing.T) { + if err := RunChannelz(conn); status.Code(err) != ec.ChannelzCode { + t.Fatalf("%s RPC failed with error %v, want code %v", "channelz", err, ec.ChannelzCode) + } + }) + t.Run("csds", func(t *testing.T) { + if err := RunCSDS(conn); status.Code(err) != ec.CSDSCode { + t.Fatalf("%s RPC failed with error %v, want code %v", "CSDS", err, ec.CSDSCode) + } + }) +} + +// RunChannelz makes a channelz RPC. +func RunChannelz(conn *grpc.ClientConn) error { + c := channelzpb.NewChannelzClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err := c.GetTopChannels(ctx, &channelzpb.GetTopChannelsRequest{}, grpc.WaitForReady(true)) + return err +} + +// RunCSDS makes a CSDS RPC. +func RunCSDS(conn *grpc.ClientConn) error { + c := v3statusgrpc.NewClientStatusDiscoveryServiceClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err := c.FetchClientStatus(ctx, &v3statuspb.ClientStatusRequest{}, grpc.WaitForReady(true)) + return err +} diff --git a/authz/rbac_translator.go b/authz/rbac_translator.go new file mode 100644 index 00000000000..039d76bc99d --- /dev/null +++ b/authz/rbac_translator.go @@ -0,0 +1,306 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package authz exposes methods to manage authorization within gRPC. +// +// Experimental +// +// Notice: This package is EXPERIMENTAL and may be changed or removed +// in a later release. +package authz + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" +) + +type header struct { + Key string + Values []string +} + +type peer struct { + Principals []string +} + +type request struct { + Paths []string + Headers []header +} + +type rule struct { + Name string + Source peer + Request request +} + +// Represents the SDK authorization policy provided by user. +type authorizationPolicy struct { + Name string + DenyRules []rule `json:"deny_rules"` + AllowRules []rule `json:"allow_rules"` +} + +func principalOr(principals []*v3rbacpb.Principal) *v3rbacpb.Principal { + return &v3rbacpb.Principal{ + Identifier: &v3rbacpb.Principal_OrIds{ + OrIds: &v3rbacpb.Principal_Set{ + Ids: principals, + }, + }, + } +} + +func permissionOr(permission []*v3rbacpb.Permission) *v3rbacpb.Permission { + return &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_OrRules{ + OrRules: &v3rbacpb.Permission_Set{ + Rules: permission, + }, + }, + } +} + +func permissionAnd(permission []*v3rbacpb.Permission) *v3rbacpb.Permission { + return &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_AndRules{ + AndRules: &v3rbacpb.Permission_Set{ + Rules: permission, + }, + }, + } +} + +func getStringMatcher(value string) *v3matcherpb.StringMatcher { + switch { + case value == "*": + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{}, + } + case strings.HasSuffix(value, "*"): + prefix := strings.TrimSuffix(value, "*") + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: prefix}, + } + case strings.HasPrefix(value, "*"): + suffix := strings.TrimPrefix(value, "*") + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: suffix}, + } + default: + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: value}, + } + } +} + +func getHeaderMatcher(key, value string) *v3routepb.HeaderMatcher { + switch { + case value == "*": + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{}, + } + case strings.HasSuffix(value, "*"): + prefix := strings.TrimSuffix(value, "*") + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: prefix}, + } + case strings.HasPrefix(value, "*"): + suffix := strings.TrimPrefix(value, "*") + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: suffix}, + } + default: + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: value}, + } + } +} + +func parsePrincipalNames(principalNames []string) []*v3rbacpb.Principal { + ps := make([]*v3rbacpb.Principal, 0, len(principalNames)) + for _, principalName := range principalNames { + newPrincipalName := &v3rbacpb.Principal{ + Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{ + PrincipalName: getStringMatcher(principalName), + }, + }} + ps = append(ps, newPrincipalName) + } + return ps +} + +func parsePeer(source peer) (*v3rbacpb.Principal, error) { + if len(source.Principals) > 0 { + return principalOr(parsePrincipalNames(source.Principals)), nil + } + return &v3rbacpb.Principal{ + Identifier: &v3rbacpb.Principal_Any{ + Any: true, + }, + }, nil +} + +func parsePaths(paths []string) []*v3rbacpb.Permission { + ps := make([]*v3rbacpb.Permission, 0, len(paths)) + for _, path := range paths { + newPath := &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{ + Rule: &v3matcherpb.PathMatcher_Path{Path: getStringMatcher(path)}}}} + ps = append(ps, newPath) + } + return ps +} + +func parseHeaderValues(key string, values []string) []*v3rbacpb.Permission { + vs := make([]*v3rbacpb.Permission, 0, len(values)) + for _, value := range values { + newHeader := &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_Header{ + Header: getHeaderMatcher(key, value)}} + vs = append(vs, newHeader) + } + return vs +} + +var unsupportedHeaders = map[string]bool{ + "host": true, + "connection": true, + "keep-alive": true, + "proxy-authenticate": true, + "proxy-authorization": true, + "te": true, + "trailer": true, + "transfer-encoding": true, + "upgrade": true, +} + +func unsupportedHeader(key string) bool { + return key[0] == ':' || strings.HasPrefix(key, "grpc-") || unsupportedHeaders[key] +} + +func parseHeaders(headers []header) ([]*v3rbacpb.Permission, error) { + hs := make([]*v3rbacpb.Permission, 0, len(headers)) + for i, header := range headers { + if header.Key == "" { + return nil, fmt.Errorf(`"headers" %d: "key" is not present`, i) + } + header.Key = strings.ToLower(header.Key) + if unsupportedHeader(header.Key) { + return nil, fmt.Errorf(`"headers" %d: unsupported "key" %s`, i, header.Key) + } + if len(header.Values) == 0 { + return nil, fmt.Errorf(`"headers" %d: "values" is not present`, i) + } + values := parseHeaderValues(header.Key, header.Values) + hs = append(hs, permissionOr(values)) + } + return hs, nil +} + +func parseRequest(request request) (*v3rbacpb.Permission, error) { + var and []*v3rbacpb.Permission + if len(request.Paths) > 0 { + and = append(and, permissionOr(parsePaths(request.Paths))) + } + if len(request.Headers) > 0 { + headers, err := parseHeaders(request.Headers) + if err != nil { + return nil, err + } + and = append(and, permissionAnd(headers)) + } + if len(and) > 0 { + return permissionAnd(and), nil + } + return &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_Any{ + Any: true, + }, + }, nil +} + +func parseRules(rules []rule, prefixName string) (map[string]*v3rbacpb.Policy, error) { + policies := make(map[string]*v3rbacpb.Policy) + for i, rule := range rules { + if rule.Name == "" { + return policies, fmt.Errorf(`%d: "name" is not present`, i) + } + principal, err := parsePeer(rule.Source) + if err != nil { + return nil, fmt.Errorf("%d: %v", i, err) + } + permission, err := parseRequest(rule.Request) + if err != nil { + return nil, fmt.Errorf("%d: %v", i, err) + } + policyName := prefixName + "_" + rule.Name + policies[policyName] = &v3rbacpb.Policy{ + Principals: []*v3rbacpb.Principal{principal}, + Permissions: []*v3rbacpb.Permission{permission}, + } + } + return policies, nil +} + +// translatePolicy translates SDK authorization policy in JSON format to two +// Envoy RBAC polices (deny followed by allow policy) or only one Envoy RBAC +// allow policy. If the input policy cannot be parsed or is invalid, an error +// will be returned. +func translatePolicy(policyStr string) ([]*v3rbacpb.RBAC, error) { + policy := &authorizationPolicy{} + d := json.NewDecoder(bytes.NewReader([]byte(policyStr))) + d.DisallowUnknownFields() + if err := d.Decode(policy); err != nil { + return nil, fmt.Errorf("failed to unmarshal policy: %v", err) + } + if policy.Name == "" { + return nil, fmt.Errorf(`"name" is not present`) + } + if len(policy.AllowRules) == 0 { + return nil, fmt.Errorf(`"allow_rules" is not present`) + } + rbacs := make([]*v3rbacpb.RBAC, 0, 2) + if len(policy.DenyRules) > 0 { + denyPolicies, err := parseRules(policy.DenyRules, policy.Name) + if err != nil { + return nil, fmt.Errorf(`"deny_rules" %v`, err) + } + denyRBAC := &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_DENY, + Policies: denyPolicies, + } + rbacs = append(rbacs, denyRBAC) + } + allowPolicies, err := parseRules(policy.AllowRules, policy.Name) + if err != nil { + return nil, fmt.Errorf(`"allow_rules" %v`, err) + } + allowRBAC := &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_ALLOW, Policies: allowPolicies} + return append(rbacs, allowRBAC), nil +} diff --git a/authz/rbac_translator_test.go b/authz/rbac_translator_test.go new file mode 100644 index 00000000000..9a883e9d78d --- /dev/null +++ b/authz/rbac_translator_test.go @@ -0,0 +1,273 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authz + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/testing/protocmp" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" +) + +func TestTranslatePolicy(t *testing.T) { + tests := map[string]struct { + authzPolicy string + wantErr string + wantPolicies []*v3rbacpb.RBAC + }{ + "valid policy": { + authzPolicy: `{ + "name": "authz", + "deny_rules": [ + { + "name": "deny_policy_1", + "source": { + "principals":[ + "spiffe://foo.abc", + "spiffe://bar*", + "*baz", + "spiffe://abc.*.com" + ] + } + }], + "allow_rules": [ + { + "name": "allow_policy_1", + "source": { + "principals":["*"] + }, + "request": { + "paths": ["path-foo*"] + } + }, + { + "name": "allow_policy_2", + "request": { + "paths": [ + "path-bar", + "*baz" + ], + "headers": [ + { + "key": "key-1", + "values": ["foo", "*bar"] + }, + { + "key": "key-2", + "values": ["baz*"] + } + ] + } + }] + }`, + wantPolicies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "authz_deny_policy_1": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://foo.abc"}, + }}, + }}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "spiffe://bar"}, + }}, + }}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}, + }}, + }}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://abc.*.com"}, + }}, + }}, + }, + }}}, + }, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + }, + }, + }, + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "authz_allow_policy_1": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{}, + }}, + }}, + }, + }}}, + }, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "path-foo"}, + }}}, + }}, + }, + }}}, + }, + }}}, + }, + }, + "authz_allow_policy_2": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "path-bar"}, + }}}, + }}, + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}, + }}}, + }}, + }, + }}}, + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "foo"}, + }, + }}, + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: "bar"}, + }, + }}, + }, + }}}, + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-2", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "baz"}, + }, + }}, + }, + }}}, + }, + }}}, + }, + }}}, + }, + }, + }, + }, + }, + }, + "unknown field": { + authzPolicy: `{"random": 123}`, + wantErr: "failed to unmarshal policy", + }, + "missing name field": { + authzPolicy: `{}`, + wantErr: `"name" is not present`, + }, + "invalid field type": { + authzPolicy: `{"name": 123}`, + wantErr: "failed to unmarshal policy", + }, + "missing allow rules field": { + authzPolicy: `{"name": "authz-foo"}`, + wantErr: `"allow_rules" is not present`, + }, + "missing rule name field": { + authzPolicy: `{ + "name": "authz-foo", + "allow_rules": [{}] + }`, + wantErr: `"allow_rules" 0: "name" is not present`, + }, + "missing header key": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [{ + "name": "allow_policy_1", + "request": {"headers":[{"key":"key-a", "values": ["value-a"]}, {}]} + }] + }`, + wantErr: `"allow_rules" 0: "headers" 1: "key" is not present`, + }, + "missing header values": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [{ + "name": "allow_policy_1", + "request": {"headers":[{"key":"key-a"}]} + }] + }`, + wantErr: `"allow_rules" 0: "headers" 0: "values" is not present`, + }, + "unsupported header": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [{ + "name": "allow_policy_1", + "request": {"headers":[{"key":":method", "values":["GET"]}]} + }] + }`, + wantErr: `"allow_rules" 0: "headers" 0: unsupported "key" :method`, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + gotPolicies, gotErr := translatePolicy(test.authzPolicy) + if gotErr != nil && !strings.HasPrefix(gotErr.Error(), test.wantErr) { + t.Fatalf("unexpected error\nwant:%v\ngot:%v", test.wantErr, gotErr) + } + if diff := cmp.Diff(gotPolicies, test.wantPolicies, protocmp.Transform()); diff != "" { + t.Fatalf("unexpected policy\ndiff (-want +got):\n%s", diff) + } + }) + } +} diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go new file mode 100644 index 00000000000..093b2bb437d --- /dev/null +++ b/authz/sdk_end2end_test.go @@ -0,0 +1,548 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authz_test + +import ( + "context" + "io" + "io/ioutil" + "net" + "os" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/authz" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + pb "google.golang.org/grpc/test/grpc_testing" +) + +type testServer struct { + pb.UnimplementedTestServiceServer +} + +func (s *testServer) UnaryCall(ctx context.Context, req *pb.SimpleRequest) (*pb.SimpleResponse, error) { + return &pb.SimpleResponse{}, nil +} + +func (s *testServer) StreamingInputCall(stream pb.TestService_StreamingInputCallServer) error { + for { + _, err := stream.Recv() + if err == io.EOF { + return stream.SendAndClose(&pb.StreamingInputCallResponse{}) + } + if err != nil { + return err + } + } +} + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +var sdkTests = map[string]struct { + authzPolicy string + md metadata.MD + wantStatus *status.Status +}{ + "DeniesRpcMatchInDenyNoMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_StreamingOutputCall", + "request": { + "paths": + [ + "/grpc.testing.TestService/StreamingOutputCall" + ] + } + } + ], + "deny_rules": + [ + { + "name": "deny_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/UnaryCall", + "/grpc.testing.TestService/StreamingInputCall" + ], + "headers": + [ + { + "key": "key-abc", + "values": + [ + "val-abc", + "val-def" + ] + } + ] + } + } + ] + }`, + md: metadata.Pairs("key-abc", "val-abc"), + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, + "DeniesRpcMatchInDenyAndAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/*" + ] + } + } + ], + "deny_rules": + [ + { + "name": "deny_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/*" + ] + } + } + ] + }`, + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, + "AllowsRpcNoMatchInDenyMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ], + "deny_rules": + [ + { + "name": "deny_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/UnaryCall", + "/grpc.testing.TestService/StreamingInputCall" + ], + "headers": + [ + { + "key": "key-abc", + "values": + [ + "val-abc", + "val-def" + ] + } + ] + } + } + ] + }`, + md: metadata.Pairs("key-xyz", "val-xyz"), + wantStatus: status.New(codes.OK, ""), + }, + "DeniesRpcNoMatchInDenyAndAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_some_user", + "source": { + "principals": + [ + "some_user" + ] + } + } + ], + "deny_rules": + [ + { + "name": "deny_StreamingOutputCall", + "request": { + "paths": + [ + "/grpc.testing.TestService/StreamingOutputCall" + ] + } + } + ] + }`, + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, + "AllowsRpcEmptyDenyMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_UnaryCall", + "request": + { + "paths": + [ + "/grpc.testing.TestService/UnaryCall" + ] + } + }, + { + "name": "allow_StreamingInputCall", + "request": + { + "paths": + [ + "/grpc.testing.TestService/StreamingInputCall" + ] + } + } + ] + }`, + wantStatus: status.New(codes.OK, ""), + }, + "DeniesRpcEmptyDenyNoMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_StreamingOutputCall", + "request": + { + "paths": + [ + "/grpc.testing.TestService/StreamingOutputCall" + ] + } + } + ] + }`, + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, +} + +func (s) TestSDKStaticPolicyEnd2End(t *testing.T) { + for name, test := range sdkTests { + t.Run(name, func(t *testing.T) { + // Start a gRPC server with SDK unary and stream server interceptors. + i, _ := authz.NewStatic(test.authzPolicy) + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, test.md) + + // Verifying authorization decision for Unary RPC. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + + // Verifying authorization decision for Streaming RPC. + stream, err := client.StreamingInputCall(ctx) + if err != nil { + t.Fatalf("failed StreamingInputCall err: %v", err) + } + req := &pb.StreamingInputCallRequest{ + Payload: &pb.Payload{ + Body: []byte("hi"), + }, + } + if err := stream.Send(req); err != nil && err != io.EOF { + t.Fatalf("failed stream.Send err: %v", err) + } + _, err = stream.CloseAndRecv() + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + }) + } +} + +func (s) TestSDKFileWatcherEnd2End(t *testing.T) { + for name, test := range sdkTests { + t.Run(name, func(t *testing.T) { + file := createTmpPolicyFile(t, name, []byte(test.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 1*time.Second) + defer i.Close() + + // Start a gRPC server with SDK unary and stream server interceptors. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, test.md) + + // Verifying authorization decision for Unary RPC. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + + // Verifying authorization decision for Streaming RPC. + stream, err := client.StreamingInputCall(ctx) + if err != nil { + t.Fatalf("failed StreamingInputCall err: %v", err) + } + req := &pb.StreamingInputCallRequest{ + Payload: &pb.Payload{ + Body: []byte("hi"), + }, + } + if err := stream.Send(req); err != nil && err != io.EOF { + t.Fatalf("failed stream.Send err: %v", err) + } + _, err = stream.CloseAndRecv() + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + }) + } +} + +func retryUntil(ctx context.Context, tsc pb.TestServiceClient, want *status.Status) (lastErr error) { + for ctx.Err() == nil { + _, lastErr = tsc.UnaryCall(ctx, &pb.SimpleRequest{}) + if s := status.Convert(lastErr); s.Code() == want.Code() && s.Message() == want.Message() { + return nil + } + time.Sleep(20 * time.Millisecond) + } + return lastErr +} + +func (s) TestSDKFileWatcher_ValidPolicyRefresh(t *testing.T) { + valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"] + file := createTmpPolicyFile(t, "valid_policy_refresh", []byte(valid1.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) + defer i.Close() + + // Start a gRPC server with SDK unary server interceptor. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err()) + } + + // Rewrite the file with a different valid authorization policy. + valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"] + if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Verifying authorization decision. + if got := retryUntil(ctx, client, valid2.wantStatus); got != nil { + t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got) + } +} + +func (s) TestSDKFileWatcher_InvalidPolicySkipReload(t *testing.T) { + valid := sdkTests["DeniesRpcMatchInDenyAndAllow"] + file := createTmpPolicyFile(t, "invalid_policy_skip_reload", []byte(valid.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 20*time.Millisecond) + defer i.Close() + + // Start a gRPC server with SDK unary server interceptors. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err()) + } + + // Skips the invalid policy update, and continues to use the valid policy. + if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Wait 40 ms for background go routine to read updated files. + time.Sleep(40 * time.Millisecond) + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err()) + } +} + +func (s) TestSDKFileWatcher_RecoversFromReloadFailure(t *testing.T) { + valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"] + file := createTmpPolicyFile(t, "recovers_from_reload_failure", []byte(valid1.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) + defer i.Close() + + // Start a gRPC server with SDK unary server interceptors. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err()) + } + + // Skips the invalid policy update, and continues to use the valid policy. + if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Wait 120 ms for background go routine to read updated files. + time.Sleep(120 * time.Millisecond) + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err()) + } + + // Rewrite the file with a different valid authorization policy. + valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"] + if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Verifying authorization decision. + if got := retryUntil(ctx, client, valid2.wantStatus); got != nil { + t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got) + } +} diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go new file mode 100644 index 00000000000..72dc14ed85e --- /dev/null +++ b/authz/sdk_server_interceptors.go @@ -0,0 +1,172 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package authz + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "sync/atomic" + "time" + "unsafe" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/xds/rbac" + "google.golang.org/grpc/status" +) + +var logger = grpclog.Component("authz") + +// StaticInterceptor contains engines used to make authorization decisions. It +// either contains two engines deny engine followed by an allow engine or only +// one allow engine. +type StaticInterceptor struct { + engines rbac.ChainEngine +} + +// NewStatic returns a new StaticInterceptor from a static authorization policy +// JSON string. +func NewStatic(authzPolicy string) (*StaticInterceptor, error) { + rbacs, err := translatePolicy(authzPolicy) + if err != nil { + return nil, err + } + chainEngine, err := rbac.NewChainEngine(rbacs) + if err != nil { + return nil, err + } + return &StaticInterceptor{*chainEngine}, nil +} + +// UnaryInterceptor intercepts incoming Unary RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + err := i.engines.IsAuthorized(ctx) + if err != nil { + if status.Code(err) == codes.PermissionDenied { + return nil, status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") + } + return nil, err + } + return handler(ctx, req) +} + +// StreamInterceptor intercepts incoming Stream RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + err := i.engines.IsAuthorized(ss.Context()) + if err != nil { + if status.Code(err) == codes.PermissionDenied { + return status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") + } + return err + } + return handler(srv, ss) +} + +// FileWatcherInterceptor contains details used to make authorization decisions +// by watching a file path that contains authorization policy in JSON format. +type FileWatcherInterceptor struct { + internalInterceptor unsafe.Pointer // *StaticInterceptor + policyFile string + policyContents []byte + refreshDuration time.Duration + cancel context.CancelFunc +} + +// NewFileWatcher returns a new FileWatcherInterceptor from a policy file +// that contains JSON string of authorization policy and a refresh duration to +// specify the amount of time between policy refreshes. +func NewFileWatcher(file string, duration time.Duration) (*FileWatcherInterceptor, error) { + if file == "" { + return nil, fmt.Errorf("authorization policy file path is empty") + } + if duration <= time.Duration(0) { + return nil, fmt.Errorf("requires refresh interval(%v) greater than 0s", duration) + } + i := &FileWatcherInterceptor{policyFile: file, refreshDuration: duration} + if err := i.updateInternalInterceptor(); err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) + i.cancel = cancel + // Create a background go routine for policy refresh. + go i.run(ctx) + return i, nil +} + +func (i *FileWatcherInterceptor) run(ctx context.Context) { + ticker := time.NewTicker(i.refreshDuration) + for { + if err := i.updateInternalInterceptor(); err != nil { + logger.Warningf("authorization policy reload status err: %v", err) + } + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + } + } +} + +// updateInternalInterceptor checks if the policy file that is watching has changed, +// and if so, updates the internalInterceptor with the policy. Unlike the +// constructor, if there is an error in reading the file or parsing the policy, the +// previous internalInterceptors will not be replaced. +func (i *FileWatcherInterceptor) updateInternalInterceptor() error { + policyContents, err := ioutil.ReadFile(i.policyFile) + if err != nil { + return fmt.Errorf("policyFile(%s) read failed: %v", i.policyFile, err) + } + if bytes.Equal(i.policyContents, policyContents) { + return nil + } + i.policyContents = policyContents + policyContentsString := string(policyContents) + interceptor, err := NewStatic(policyContentsString) + if err != nil { + return err + } + atomic.StorePointer(&i.internalInterceptor, unsafe.Pointer(interceptor)) + logger.Infof("authorization policy reload status: successfully loaded new policy %v", policyContentsString) + return nil +} + +// Close cleans up resources allocated by the interceptor. +func (i *FileWatcherInterceptor) Close() { + i.cancel() +} + +// UnaryInterceptor intercepts incoming Unary RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *FileWatcherInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).UnaryInterceptor(ctx, req, info, handler) +} + +// StreamInterceptor intercepts incoming Stream RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *FileWatcherInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).StreamInterceptor(srv, ss, info, handler) +} diff --git a/authz/sdk_server_interceptors_test.go b/authz/sdk_server_interceptors_test.go new file mode 100644 index 00000000000..f43f9807612 --- /dev/null +++ b/authz/sdk_server_interceptors_test.go @@ -0,0 +1,121 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authz_test + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "testing" + "time" + + "google.golang.org/grpc/authz" +) + +func createTmpPolicyFile(t *testing.T, dirSuffix string, policy []byte) string { + t.Helper() + + // Create a temp directory. Passing an empty string for the first argument + // uses the system temp directory. + dir, err := ioutil.TempDir("", dirSuffix) + if err != nil { + t.Fatalf("ioutil.TempDir() failed: %v", err) + } + t.Logf("Using tmpdir: %s", dir) + // Write policy into file. + filename := path.Join(dir, "policy.json") + if err := ioutil.WriteFile(filename, policy, os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", filename, err) + } + t.Logf("Wrote policy %s to file at %s", string(policy), filename) + return filename +} + +func (s) TestNewStatic(t *testing.T) { + tests := map[string]struct { + authzPolicy string + wantErr error + }{ + "InvalidPolicyFailsToCreateInterceptor": { + authzPolicy: `{}`, + wantErr: fmt.Errorf(`"name" is not present`), + }, + "ValidPolicyCreatesInterceptor": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ] + }`, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if _, err := authz.NewStatic(test.authzPolicy); fmt.Sprint(err) != fmt.Sprint(test.wantErr) { + t.Fatalf("NewStatic(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr) + } + }) + } +} + +func (s) TestNewFileWatcher(t *testing.T) { + tests := map[string]struct { + authzPolicy string + refreshDuration time.Duration + wantErr error + }{ + "InvalidRefreshDurationFailsToCreateInterceptor": { + refreshDuration: time.Duration(0), + wantErr: fmt.Errorf("requires refresh interval(0s) greater than 0s"), + }, + "InvalidPolicyFailsToCreateInterceptor": { + authzPolicy: `{}`, + refreshDuration: time.Duration(1), + wantErr: fmt.Errorf(`"name" is not present`), + }, + "ValidPolicyCreatesInterceptor": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ] + }`, + refreshDuration: time.Duration(1), + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + file := createTmpPolicyFile(t, name, []byte(test.authzPolicy)) + i, err := authz.NewFileWatcher(file, test.refreshDuration) + if fmt.Sprint(err) != fmt.Sprint(test.wantErr) { + t.Fatalf("NewFileWatcher(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr) + } + if i != nil { + i.Close() + } + }) + } +} diff --git a/balancer/balancer.go b/balancer/balancer.go index ab531f4c0b8..178de0898aa 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -75,24 +75,26 @@ func Get(name string) Builder { return nil } -// SubConn represents a gRPC sub connection. -// Each sub connection contains a list of addresses. gRPC will -// try to connect to them (in sequence), and stop trying the -// remainder once one connection is successful. +// A SubConn represents a single connection to a gRPC backend service. // -// The reconnect backoff will be applied on the list, not a single address. -// For example, try_on_all_addresses -> backoff -> try_on_all_addresses. +// Each SubConn contains a list of addresses. // -// All SubConns start in IDLE, and will not try to connect. To trigger -// the connecting, Balancers must call Connect. -// When the connection encounters an error, it will reconnect immediately. -// When the connection becomes IDLE, it will not reconnect unless Connect is -// called. +// All SubConns start in IDLE, and will not try to connect. To trigger the +// connecting, Balancers must call Connect. If a connection re-enters IDLE, +// Balancers must call Connect again to trigger a new connection attempt. // -// This interface is to be implemented by gRPC. Users should not need a -// brand new implementation of this interface. For the situations like -// testing, the new implementation should embed this interface. This allows -// gRPC to add new methods to this interface. +// gRPC will try to connect to the addresses in sequence, and stop trying the +// remainder once the first connection is successful. If an attempt to connect +// to all addresses encounters an error, the SubConn will enter +// TRANSIENT_FAILURE for a backoff period, and then transition to IDLE. +// +// Once established, if a connection is lost, the SubConn will transition +// directly to IDLE. +// +// This interface is to be implemented by gRPC. Users should not need their own +// implementation of this interface. For situations like testing, any +// implementations should embed this interface. This allows gRPC to add new +// methods to this interface. type SubConn interface { // UpdateAddresses updates the addresses used in this SubConn. // gRPC checks if currently-connected address is still in the new list. @@ -326,6 +328,20 @@ type Balancer interface { Close() } +// ExitIdler is an optional interface for balancers to implement. If +// implemented, ExitIdle will be called when ClientConn.Connect is called, if +// the ClientConn is idle. If unimplemented, ClientConn.Connect will cause +// all SubConns to connect. +// +// Notice: it will be required for all balancers to implement this in a future +// release. +type ExitIdler interface { + // ExitIdle instructs the LB policy to reconnect to backends / exit the + // IDLE state, if appropriate and possible. Note that SubConns that enter + // the IDLE state will not reconnect until SubConn.Connect is called. + ExitIdle() +} + // SubConnState describes the state of a SubConn. type SubConnState struct { // ConnectivityState is the connectivity state of the SubConn. @@ -353,8 +369,10 @@ var ErrBadResolverState = errors.New("bad resolver state") // // It's not thread safe. type ConnectivityStateEvaluator struct { - numReady uint64 // Number of addrConns in ready state. - numConnecting uint64 // Number of addrConns in connecting state. + numReady uint64 // Number of addrConns in ready state. + numConnecting uint64 // Number of addrConns in connecting state. + numTransientFailure uint64 // Number of addrConns in transient failure state. + numIdle uint64 // Number of addrConns in idle state. } // RecordTransition records state change happening in subConn and based on that @@ -362,9 +380,11 @@ type ConnectivityStateEvaluator struct { // // - If at least one SubConn in Ready, the aggregated state is Ready; // - Else if at least one SubConn in Connecting, the aggregated state is Connecting; -// - Else the aggregated state is TransientFailure. +// - Else if at least one SubConn is TransientFailure, the aggregated state is Transient Failure; +// - Else if at least one SubConn is Idle, the aggregated state is Idle; +// - Else there are no subconns and the aggregated state is Transient Failure // -// Idle and Shutdown are not considered. +// Shutdown is not considered. func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState connectivity.State) connectivity.State { // Update counters. for idx, state := range []connectivity.State{oldState, newState} { @@ -374,6 +394,10 @@ func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState conne cse.numReady += updateVal case connectivity.Connecting: cse.numConnecting += updateVal + case connectivity.TransientFailure: + cse.numTransientFailure += updateVal + case connectivity.Idle: + cse.numIdle += updateVal } } @@ -384,5 +408,11 @@ func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState conne if cse.numConnecting > 0 { return connectivity.Connecting } + if cse.numTransientFailure > 0 { + return connectivity.TransientFailure + } + if cse.numIdle > 0 { + return connectivity.Idle + } return connectivity.TransientFailure } diff --git a/balancer/base/balancer.go b/balancer/base/balancer.go index 383d02ec2bf..8dd504299fe 100644 --- a/balancer/base/balancer.go +++ b/balancer/base/balancer.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" + "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" @@ -41,7 +42,7 @@ func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) cc: cc, pickerBuilder: bb.pickerBuilder, - subConns: make(map[resolver.Address]balancer.SubConn), + subConns: make(map[resolver.Address]subConnInfo), scStates: make(map[balancer.SubConn]connectivity.State), csEvltr: &balancer.ConnectivityStateEvaluator{}, config: bb.config, @@ -57,6 +58,11 @@ func (bb *baseBuilder) Name() string { return bb.name } +type subConnInfo struct { + subConn balancer.SubConn + attrs *attributes.Attributes +} + type baseBalancer struct { cc balancer.ClientConn pickerBuilder PickerBuilder @@ -64,7 +70,7 @@ type baseBalancer struct { csEvltr *balancer.ConnectivityStateEvaluator state connectivity.State - subConns map[resolver.Address]balancer.SubConn // `attributes` is stripped from the keys of this map (the addresses) + subConns map[resolver.Address]subConnInfo // `attributes` is stripped from the keys of this map (the addresses) scStates map[balancer.SubConn]connectivity.State picker balancer.Picker config Config @@ -114,7 +120,7 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error { aNoAttrs := a aNoAttrs.Attributes = nil addrsSet[aNoAttrs] = struct{}{} - if sc, ok := b.subConns[aNoAttrs]; !ok { + if scInfo, ok := b.subConns[aNoAttrs]; !ok { // a is a new address (not existing in b.subConns). // // When creating SubConn, the original address with attributes is @@ -125,8 +131,9 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error { logger.Warningf("base.baseBalancer: failed to create new SubConn: %v", err) continue } - b.subConns[aNoAttrs] = sc + b.subConns[aNoAttrs] = subConnInfo{subConn: sc, attrs: a.Attributes} b.scStates[sc] = connectivity.Idle + b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle) sc.Connect() } else { // Always update the subconn's address in case the attributes @@ -135,13 +142,15 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error { // The SubConn does a reflect.DeepEqual of the new and old // addresses. So this is a noop if the current address is the same // as the old one (including attributes). - b.cc.UpdateAddresses(sc, []resolver.Address{a}) + scInfo.attrs = a.Attributes + b.subConns[aNoAttrs] = scInfo + b.cc.UpdateAddresses(scInfo.subConn, []resolver.Address{a}) } } - for a, sc := range b.subConns { + for a, scInfo := range b.subConns { // a was removed by resolver. if _, ok := addrsSet[a]; !ok { - b.cc.RemoveSubConn(sc) + b.cc.RemoveSubConn(scInfo.subConn) delete(b.subConns, a) // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. // The entry will be deleted in UpdateSubConnState. @@ -184,9 +193,10 @@ func (b *baseBalancer) regeneratePicker() { readySCs := make(map[balancer.SubConn]SubConnInfo) // Filter out all ready SCs from full subConn map. - for addr, sc := range b.subConns { - if st, ok := b.scStates[sc]; ok && st == connectivity.Ready { - readySCs[sc] = SubConnInfo{Address: addr} + for addr, scInfo := range b.subConns { + if st, ok := b.scStates[scInfo.subConn]; ok && st == connectivity.Ready { + addr.Attributes = scInfo.attrs + readySCs[scInfo.subConn] = SubConnInfo{Address: addr} } } b.picker = b.pickerBuilder.Build(PickerBuildInfo{ReadySCs: readySCs}) @@ -204,10 +214,14 @@ func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Su } return } - if oldS == connectivity.TransientFailure && s == connectivity.Connecting { - // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent + if oldS == connectivity.TransientFailure && + (s == connectivity.Connecting || s == connectivity.Idle) { + // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or // CONNECTING transitions to prevent the aggregated state from being // always CONNECTING when many backends exist but are all down. + if s == connectivity.Idle { + sc.Connect() + } return } b.scStates[sc] = s @@ -233,7 +247,6 @@ func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Su b.state == connectivity.TransientFailure { b.regeneratePicker() } - b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) } @@ -242,6 +255,11 @@ func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Su func (b *baseBalancer) Close() { } +// ExitIdle is a nop because the base balancer attempts to stay connected to +// all SubConns at all times. +func (b *baseBalancer) ExitIdle() { +} + // NewErrPicker returns a Picker that always returns err on Pick(). func NewErrPicker(err error) balancer.Picker { return &errPicker{err: err} diff --git a/balancer/base/balancer_test.go b/balancer/base/balancer_test.go index 03114251a04..f8ff8cf9844 100644 --- a/balancer/base/balancer_test.go +++ b/balancer/base/balancer_test.go @@ -23,6 +23,7 @@ import ( "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/resolver" ) @@ -35,12 +36,24 @@ func (c *testClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewS return c.newSubConn(addrs, opts) } +func (c *testClientConn) UpdateState(balancer.State) {} + type testSubConn struct{} func (sc *testSubConn) UpdateAddresses(addresses []resolver.Address) {} func (sc *testSubConn) Connect() {} +// testPickBuilder creates balancer.Picker for test. +type testPickBuilder struct { + validate func(info PickerBuildInfo) +} + +func (p *testPickBuilder) Build(info PickerBuildInfo) balancer.Picker { + p.validate(info) + return nil +} + func TestBaseBalancerStripAttributes(t *testing.T) { b := (&baseBuilder{}).Build(&testClientConn{ newSubConn: func(addrs []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) { @@ -64,7 +77,46 @@ func TestBaseBalancerStripAttributes(t *testing.T) { for addr := range b.subConns { if addr.Attributes != nil { - t.Errorf("in b.subConns, got address %+v with nil attributes, want not nil", addr) + t.Errorf("in b.subConns, got address %+v with not nil attributes, want nil", addr) + } + } +} + +func TestBaseBalancerReserveAttributes(t *testing.T) { + var v = func(info PickerBuildInfo) { + for _, sc := range info.ReadySCs { + if sc.Address.Addr == "1.1.1.1" { + if sc.Address.Attributes == nil { + t.Errorf("in picker.validate, got address %+v with nil attributes, want not nil", sc.Address) + } + foo, ok := sc.Address.Attributes.Value("foo").(string) + if !ok || foo != "2233niang" { + t.Errorf("in picker.validate, got address[1.1.1.1] with invalid attributes value %v, want 2233niang", sc.Address.Attributes.Value("foo")) + } + } else if sc.Address.Addr == "2.2.2.2" { + if sc.Address.Attributes != nil { + t.Error("in b.subConns, got address[2.2.2.2] with not nil attributes, want nil") + } + } } } + pickBuilder := &testPickBuilder{validate: v} + b := (&baseBuilder{pickerBuilder: pickBuilder}).Build(&testClientConn{ + newSubConn: func(addrs []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) { + return &testSubConn{}, nil + }, + }, balancer.BuildOptions{}).(*baseBalancer) + + b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {Addr: "1.1.1.1", Attributes: attributes.New("foo", "2233niang")}, + {Addr: "2.2.2.2", Attributes: nil}, + }, + }, + }) + + for sc := range b.scStates { + b.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready, ConnectionError: nil}) + } } diff --git a/balancer/grpclb/grpc_lb_v1/load_balancer_grpc.pb.go b/balancer/grpclb/grpc_lb_v1/load_balancer_grpc.pb.go index d56b77cca63..50cc9da4a90 100644 --- a/balancer/grpclb/grpc_lb_v1/load_balancer_grpc.pb.go +++ b/balancer/grpclb/grpc_lb_v1/load_balancer_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/lb/v1/load_balancer.proto package grpc_lb_v1 diff --git a/balancer/grpclb/grpclb.go b/balancer/grpclb/grpclb.go index a43d8964119..fe423af182a 100644 --- a/balancer/grpclb/grpclb.go +++ b/balancer/grpclb/grpclb.go @@ -25,6 +25,7 @@ package grpclb import ( "context" "errors" + "fmt" "sync" "time" @@ -134,6 +135,7 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal lb := &lbBalancer{ cc: newLBCacheClientConn(cc), + dialTarget: opt.Target.Endpoint, target: opt.Target.Endpoint, opt: opt, fallbackTimeout: b.fallbackTimeout, @@ -163,9 +165,10 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal } type lbBalancer struct { - cc *lbCacheClientConn - target string - opt balancer.BuildOptions + cc *lbCacheClientConn + dialTarget string // user's dial target + target string // same as dialTarget unless overridden in service config + opt balancer.BuildOptions usePickFirst bool @@ -221,6 +224,7 @@ type lbBalancer struct { // when resolved address updates are received, and read in the goroutine // handling fallback. resolvedBackendAddrs []resolver.Address + connErr error // the last connection error } // regeneratePicker takes a snapshot of the balancer, and generates a picker from @@ -230,7 +234,7 @@ type lbBalancer struct { // Caller must hold lb.mu. func (lb *lbBalancer) regeneratePicker(resetDrop bool) { if lb.state == connectivity.TransientFailure { - lb.picker = &errPicker{err: balancer.ErrTransientFailure} + lb.picker = &errPicker{err: fmt.Errorf("all SubConns are in TransientFailure, last connection error: %v", lb.connErr)} return } @@ -336,6 +340,8 @@ func (lb *lbBalancer) UpdateSubConnState(sc balancer.SubConn, scs balancer.SubCo // When an address was removed by resolver, b called RemoveSubConn but // kept the sc's state in scStates. Remove state for this sc here. delete(lb.scStates, sc) + case connectivity.TransientFailure: + lb.connErr = scs.ConnectionError } // Force regenerate picker if // - this sc became ready from not-ready @@ -394,6 +400,30 @@ func (lb *lbBalancer) handleServiceConfig(gc *grpclbServiceConfig) { lb.mu.Lock() defer lb.mu.Unlock() + // grpclb uses the user's dial target to populate the `Name` field of the + // `InitialLoadBalanceRequest` message sent to the remote balancer. But when + // grpclb is used a child policy in the context of RLS, we want the `Name` + // field to be populated with the value received from the RLS server. To + // support this use case, an optional "target_name" field has been added to + // the grpclb LB policy's config. If specified, it overrides the name of + // the target to be sent to the remote balancer; if not, the target to be + // sent to the balancer will continue to be obtained from the target URI + // passed to the gRPC client channel. Whenever that target to be sent to the + // balancer is updated, we need to restart the stream to the balancer as + // this target is sent in the first message on the stream. + if gc != nil { + target := lb.dialTarget + if gc.TargetName != "" { + target = gc.TargetName + } + if target != lb.target { + lb.target = target + if lb.ccRemoteLB != nil { + lb.ccRemoteLB.cancelRemoteBalancerCall() + } + } + } + newUsePickFirst := childIsPickFirst(gc) if lb.usePickFirst == newUsePickFirst { return @@ -484,3 +514,5 @@ func (lb *lbBalancer) Close() { } lb.cc.close() } + +func (lb *lbBalancer) ExitIdle() {} diff --git a/balancer/grpclb/grpclb_config.go b/balancer/grpclb/grpclb_config.go index aac3719631b..b4e23dee017 100644 --- a/balancer/grpclb/grpclb_config.go +++ b/balancer/grpclb/grpclb_config.go @@ -34,6 +34,7 @@ const ( type grpclbServiceConfig struct { serviceconfig.LoadBalancingConfig ChildPolicy *[]map[string]json.RawMessage + TargetName string } func (b *lbBuilder) ParseConfig(lbConfig json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { diff --git a/balancer/grpclb/grpclb_config_test.go b/balancer/grpclb/grpclb_config_test.go index 5a45de90494..0db2299157e 100644 --- a/balancer/grpclb/grpclb_config_test.go +++ b/balancer/grpclb/grpclb_config_test.go @@ -20,52 +20,68 @@ package grpclb import ( "encoding/json" - "errors" - "fmt" - "reflect" - "strings" "testing" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/serviceconfig" ) func (s) TestParse(t *testing.T) { tests := []struct { name string - s string + sc string want serviceconfig.LoadBalancingConfig - wantErr error + wantErr bool }{ { name: "empty", - s: "", + sc: "", want: nil, - wantErr: errors.New("unexpected end of JSON input"), + wantErr: true, }, { name: "success1", - s: `{"childPolicy":[{"pick_first":{}}]}`, + sc: ` +{ + "childPolicy": [ + {"pick_first":{}} + ], + "targetName": "foo-service" +}`, want: &grpclbServiceConfig{ ChildPolicy: &[]map[string]json.RawMessage{ {"pick_first": json.RawMessage("{}")}, }, + TargetName: "foo-service", }, }, { name: "success2", - s: `{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`, + sc: ` +{ + "childPolicy": [ + {"round_robin":{}}, + {"pick_first":{}} + ], + "targetName": "foo-service" +}`, want: &grpclbServiceConfig{ ChildPolicy: &[]map[string]json.RawMessage{ {"round_robin": json.RawMessage("{}")}, {"pick_first": json.RawMessage("{}")}, }, + TargetName: "foo-service", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got, err := (&lbBuilder{}).ParseConfig(json.RawMessage(tt.s)); !reflect.DeepEqual(got, tt.want) || !strings.Contains(fmt.Sprint(err), fmt.Sprint(tt.wantErr)) { - t.Errorf("parseFullServiceConfig() = %+v, %+v, want %+v, ", got, err, tt.want, tt.wantErr) + got, err := (&lbBuilder{}).ParseConfig(json.RawMessage(tt.sc)) + if (err != nil) != (tt.wantErr) { + t.Fatalf("ParseConfig(%q) returned error: %v, wantErr: %v", tt.sc, err, tt.wantErr) + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatalf("ParseConfig(%q) returned unexpected difference (-want +got):\n%s", tt.sc, diff) } }) } diff --git a/balancer/grpclb/grpclb_remote_balancer.go b/balancer/grpclb/grpclb_remote_balancer.go index 5ac8d86bd57..0210c012d7b 100644 --- a/balancer/grpclb/grpclb_remote_balancer.go +++ b/balancer/grpclb/grpclb_remote_balancer.go @@ -206,6 +206,9 @@ type remoteBalancerCCWrapper struct { backoff backoff.Strategy done chan struct{} + streamMu sync.Mutex + streamCancel func() + // waitgroup to wait for all goroutines to exit. wg sync.WaitGroup } @@ -319,10 +322,8 @@ func (ccw *remoteBalancerCCWrapper) sendLoadReport(s *balanceLoadClientStream, i } } -func (ccw *remoteBalancerCCWrapper) callRemoteBalancer() (backoff bool, _ error) { +func (ccw *remoteBalancerCCWrapper) callRemoteBalancer(ctx context.Context) (backoff bool, _ error) { lbClient := &loadBalancerClient{cc: ccw.cc} - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() stream, err := lbClient.BalanceLoad(ctx, grpc.WaitForReady(true)) if err != nil { return true, fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err) @@ -362,11 +363,43 @@ func (ccw *remoteBalancerCCWrapper) callRemoteBalancer() (backoff bool, _ error) return false, ccw.readServerList(stream) } +// cancelRemoteBalancerCall cancels the context used by the stream to the remote +// balancer. watchRemoteBalancer() takes care of restarting this call after the +// stream fails. +func (ccw *remoteBalancerCCWrapper) cancelRemoteBalancerCall() { + ccw.streamMu.Lock() + if ccw.streamCancel != nil { + ccw.streamCancel() + ccw.streamCancel = nil + } + ccw.streamMu.Unlock() +} + func (ccw *remoteBalancerCCWrapper) watchRemoteBalancer() { - defer ccw.wg.Done() + defer func() { + ccw.wg.Done() + ccw.streamMu.Lock() + if ccw.streamCancel != nil { + // This is to make sure that we don't leak the context when we are + // directly returning from inside of the below `for` loop. + ccw.streamCancel() + ccw.streamCancel = nil + } + ccw.streamMu.Unlock() + }() + var retryCount int + var ctx context.Context for { - doBackoff, err := ccw.callRemoteBalancer() + ccw.streamMu.Lock() + if ccw.streamCancel != nil { + ccw.streamCancel() + ccw.streamCancel = nil + } + ctx, ccw.streamCancel = context.WithCancel(context.Background()) + ccw.streamMu.Unlock() + + doBackoff, err := ccw.callRemoteBalancer(ctx) select { case <-ccw.done: return diff --git a/balancer/grpclb/grpclb_test.go b/balancer/grpclb/grpclb_test.go index 9cbb338c241..3b666764728 100644 --- a/balancer/grpclb/grpclb_test.go +++ b/balancer/grpclb/grpclb_test.go @@ -31,12 +31,16 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc" "google.golang.org/grpc/balancer" grpclbstate "google.golang.org/grpc/balancer/grpclb/state" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" @@ -60,6 +64,13 @@ var ( fakeName = "fake.Name" ) +const ( + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + testUserAgent = "test-user-agent" + grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}` +) + type s struct { grpctest.Tester } @@ -136,18 +147,6 @@ func (s *rpcStats) merge(cs *lbpb.ClientStats) { s.mu.Unlock() } -func mapsEqual(a, b map[string]int64) bool { - if len(a) != len(b) { - return false - } - for k, v1 := range a { - if v2, ok := b[k]; !ok || v1 != v2 { - return false - } - } - return true -} - func atomicEqual(a, b *int64) bool { return atomic.LoadInt64(a) == atomic.LoadInt64(b) } @@ -172,7 +171,7 @@ func (s *rpcStats) equal(o *rpcStats) bool { defer s.mu.Unlock() o.mu.Lock() defer o.mu.Unlock() - return mapsEqual(s.numCallsDropped, o.numCallsDropped) + return cmp.Equal(s.numCallsDropped, o.numCallsDropped, cmpopts.EquateEmpty()) } func (s *rpcStats) String() string { @@ -188,24 +187,28 @@ func (s *rpcStats) String() string { type remoteBalancer struct { lbgrpc.UnimplementedLoadBalancerServer - sls chan *lbpb.ServerList - statsDura time.Duration - done chan struct{} - stats *rpcStats - statsChan chan *lbpb.ClientStats - fbChan chan struct{} - - customUserAgent string + sls chan *lbpb.ServerList + statsDura time.Duration + done chan struct{} + stats *rpcStats + statsChan chan *lbpb.ClientStats + fbChan chan struct{} + balanceLoadCh chan struct{} // notify successful invocation of BalanceLoad + + wantUserAgent string // expected user-agent in metadata of BalancerLoad + wantServerName string // expected server name in InitialLoadBalanceRequest } -func newRemoteBalancer(customUserAgent string, statsChan chan *lbpb.ClientStats) *remoteBalancer { +func newRemoteBalancer(wantUserAgent, wantServerName string, statsChan chan *lbpb.ClientStats) *remoteBalancer { return &remoteBalancer{ - sls: make(chan *lbpb.ServerList, 1), - done: make(chan struct{}), - stats: newRPCStats(), - statsChan: statsChan, - fbChan: make(chan struct{}), - customUserAgent: customUserAgent, + sls: make(chan *lbpb.ServerList, 1), + done: make(chan struct{}), + stats: newRPCStats(), + statsChan: statsChan, + fbChan: make(chan struct{}), + balanceLoadCh: make(chan struct{}, 1), + wantUserAgent: wantUserAgent, + wantServerName: wantServerName, } } @@ -218,15 +221,18 @@ func (b *remoteBalancer) fallbackNow() { b.fbChan <- struct{}{} } +func (b *remoteBalancer) updateServerName(name string) { + b.wantServerName = name +} + func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error { md, ok := metadata.FromIncomingContext(stream.Context()) if !ok { return status.Error(codes.Internal, "failed to receive metadata") } - if b.customUserAgent != "" { - ua := md["user-agent"] - if len(ua) == 0 || !strings.HasPrefix(ua[0], b.customUserAgent) { - return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.customUserAgent) + if b.wantUserAgent != "" { + if ua := md["user-agent"]; len(ua) == 0 || !strings.HasPrefix(ua[0], b.wantUserAgent) { + return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.wantUserAgent) } } @@ -235,9 +241,10 @@ func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServe return err } initReq := req.GetInitialRequest() - if initReq.Name != beServerName { - return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name) + if initReq.Name != b.wantServerName { + return status.Errorf(codes.InvalidArgument, "invalid service name: %q, want: %q", initReq.Name, b.wantServerName) } + b.balanceLoadCh <- struct{}{} resp := &lbpb.LoadBalanceResponse{ LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{ InitialResponse: &lbpb.InitialLoadBalanceResponse{ @@ -253,11 +260,8 @@ func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServe } go func() { for { - var ( - req *lbpb.LoadBalanceRequest - err error - ) - if req, err = stream.Recv(); err != nil { + req, err := stream.Recv() + if err != nil { return } b.stats.merge(req.GetClientStats()) @@ -347,7 +351,7 @@ type testServers struct { beListeners []net.Listener } -func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) { +func startBackendsAndRemoteLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) { var ( beListeners []net.Listener ls *remoteBalancer @@ -380,7 +384,7 @@ func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan cha sn: lbServerName, } lb = grpc.NewServer(grpc.Creds(lbCreds)) - ls = newRemoteBalancer(customUserAgent, statsChan) + ls = newRemoteBalancer(customUserAgent, beServerName, statsChan) lbgrpc.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) @@ -407,34 +411,29 @@ func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan cha return } -var grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}` - func (s) TestGRPCLB(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - const testUserAgent = "test-user-agent" - tss, cleanup, err := newLoadBalancer(1, testUserAgent, nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, testUserAgent, nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, + tss.ls.sls <- &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, } - tss.ls.sls <- sl - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer), + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer), grpc.WithUserAgent(testUserAgent)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) @@ -445,12 +444,11 @@ func (s) TestGRPCLB(t *testing.T) { rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, &grpclbstate.State{BalancerAddresses: []resolver.Address{{ Addr: tss.lbAddr, - Type: resolver.Backend, ServerName: lbServerName, }}}) r.UpdateState(rs) - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) @@ -461,7 +459,7 @@ func (s) TestGRPCLB(t *testing.T) { func (s) TestGRPCLBWeighted(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(2, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(2, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -481,23 +479,25 @@ func (s) TestGRPCLBWeighted(t *testing.T) { portsToIndex[tss.bePorts[i]] = i } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}}) + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() sequences := []string{"00101", "00011"} for _, seq := range sequences { var ( @@ -526,7 +526,7 @@ func (s) TestGRPCLBWeighted(t *testing.T) { func (s) TestDropRequest(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(2, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(2, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -546,22 +546,23 @@ func (s) TestDropRequest(t *testing.T) { Drop: true, }}, } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}}) + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) var ( i int @@ -573,6 +574,8 @@ func (s) TestDropRequest(t *testing.T) { sleepEachLoop = time.Millisecond loopCount = int(time.Second / sleepEachLoop) ) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Make a non-fail-fast RPC and wait for it to succeed. for i = 0; i < loopCount; i++ { if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err == nil { @@ -681,49 +684,51 @@ func (s) TestBalancerDisconnects(t *testing.T) { lbs []*grpc.Server ) for i := 0; i < 2; i++ { - tss, cleanup, err := newLoadBalancer(1, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, + tss.ls.sls <- &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, } - tss.ls.sls <- sl tests = append(tests, tss) lbs = append(lbs, tss.lb) } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tests[0].lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: tests[1].lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}}) + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{ + { + Addr: tests[0].lbAddr, + ServerName: lbServerName, + }, + { + Addr: tests[1].lbAddr, + ServerName: lbServerName, + }, + }}) + r.UpdateState(rs) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() var p peer.Peer if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) @@ -750,16 +755,27 @@ func (s) TestBalancerDisconnects(t *testing.T) { func (s) TestFallback(t *testing.T) { balancer.Register(newLBBuilderWithFallbackTimeout(100 * time.Millisecond)) defer balancer.Register(newLBBuilder()) - r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, "", nil) + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } + tss.ls.sls <- sl - // Start a standalone backend. + // Start a standalone backend for fallback. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen %v", err) @@ -768,37 +784,29 @@ func (s) TestFallback(t *testing.T) { standaloneBEs := startBackends(beServerName, true, beLis) defer stopBackends(standaloneBEs) - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, - } - tss.ls.sls <- sl - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: "invalid.address", - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + // Push an update to the resolver with fallback backend address stored in + // the `Addresses` field and an invalid remote balancer address stored in + // attributes, which will cause fallback behavior to be invoked. + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: "invalid.address", ServerName: lbServerName}}}) + r.UpdateState(rs) + // Make an RPC and verify that it got routed to the fallback backend. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() var p peer.Peer if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, ", err) @@ -807,15 +815,21 @@ func (s) TestFallback(t *testing.T) { t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr()) } - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + // Push another update to the resolver, this time with a valid balancer + // address in the attributes field. + rs = resolver.State{ + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } + // Wait for RPCs to get routed to the backend behind the remote balancer. var backendUsed bool for i := 0; i < 1000; i++ { if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { @@ -856,7 +870,7 @@ func (s) TestFallback(t *testing.T) { t.Fatalf("No RPC sent to fallback after 2 seconds") } - // Restart backend and remote balancer, should not use backends. + // Restart backend and remote balancer, should not use fallback backend. tss.beListeners[0].(*restartableListener).restart() tss.lbListener.(*restartableListener).restart() tss.ls.sls <- sl @@ -880,13 +894,25 @@ func (s) TestFallback(t *testing.T) { func (s) TestExplicitFallback(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, "", nil) + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } + tss.ls.sls <- sl - // Start a standalone backend. + // Start a standalone backend for fallback. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen %v", err) @@ -895,37 +921,25 @@ func (s) TestExplicitFallback(t *testing.T) { standaloneBEs := startBackends(beServerName, true, beLis) defer stopBackends(standaloneBEs) - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, - } - tss.ls.sls <- sl - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() var p peer.Peer var backendUsed bool for i := 0; i < 2000; i++ { @@ -980,23 +994,34 @@ func (s) TestExplicitFallback(t *testing.T) { } func (s) TestFallBackWithNoServerAddress(t *testing.T) { - resolveNowCh := make(chan struct{}, 1) + resolveNowCh := testutils.NewChannel() r := manual.NewBuilderWithScheme("whatever") r.ResolveNowCallback = func(resolver.ResolveNowOptions) { - select { - case <-resolveNowCh: - default: + ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer cancel() + if err := resolveNowCh.SendContext(ctx, nil); err != nil { + t.Error("timeout when attemping to send on resolverNowCh") } - resolveNowCh <- struct{}{} } - tss, cleanup, err := newLoadBalancer(1, "", nil) + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer yet. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } - // Start a standalone backend. + // Start a standalone backend for fallback. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen %v", err) @@ -1005,81 +1030,61 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) { standaloneBEs := startBackends(beServerName, true, beLis) defer stopBackends(standaloneBEs) - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, - } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - // Select grpclb with service config. - const pfc = `{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"round_robin":{}}]}}]}` - scpr := r.CC.ParseServiceConfig(pfc) - if scpr.Err != nil { - t.Fatalf("Error parsing config %q: %v", pfc, scpr.Err) - } - + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() for i := 0; i < 2; i++ { - // Send an update with only backend address. grpclb should enter fallback - // and use the fallback backend. + // Send an update with only backend address. grpclb should enter + // fallback and use the fallback backend. r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}, - ServiceConfig: scpr, + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), }) - select { - case <-resolveNowCh: - t.Errorf("unexpected resolveNow when grpclb gets no balancer address 1111, %d", i) - case <-time.After(time.Second): + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := resolveNowCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("unexpected resolveNow when grpclb gets no balancer address 1111, %d", i) } var p peer.Peer - rpcCtx, rpcCancel := context.WithTimeout(context.Background(), time.Second) - defer rpcCancel() - if _, err := testC.EmptyCall(rpcCtx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, ", err) } if p.Addr.String() != beLis.Addr().String() { t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr()) } - select { - case <-resolveNowCh: + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := resolveNowCh.Receive(sCtx); err != context.DeadlineExceeded { t.Errorf("unexpected resolveNow when grpclb gets no balancer address 2222, %d", i) - case <-time.After(time.Second): } tss.ls.sls <- sl // Send an update with balancer address. The backends behind grpclb should // be used. - r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}, - ServiceConfig: scpr, - }) + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } var backendUsed bool for i := 0; i < 1000; i++ { @@ -1101,7 +1106,7 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) { func (s) TestGRPCLBPickFirst(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(3, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(3, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -1125,11 +1130,10 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { portsToIndex[tss.bePorts[i]] = i } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } @@ -1143,21 +1147,11 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { tss.ls.sls <- &lbpb.ServerList{Servers: beServers[0:3]} // Start with sub policy pick_first. - const pfc = `{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"pick_first":{}}]}}]}` - scpr := r.CC.ParseServiceConfig(pfc) - if scpr.Err != nil { - t.Fatalf("Error parsing config %q: %v", pfc, scpr.Err) - } - - r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}, - ServiceConfig: scpr, - }) + rs := resolver.State{ServiceConfig: r.CC.ParseServiceConfig(`{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"pick_first":{}}]}}]}`)} + r.UpdateState(grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}})) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() result = "" for i := 0; i < 1000; i++ { if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { @@ -1194,19 +1188,12 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { } // Switch sub policy to roundrobin. - grpclbServiceConfigEmpty := r.CC.ParseServiceConfig(`{}`) - if grpclbServiceConfigEmpty.Err != nil { - t.Fatalf("Error parsing config %q: %v", `{}`, grpclbServiceConfigEmpty.Err) - } - - r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ + rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ Addr: tss.lbAddr, - Type: resolver.GRPCLB, ServerName: lbServerName, - }}, - ServiceConfig: grpclbServiceConfigEmpty, - }) + }}}) + r.UpdateState(rs) result = "" for i := 0; i < 1000; i++ { @@ -1232,6 +1219,142 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { } } +func (s) TestGRPCLBBackendConnectionErrorPropagation(t *testing.T) { + r := manual.NewBuilderWithScheme("whatever") + + // Start up an LB which will tells the client to fall back right away. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(0, "", nil) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + + // Start a standalone backend, to be used during fallback. The creds + // are intentionally misconfigured in order to simulate failure of a + // security handshake. + beLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + defer beLis.Close() + standaloneBEs := startBackends("arbitrary.invalid.name", true, beLis) + defer stopBackends(standaloneBEs) + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + defer cc.Close() + testC := testpb.NewTestServiceClient(cc) + + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + + // If https://github.com/grpc/grpc-go/blob/65cabd74d8e18d7347fecd414fa8d83a00035f5f/balancer/grpclb/grpclb_test.go#L103 + // changes, then expectedErrMsg may need to be updated. + const expectedErrMsg = "received unexpected server name" + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + var wg sync.WaitGroup + wg.Add(1) + go func() { + tss.ls.fallbackNow() + wg.Done() + }() + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err == nil || !strings.Contains(err.Error(), expectedErrMsg) { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, rpc error containing substring: %q", testC, err, expectedErrMsg) + } + wg.Wait() +} + +func (s) TestGRPCLBWithTargetNameFieldInConfig(t *testing.T) { + r := manual.NewBuilderWithScheme("whatever") + + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } + tss.ls.sls <- sl + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer), + grpc.WithUserAgent(testUserAgent)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + defer cc.Close() + testC := testpb.NewTestServiceClient(cc) + + // Push a resolver update with grpclb configuration which does not contain the + // target_name field. Our fake remote balancer is configured to always + // expect `beServerName` as the server name in the initial request. + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + } + + // When the value of target_field changes, grpclb will recreate the stream + // to the remote balancer. So, we need to update the fake remote balancer to + // expect a new server name in the initial request. + const newServerName = "new-server-name" + tss.ls.updateServerName(newServerName) + tss.ls.sls <- sl + + // Push the resolver update with target_field changed. + // Push a resolver update with grpclb configuration containing the + // target_name field. Our fake remote balancer has been updated above to expect the newServerName in the initial request. + lbCfg := fmt.Sprintf(`{"loadBalancingConfig": [{"grpclb": {"targetName": "%s"}}]}`, newServerName) + rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(lbCfg)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } + + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + } +} + type failPreRPCCred struct{} func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { @@ -1255,7 +1378,7 @@ func checkStats(stats, expected *rpcStats) error { func runAndCheckStats(t *testing.T, drop bool, statsChan chan *lbpb.ClientStats, runRPCs func(*grpc.ClientConn), statsWant *rpcStats) error { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, "", statsChan) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", statsChan) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } diff --git a/balancer/grpclb/grpclb_test_util_test.go b/balancer/grpclb/grpclb_test_util_test.go index 5d3e6ba7fed..c143e961754 100644 --- a/balancer/grpclb/grpclb_test_util_test.go +++ b/balancer/grpclb/grpclb_test_util_test.go @@ -48,19 +48,20 @@ func newRestartableListener(l net.Listener) *restartableListener { } } -func (l *restartableListener) Accept() (conn net.Conn, err error) { - conn, err = l.Listener.Accept() - if err == nil { - l.mu.Lock() - if l.closed { - conn.Close() - l.mu.Unlock() - return nil, &tempError{} - } - l.conns = append(l.conns, conn) - l.mu.Unlock() +func (l *restartableListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err } - return + + l.mu.Lock() + defer l.mu.Unlock() + if l.closed { + conn.Close() + return nil, &tempError{} + } + l.conns = append(l.conns, conn) + return conn, nil } func (l *restartableListener) Close() error { diff --git a/balancer/rls/internal/balancer.go b/balancer/rls/internal/balancer.go index 7af97b76faf..b23783bf9da 100644 --- a/balancer/rls/internal/balancer.go +++ b/balancer/rls/internal/balancer.go @@ -129,6 +129,10 @@ func (lb *rlsBalancer) Close() { } } +func (lb *rlsBalancer) ExitIdle() { + // TODO: are we 100% sure this should be a nop? +} + // updateControlChannel updates the RLS client if required. // Caller must hold lb.mu. func (lb *rlsBalancer) updateControlChannel(newCfg *lbConfig) { diff --git a/balancer/rls/internal/config.go b/balancer/rls/internal/config.go index a3deb8906c9..b27a2970f08 100644 --- a/balancer/rls/internal/config.go +++ b/balancer/rls/internal/config.go @@ -22,15 +22,16 @@ import ( "bytes" "encoding/json" "fmt" + "net/url" "time" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/ptypes" durationpb "github.com/golang/protobuf/ptypes/duration" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/rls/internal/keys" rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1" - "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) @@ -62,9 +63,10 @@ type lbConfig struct { staleAge time.Duration cacheSizeBytes int64 defaultTarget string - cpName string - cpTargetField string - cpConfig map[string]json.RawMessage + + childPolicyName string + childPolicyConfig map[string]json.RawMessage + childPolicyTargetField string } func (lbCfg *lbConfig) Equal(other *lbConfig) bool { @@ -75,21 +77,21 @@ func (lbCfg *lbConfig) Equal(other *lbConfig) bool { lbCfg.staleAge == other.staleAge && lbCfg.cacheSizeBytes == other.cacheSizeBytes && lbCfg.defaultTarget == other.defaultTarget && - lbCfg.cpName == other.cpName && - lbCfg.cpTargetField == other.cpTargetField && - cpConfigEqual(lbCfg.cpConfig, other.cpConfig) + lbCfg.childPolicyName == other.childPolicyName && + lbCfg.childPolicyTargetField == other.childPolicyTargetField && + childPolicyConfigEqual(lbCfg.childPolicyConfig, other.childPolicyConfig) } -func cpConfigEqual(am, bm map[string]json.RawMessage) bool { - if (bm == nil) != (am == nil) { +func childPolicyConfigEqual(a, b map[string]json.RawMessage) bool { + if (b == nil) != (a == nil) { return false } - if len(bm) != len(am) { + if len(b) != len(a) { return false } - for k, jsonA := range am { - jsonB, ok := bm[k] + for k, jsonA := range a { + jsonB, ok := b[k] if !ok { return false } @@ -100,71 +102,18 @@ func cpConfigEqual(am, bm map[string]json.RawMessage) bool { return true } -// This struct resembles the JSON respresentation of the loadBalancing config +// This struct resembles the JSON representation of the loadBalancing config // and makes it easier to unmarshal. type lbConfigJSON struct { RouteLookupConfig json.RawMessage - ChildPolicy []*loadBalancingConfig + ChildPolicy []map[string]json.RawMessage ChildPolicyConfigTargetFieldName string } -// loadBalancingConfig represents a single load balancing config, -// stored in JSON format. -// -// TODO(easwars): This code seems to be repeated in a few places -// (service_config.go and in the xds code as well). Refactor and re-use. -type loadBalancingConfig struct { - Name string - Config json.RawMessage -} - -// MarshalJSON returns a JSON encoding of l. -func (l *loadBalancingConfig) MarshalJSON() ([]byte, error) { - return nil, fmt.Errorf("rls: loadBalancingConfig.MarshalJSON() is unimplemented") -} - -// UnmarshalJSON parses the JSON-encoded byte slice in data and stores it in l. -func (l *loadBalancingConfig) UnmarshalJSON(data []byte) error { - var cfg map[string]json.RawMessage - if err := json.Unmarshal(data, &cfg); err != nil { - return err - } - for name, config := range cfg { - l.Name = name - l.Config = config - } - return nil -} - // ParseConfig parses and validates the JSON representation of the service // config and returns the loadBalancingConfig to be used by the RLS LB policy. // // Helps implement the balancer.ConfigParser interface. -// -// The following validation checks are performed: -// * routeLookupConfig: -// ** grpc_keybuilders field: -// - must have at least one entry -// - must not have two entries with the same Name -// - must not have any entry with a Name with the service field unset or -// empty -// - must not have any entries without a Name -// - must not have a headers entry that has required_match set -// - must not have two headers entries with the same key within one entry -// ** lookup_service field: -// - must be set and non-empty and must parse as a target URI -// ** max_age field: -// - if not specified or is greater than maxMaxAge, it will be reset to -// maxMaxAge -// ** stale_age field: -// - if the value is greater than or equal to max_age, it is ignored -// - if set, then max_age must also be set -// ** valid_targets field: -// - will be ignored -// ** cache_size_bytes field: -// - must be greater than zero -// - TODO(easwars): Define a minimum value for this field, to be used when -// left unspecified // * childPolicy field: // - must find a valid child policy with a valid config (the child policy must // be able to parse the provided config successfully when we pass it a dummy @@ -178,20 +127,58 @@ func (*rlsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, return nil, fmt.Errorf("rls: json unmarshal failed for service config {%+v}: %v", string(c), err) } + // Unmarshal and validate contents of the RLS proto. m := jsonpb.Unmarshaler{AllowUnknownFields: true} rlsProto := &rlspb.RouteLookupConfig{} if err := m.Unmarshal(bytes.NewReader(cfgJSON.RouteLookupConfig), rlsProto); err != nil { return nil, fmt.Errorf("rls: bad RouteLookupConfig proto {%+v}: %v", string(cfgJSON.RouteLookupConfig), err) } + lbCfg, err := parseRLSProto(rlsProto) + if err != nil { + return nil, err + } - var childPolicy *loadBalancingConfig - for _, lbcfg := range cfgJSON.ChildPolicy { - if balancer.Get(lbcfg.Name) != nil { - childPolicy = lbcfg - break - } + // Unmarshal and validate child policy configs. + if cfgJSON.ChildPolicyConfigTargetFieldName == "" { + return nil, fmt.Errorf("rls: childPolicyConfigTargetFieldName field is not set in service config {%+v}", string(c)) } + name, config, err := parseChildPolicyConfigs(cfgJSON.ChildPolicy, cfgJSON.ChildPolicyConfigTargetFieldName) + if err != nil { + return nil, err + } + lbCfg.childPolicyName = name + lbCfg.childPolicyConfig = config + lbCfg.childPolicyTargetField = cfgJSON.ChildPolicyConfigTargetFieldName + return lbCfg, nil +} +// parseRLSProto fetches relevant information from the RouteLookupConfig proto +// and validates the values in the process. +// +// The following validation checks are performed: +// ** grpc_keybuilders field: +// - must have at least one entry +// - must not have two entries with the same Name +// - must not have any entry with a Name with the service field unset or +// empty +// - must not have any entries without a Name +// - must not have a headers entry that has required_match set +// - must not have two headers entries with the same key within one entry +// ** lookup_service field: +// - must be set and non-empty and must parse as a target URI +// ** max_age field: +// - if not specified or is greater than maxMaxAge, it will be reset to +// maxMaxAge +// ** stale_age field: +// - if the value is greater than or equal to max_age, it is ignored +// - if set, then max_age must also be set +// ** valid_targets field: +// - will be ignored +// ** cache_size_bytes field: +// - must be greater than zero +// - TODO(easwars): Define a minimum value for this field, to be used when +// left unspecified +func parseRLSProto(rlsProto *rlspb.RouteLookupConfig) (*lbConfig, error) { kbMap, err := keys.MakeBuilderMap(rlsProto) if err != nil { return nil, err @@ -199,64 +186,54 @@ func (*rlsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, lookupService := rlsProto.GetLookupService() if lookupService == "" { - return nil, fmt.Errorf("rls: empty lookup_service in service config {%+v}", string(c)) + return nil, fmt.Errorf("rls: empty lookup_service in route lookup config {%+v}", rlsProto) + } + parsedTarget, err := url.Parse(lookupService) + if err != nil { + // If the first attempt failed because of a missing scheme, try again + // with the default scheme. + parsedTarget, err = url.Parse(resolver.GetDefaultScheme() + ":///" + lookupService) + if err != nil { + return nil, fmt.Errorf("rls: invalid target URI in lookup_service {%s}", lookupService) + } } - parsedTarget := grpcutil.ParseTarget(lookupService, false) if parsedTarget.Scheme == "" { parsedTarget.Scheme = resolver.GetDefaultScheme() } if resolver.Get(parsedTarget.Scheme) == nil { - return nil, fmt.Errorf("rls: invalid target URI in lookup_service {%s}", lookupService) + return nil, fmt.Errorf("rls: unregistered scheme in lookup_service {%s}", lookupService) } lookupServiceTimeout, err := convertDuration(rlsProto.GetLookupServiceTimeout()) if err != nil { - return nil, fmt.Errorf("rls: failed to parse lookup_service_timeout in service config {%+v}: %v", string(c), err) + return nil, fmt.Errorf("rls: failed to parse lookup_service_timeout in route lookup config {%+v}: %v", rlsProto, err) } if lookupServiceTimeout == 0 { lookupServiceTimeout = defaultLookupServiceTimeout } maxAge, err := convertDuration(rlsProto.GetMaxAge()) if err != nil { - return nil, fmt.Errorf("rls: failed to parse max_age in service config {%+v}: %v", string(c), err) + return nil, fmt.Errorf("rls: failed to parse max_age in route lookup config {%+v}: %v", rlsProto, err) } staleAge, err := convertDuration(rlsProto.GetStaleAge()) if err != nil { - return nil, fmt.Errorf("rls: failed to parse staleAge in service config {%+v}: %v", string(c), err) + return nil, fmt.Errorf("rls: failed to parse staleAge in route lookup config {%+v}: %v", rlsProto, err) } if staleAge != 0 && maxAge == 0 { - return nil, fmt.Errorf("rls: stale_age is set, but max_age is not in service config {%+v}", string(c)) + return nil, fmt.Errorf("rls: stale_age is set, but max_age is not in route lookup config {%+v}", rlsProto) } if staleAge >= maxAge { logger.Info("rls: stale_age {%v} is greater than max_age {%v}, ignoring it", staleAge, maxAge) staleAge = 0 } if maxAge == 0 || maxAge > maxMaxAge { - logger.Infof("rls: max_age in service config is %v, using %v", maxAge, maxMaxAge) + logger.Infof("rls: max_age in route lookup config is %v, using %v", maxAge, maxMaxAge) maxAge = maxMaxAge } cacheSizeBytes := rlsProto.GetCacheSizeBytes() if cacheSizeBytes <= 0 { - return nil, fmt.Errorf("rls: cache_size_bytes must be greater than 0 in service config {%+v}", string(c)) + return nil, fmt.Errorf("rls: cache_size_bytes must be greater than 0 in route lookup config {%+v}", rlsProto) } - if childPolicy == nil { - return nil, fmt.Errorf("rls: childPolicy is invalid in service config {%+v}", string(c)) - } - if cfgJSON.ChildPolicyConfigTargetFieldName == "" { - return nil, fmt.Errorf("rls: childPolicyConfigTargetFieldName field is not set in service config {%+v}", string(c)) - } - // TODO(easwars): When we start instantiating the child policy from the - // parent RLS LB policy, we could make this function a method on the - // lbConfig object and share the code. We would be parsing the child policy - // config again during that time. The only difference betweeen now and then - // would be that we would be using real targetField name instead of the - // dummy. So, we could make the targetName field a parameter to this - // function during the refactor. - cpCfg, err := validateChildPolicyConfig(childPolicy, cfgJSON.ChildPolicyConfigTargetFieldName) - if err != nil { - return nil, err - } - return &lbConfig{ kbMap: kbMap, lookupService: lookupService, @@ -265,57 +242,50 @@ func (*rlsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, staleAge: staleAge, cacheSizeBytes: cacheSizeBytes, defaultTarget: rlsProto.GetDefaultTarget(), - // TODO(easwars): Once we refactor validateChildPolicyConfig and make - // it a method on the lbConfig object, we could directly store the - // balancer.Builder and/or balancer.ConfigParser here instead of the - // Name. That would mean that we would have to create the lbConfig - // object here first before validating the childPolicy config, but - // that's a minor detail. - cpName: childPolicy.Name, - cpTargetField: cfgJSON.ChildPolicyConfigTargetFieldName, - cpConfig: cpCfg, }, nil } -// validateChildPolicyConfig validates the child policy config received in the -// service config. This makes it possible for us to reject service configs -// which contain invalid child policy configs which we know will fail for sure. -// -// It does the following: -// * Unmarshals the provided child policy config into a map of string to -// json.RawMessage. This allows us to add an entry to the map corresponding -// to the targetFieldName that we received in the service config. -// * Marshals the map back into JSON, finds the config parser associated with -// the child policy and asks it to validate the config. -// * If the validation succeeded, removes the dummy entry from the map and -// returns it. If any of the above steps failed, it returns an error. -func validateChildPolicyConfig(cp *loadBalancingConfig, cpTargetField string) (map[string]json.RawMessage, error) { - var childConfig map[string]json.RawMessage - if err := json.Unmarshal(cp.Config, &childConfig); err != nil { - return nil, fmt.Errorf("rls: json unmarshal failed for child policy config {%+v}: %v", cp.Config, err) - } - childConfig[cpTargetField], _ = json.Marshal(dummyChildPolicyTarget) +// parseChildPolicyConfigs iterates through the list of child policies and picks +// the first registered policy and validates its config. +func parseChildPolicyConfigs(childPolicies []map[string]json.RawMessage, targetFieldName string) (string, map[string]json.RawMessage, error) { + for i, config := range childPolicies { + if len(config) != 1 { + return "", nil, fmt.Errorf("rls: invalid childPolicy: entry %v does not contain exactly 1 policy/config pair: %q", i, config) + } - jsonCfg, err := json.Marshal(childConfig) - if err != nil { - return nil, fmt.Errorf("rls: json marshal failed for child policy config {%+v}: %v", childConfig, err) - } - builder := balancer.Get(cp.Name) - if builder == nil { - // This should never happen since we already made sure that the child - // policy name mentioned in the service config is a valid one. - return nil, fmt.Errorf("rls: balancer builder not found for child_policy %v", cp.Name) - } - parser, ok := builder.(balancer.ConfigParser) - if !ok { - return nil, fmt.Errorf("rls: balancer builder for child_policy does not implement balancer.ConfigParser: %v", cp.Name) - } - _, err = parser.ParseConfig(jsonCfg) - if err != nil { - return nil, fmt.Errorf("rls: childPolicy config validation failed: %v", err) + var name string + var rawCfg json.RawMessage + for name, rawCfg = range config { + } + builder := balancer.Get(name) + if builder == nil { + continue + } + parser, ok := builder.(balancer.ConfigParser) + if !ok { + return "", nil, fmt.Errorf("rls: childPolicy %q with config %q does not support config parsing", name, string(rawCfg)) + } + + // To validate child policy configs we do the following: + // - unmarshal the raw JSON bytes of the child policy config into a map + // - add an entry with key set to `target_field_name` and a dummy value + // - marshal the map back to JSON and parse the config using the parser + // retrieved previously + var childConfig map[string]json.RawMessage + if err := json.Unmarshal(rawCfg, &childConfig); err != nil { + return "", nil, fmt.Errorf("rls: json unmarshal failed for child policy config %q: %v", string(rawCfg), err) + } + childConfig[targetFieldName], _ = json.Marshal(dummyChildPolicyTarget) + jsonCfg, err := json.Marshal(childConfig) + if err != nil { + return "", nil, fmt.Errorf("rls: json marshal failed for child policy config {%+v}: %v", childConfig, err) + } + if _, err := parser.ParseConfig(jsonCfg); err != nil { + return "", nil, fmt.Errorf("rls: childPolicy config validation failed: %v", err) + } + return name, childConfig, nil } - delete(childConfig, cpTargetField) - return childConfig, nil + return "", nil, fmt.Errorf("rls: invalid childPolicy config: no supported policies found in %+v", childPolicies) } func convertDuration(d *durationpb.Duration) (time.Duration, error) { diff --git a/balancer/rls/internal/config_test.go b/balancer/rls/internal/config_test.go index 1efd054512b..41d330c604e 100644 --- a/balancer/rls/internal/config_test.go +++ b/balancer/rls/internal/config_test.go @@ -25,8 +25,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer" _ "google.golang.org/grpc/balancer/grpclb" // grpclb for config parsing. _ "google.golang.org/grpc/internal/resolver/passthrough" // passthrough resolver. @@ -58,12 +56,13 @@ func testEqual(a, b *lbConfig) bool { a.staleAge == b.staleAge && a.cacheSizeBytes == b.cacheSizeBytes && a.defaultTarget == b.defaultTarget && - a.cpName == b.cpName && - a.cpTargetField == b.cpTargetField && - cmp.Equal(a.cpConfig, b.cpConfig) + a.childPolicyName == b.childPolicyName && + a.childPolicyTargetField == b.childPolicyTargetField && + childPolicyConfigEqual(a.childPolicyConfig, b.childPolicyConfig) } func TestParseConfig(t *testing.T) { + childPolicyTargetFieldVal, _ := json.Marshal(dummyChildPolicyTarget) tests := []struct { desc string input []byte @@ -85,7 +84,7 @@ func TestParseConfig(t *testing.T) { "names": [{"service": "service", "method": "method"}], "headers": [{"key": "k1", "names": ["v1"]}] }], - "lookupService": "passthrough:///target", + "lookupService": ":///target", "maxAge" : "500s", "staleAge": "600s", "cacheSizeBytes": 1000, @@ -99,15 +98,18 @@ func TestParseConfig(t *testing.T) { "childPolicyConfigTargetFieldName": "service_name" }`), wantCfg: &lbConfig{ - lookupService: "passthrough:///target", - lookupServiceTimeout: 10 * time.Second, // This is the default value. - maxAge: 5 * time.Minute, // This is max maxAge. - staleAge: time.Duration(0), // StaleAge is ignore because it was higher than maxAge. - cacheSizeBytes: 1000, - defaultTarget: "passthrough:///default", - cpName: "grpclb", - cpTargetField: "service_name", - cpConfig: map[string]json.RawMessage{"childPolicy": json.RawMessage(`[{"pickfirst": {}}]`)}, + lookupService: ":///target", + lookupServiceTimeout: 10 * time.Second, // This is the default value. + maxAge: 5 * time.Minute, // This is max maxAge. + staleAge: time.Duration(0), // StaleAge is ignore because it was higher than maxAge. + cacheSizeBytes: 1000, + defaultTarget: "passthrough:///default", + childPolicyName: "grpclb", + childPolicyTargetField: "service_name", + childPolicyConfig: map[string]json.RawMessage{ + "childPolicy": json.RawMessage(`[{"pickfirst": {}}]`), + "service_name": json.RawMessage(childPolicyTargetFieldVal), + }, }, }, { @@ -118,7 +120,7 @@ func TestParseConfig(t *testing.T) { "names": [{"service": "service", "method": "method"}], "headers": [{"key": "k1", "names": ["v1"]}] }], - "lookupService": "passthrough:///target", + "lookupService": "target", "lookupServiceTimeout" : "100s", "maxAge": "60s", "staleAge" : "50s", @@ -129,15 +131,18 @@ func TestParseConfig(t *testing.T) { "childPolicyConfigTargetFieldName": "service_name" }`), wantCfg: &lbConfig{ - lookupService: "passthrough:///target", - lookupServiceTimeout: 100 * time.Second, - maxAge: 60 * time.Second, - staleAge: 50 * time.Second, - cacheSizeBytes: 1000, - defaultTarget: "passthrough:///default", - cpName: "grpclb", - cpTargetField: "service_name", - cpConfig: map[string]json.RawMessage{"childPolicy": json.RawMessage(`[{"pickfirst": {}}]`)}, + lookupService: "target", + lookupServiceTimeout: 100 * time.Second, + maxAge: 60 * time.Second, + staleAge: 50 * time.Second, + cacheSizeBytes: 1000, + defaultTarget: "passthrough:///default", + childPolicyName: "grpclb", + childPolicyTargetField: "service_name", + childPolicyConfig: map[string]json.RawMessage{ + "childPolicy": json.RawMessage(`[{"pickfirst": {}}]`), + "service_name": json.RawMessage(childPolicyTargetFieldVal), + }, }, }, } @@ -191,10 +196,10 @@ func TestParseConfigErrors(t *testing.T) { }] } }`), - wantErr: "rls: empty lookup_service in service config", + wantErr: "rls: empty lookup_service in route lookup config", }, { - desc: "invalid lookup service URI", + desc: "unregistered scheme in lookup service URI", input: []byte(`{ "routeLookupConfig": { "grpcKeybuilders": [{ @@ -204,7 +209,7 @@ func TestParseConfigErrors(t *testing.T) { "lookupService": "badScheme:///target" } }`), - wantErr: "rls: invalid target URI in lookup_service", + wantErr: "rls: unregistered scheme in lookup_service", }, { desc: "invalid lookup service timeout", @@ -264,7 +269,7 @@ func TestParseConfigErrors(t *testing.T) { "staleAge" : "10s" } }`), - wantErr: "rls: stale_age is set, but max_age is not in service config", + wantErr: "rls: stale_age is set, but max_age is not in route lookup config", }, { desc: "invalid cache size", @@ -280,7 +285,7 @@ func TestParseConfigErrors(t *testing.T) { "staleAge" : "25s" } }`), - wantErr: "rls: cache_size_bytes must be greater than 0 in service config", + wantErr: "rls: cache_size_bytes must be greater than 0 in route lookup config", }, { desc: "no child policy", @@ -296,9 +301,10 @@ func TestParseConfigErrors(t *testing.T) { "staleAge" : "25s", "cacheSizeBytes": 1000, "defaultTarget": "passthrough:///default" - } + }, + "childPolicyConfigTargetFieldName": "service_name" }`), - wantErr: "rls: childPolicy is invalid in service config", + wantErr: "rls: invalid childPolicy config: no supported policies found", }, { desc: "no known child policy", @@ -318,9 +324,35 @@ func TestParseConfigErrors(t *testing.T) { "childPolicy": [ {"cds_experimental": {"Cluster": "my-fav-cluster"}}, {"unknown-policy": {"unknown-field": "unknown-value"}} - ] + ], + "childPolicyConfigTargetFieldName": "service_name" + }`), + wantErr: "rls: invalid childPolicy config: no supported policies found", + }, + { + desc: "invalid child policy config - more than one entry in map", + input: []byte(`{ + "routeLookupConfig": { + "grpcKeybuilders": [{ + "names": [{"service": "service", "method": "method"}], + "headers": [{"key": "k1", "names": ["v1"]}] + }], + "lookupService": "passthrough:///target", + "lookupServiceTimeout" : "10s", + "maxAge": "30s", + "staleAge" : "25s", + "cacheSizeBytes": 1000, + "defaultTarget": "passthrough:///default" + }, + "childPolicy": [ + { + "cds_experimental": {"Cluster": "my-fav-cluster"}, + "unknown-policy": {"unknown-field": "unknown-value"} + } + ], + "childPolicyConfigTargetFieldName": "service_name" }`), - wantErr: "rls: childPolicy is invalid in service config", + wantErr: "does not contain exactly 1 policy/config pair", }, { desc: "no childPolicyConfigTargetFieldName", @@ -381,60 +413,3 @@ func TestParseConfigErrors(t *testing.T) { }) } } - -func TestValidateChildPolicyConfig(t *testing.T) { - jsonCfg := json.RawMessage(`[{"round_robin" : {}}, {"pick_first" : {}}]`) - wantChildConfig := map[string]json.RawMessage{"childPolicy": jsonCfg} - cp := &loadBalancingConfig{ - Name: "grpclb", - Config: []byte(`{"childPolicy": [{"round_robin" : {}}, {"pick_first" : {}}]}`), - } - cpTargetField := "serviceName" - - gotChildConfig, err := validateChildPolicyConfig(cp, cpTargetField) - if err != nil || !cmp.Equal(gotChildConfig, wantChildConfig) { - t.Errorf("validateChildPolicyConfig(%v, %v) = {%v, %v}, want {%v, nil}", cp, cpTargetField, gotChildConfig, err, wantChildConfig) - } -} - -func TestValidateChildPolicyConfigErrors(t *testing.T) { - tests := []struct { - desc string - cp *loadBalancingConfig - wantErrPrefix string - }{ - { - desc: "unknown child policy", - cp: &loadBalancingConfig{ - Name: "unknown", - Config: []byte(`{}`), - }, - wantErrPrefix: "rls: balancer builder not found for child_policy", - }, - { - desc: "balancer builder does not implement ConfigParser", - cp: &loadBalancingConfig{ - Name: balancerWithoutConfigParserName, - Config: []byte(`{}`), - }, - wantErrPrefix: "rls: balancer builder for child_policy does not implement balancer.ConfigParser", - }, - { - desc: "child policy config parsing failure", - cp: &loadBalancingConfig{ - Name: "grpclb", - Config: []byte(`{"childPolicy": "not-an-array"}`), - }, - wantErrPrefix: "rls: childPolicy config validation failed", - }, - } - - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - gotChildConfig, gotErr := validateChildPolicyConfig(test.cp, "") - if gotChildConfig != nil || !strings.HasPrefix(fmt.Sprint(gotErr), test.wantErrPrefix) { - t.Errorf("validateChildPolicyConfig(%v) = {%v, %v}, want {nil, %v}", test.cp, gotChildConfig, gotErr, test.wantErrPrefix) - } - }) - } -} diff --git a/balancer/rls/internal/keys/builder.go b/balancer/rls/internal/keys/builder.go index 5ce5a9da508..24767b405f0 100644 --- a/balancer/rls/internal/keys/builder.go +++ b/balancer/rls/internal/keys/builder.go @@ -218,7 +218,7 @@ func (b builder) keys(md metadata.MD) KeyMap { } func mapToString(kv map[string]string) string { - var keys []string + keys := make([]string, 0, len(kv)) for k := range kv { keys = append(keys, k) } diff --git a/balancer/rls/internal/proto/grpc_lookup_v1/rls.pb.go b/balancer/rls/internal/proto/grpc_lookup_v1/rls.pb.go index d48a1a6de84..9f063fd3d62 100644 --- a/balancer/rls/internal/proto/grpc_lookup_v1/rls.pb.go +++ b/balancer/rls/internal/proto/grpc_lookup_v1/rls.pb.go @@ -39,6 +39,56 @@ const ( // of the legacy proto package is being used. const _ = proto.ProtoPackageIsVersion4 +// Possible reasons for making a request. +type RouteLookupRequest_Reason int32 + +const ( + RouteLookupRequest_REASON_UNKNOWN RouteLookupRequest_Reason = 0 // Unused + RouteLookupRequest_REASON_MISS RouteLookupRequest_Reason = 1 // No data available in local cache + RouteLookupRequest_REASON_STALE RouteLookupRequest_Reason = 2 // Data in local cache is stale +) + +// Enum value maps for RouteLookupRequest_Reason. +var ( + RouteLookupRequest_Reason_name = map[int32]string{ + 0: "REASON_UNKNOWN", + 1: "REASON_MISS", + 2: "REASON_STALE", + } + RouteLookupRequest_Reason_value = map[string]int32{ + "REASON_UNKNOWN": 0, + "REASON_MISS": 1, + "REASON_STALE": 2, + } +) + +func (x RouteLookupRequest_Reason) Enum() *RouteLookupRequest_Reason { + p := new(RouteLookupRequest_Reason) + *p = x + return p +} + +func (x RouteLookupRequest_Reason) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RouteLookupRequest_Reason) Descriptor() protoreflect.EnumDescriptor { + return file_grpc_lookup_v1_rls_proto_enumTypes[0].Descriptor() +} + +func (RouteLookupRequest_Reason) Type() protoreflect.EnumType { + return &file_grpc_lookup_v1_rls_proto_enumTypes[0] +} + +func (x RouteLookupRequest_Reason) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RouteLookupRequest_Reason.Descriptor instead. +func (RouteLookupRequest_Reason) EnumDescriptor() ([]byte, []int) { + return file_grpc_lookup_v1_rls_proto_rawDescGZIP(), []int{0, 0} +} + type RouteLookupRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -46,13 +96,23 @@ type RouteLookupRequest struct { // Full host name of the target server, e.g. firestore.googleapis.com. // Only set for gRPC requests; HTTP requests must use key_map explicitly. + // Deprecated in favor of setting key_map keys with GrpcKeyBuilder.extra_keys. + // + // Deprecated: Do not use. Server string `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"` // Full path of the request, i.e. "/service/method". // Only set for gRPC requests; HTTP requests must use key_map explicitly. + // Deprecated in favor of setting key_map keys with GrpcKeyBuilder.extra_keys. + // + // Deprecated: Do not use. Path string `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"` // Target type allows the client to specify what kind of target format it // would like from RLS to allow it to find the regional server, e.g. "grpc". TargetType string `protobuf:"bytes,3,opt,name=target_type,json=targetType,proto3" json:"target_type,omitempty"` + // Reason for making this request. + Reason RouteLookupRequest_Reason `protobuf:"varint,5,opt,name=reason,proto3,enum=grpc.lookup.v1.RouteLookupRequest_Reason" json:"reason,omitempty"` + // For REASON_STALE, the header_data from the stale response, if any. + StaleHeaderData string `protobuf:"bytes,6,opt,name=stale_header_data,json=staleHeaderData,proto3" json:"stale_header_data,omitempty"` // Map of key values extracted via key builders for the gRPC or HTTP request. KeyMap map[string]string `protobuf:"bytes,4,rep,name=key_map,json=keyMap,proto3" json:"key_map,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } @@ -89,6 +149,7 @@ func (*RouteLookupRequest) Descriptor() ([]byte, []int) { return file_grpc_lookup_v1_rls_proto_rawDescGZIP(), []int{0} } +// Deprecated: Do not use. func (x *RouteLookupRequest) GetServer() string { if x != nil { return x.Server @@ -96,6 +157,7 @@ func (x *RouteLookupRequest) GetServer() string { return "" } +// Deprecated: Do not use. func (x *RouteLookupRequest) GetPath() string { if x != nil { return x.Path @@ -110,6 +172,20 @@ func (x *RouteLookupRequest) GetTargetType() string { return "" } +func (x *RouteLookupRequest) GetReason() RouteLookupRequest_Reason { + if x != nil { + return x.Reason + } + return RouteLookupRequest_REASON_UNKNOWN +} + +func (x *RouteLookupRequest) GetStaleHeaderData() string { + if x != nil { + return x.StaleHeaderData + } + return "" +} + func (x *RouteLookupRequest) GetKeyMap() map[string]string { if x != nil { return x.KeyMap @@ -183,40 +259,52 @@ var File_grpc_lookup_v1_rls_proto protoreflect.FileDescriptor var file_grpc_lookup_v1_rls_proto_rawDesc = []byte{ 0x0a, 0x18, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2f, 0x76, 0x31, 0x2f, 0x72, 0x6c, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e, 0x67, 0x72, 0x70, 0x63, - 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x22, 0xe5, 0x01, 0x0a, 0x12, 0x52, + 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x22, 0x9d, 0x03, 0x0a, 0x12, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, - 0x0b, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0a, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x47, - 0x0a, 0x07, 0x6b, 0x65, 0x79, 0x5f, 0x6d, 0x61, 0x70, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x2e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, + 0x74, 0x12, 0x1a, 0x0a, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x16, 0x0a, + 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, + 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x5f, + 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x74, 0x61, 0x72, 0x67, + 0x65, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x41, 0x0a, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x29, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, + 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, + 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x52, 0x65, 0x61, 0x73, 0x6f, + 0x6e, 0x52, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x12, 0x2a, 0x0a, 0x11, 0x73, 0x74, 0x61, + 0x6c, 0x65, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x74, 0x61, 0x6c, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x44, 0x61, 0x74, 0x61, 0x12, 0x47, 0x0a, 0x07, 0x6b, 0x65, 0x79, 0x5f, 0x6d, 0x61, 0x70, + 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, + 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, + 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x4b, 0x65, 0x79, 0x4d, 0x61, + 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x06, 0x6b, 0x65, 0x79, 0x4d, 0x61, 0x70, 0x1a, 0x39, + 0x0a, 0x0b, 0x4b, 0x65, 0x79, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, + 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, + 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x3f, 0x0a, 0x06, 0x52, 0x65, 0x61, + 0x73, 0x6f, 0x6e, 0x12, 0x12, 0x0a, 0x0e, 0x52, 0x45, 0x41, 0x53, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, + 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x52, 0x45, 0x41, 0x53, 0x4f, + 0x4e, 0x5f, 0x4d, 0x49, 0x53, 0x53, 0x10, 0x01, 0x12, 0x10, 0x0a, 0x0c, 0x52, 0x45, 0x41, 0x53, + 0x4f, 0x4e, 0x5f, 0x53, 0x54, 0x41, 0x4c, 0x45, 0x10, 0x02, 0x22, 0x5e, 0x0a, 0x13, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x18, 0x0a, 0x07, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x07, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x68, + 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x44, 0x61, 0x74, 0x61, 0x4a, 0x04, 0x08, 0x01, + 0x10, 0x02, 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x32, 0x6e, 0x0a, 0x12, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x12, 0x58, 0x0a, 0x0b, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x12, + 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x2e, 0x4b, 0x65, 0x79, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, - 0x06, 0x6b, 0x65, 0x79, 0x4d, 0x61, 0x70, 0x1a, 0x39, 0x0a, 0x0b, 0x4b, 0x65, 0x79, 0x4d, 0x61, - 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, - 0x38, 0x01, 0x22, 0x5e, 0x0a, 0x13, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, - 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x74, 0x61, 0x72, - 0x67, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x74, 0x61, 0x72, 0x67, - 0x65, 0x74, 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x64, 0x61, - 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, - 0x44, 0x61, 0x74, 0x61, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, - 0x65, 0x74, 0x32, 0x6e, 0x0a, 0x12, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, - 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x58, 0x0a, 0x0b, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x12, 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, - 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, - 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x72, - 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x42, 0x4d, 0x0a, 0x11, 0x69, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, - 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x42, 0x08, 0x52, 0x6c, 0x73, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, - 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x6c, 0x6f, 0x6f, 0x6b, - 0x75, 0x70, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x76, - 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, + 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x4d, 0x0a, 0x11, 0x69, 0x6f, + 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x42, + 0x08, 0x52, 0x6c, 0x73, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, + 0x72, 0x70, 0x63, 0x2f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, + 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( @@ -231,21 +319,24 @@ func file_grpc_lookup_v1_rls_proto_rawDescGZIP() []byte { return file_grpc_lookup_v1_rls_proto_rawDescData } +var file_grpc_lookup_v1_rls_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_grpc_lookup_v1_rls_proto_msgTypes = make([]protoimpl.MessageInfo, 3) var file_grpc_lookup_v1_rls_proto_goTypes = []interface{}{ - (*RouteLookupRequest)(nil), // 0: grpc.lookup.v1.RouteLookupRequest - (*RouteLookupResponse)(nil), // 1: grpc.lookup.v1.RouteLookupResponse - nil, // 2: grpc.lookup.v1.RouteLookupRequest.KeyMapEntry + (RouteLookupRequest_Reason)(0), // 0: grpc.lookup.v1.RouteLookupRequest.Reason + (*RouteLookupRequest)(nil), // 1: grpc.lookup.v1.RouteLookupRequest + (*RouteLookupResponse)(nil), // 2: grpc.lookup.v1.RouteLookupResponse + nil, // 3: grpc.lookup.v1.RouteLookupRequest.KeyMapEntry } var file_grpc_lookup_v1_rls_proto_depIdxs = []int32{ - 2, // 0: grpc.lookup.v1.RouteLookupRequest.key_map:type_name -> grpc.lookup.v1.RouteLookupRequest.KeyMapEntry - 0, // 1: grpc.lookup.v1.RouteLookupService.RouteLookup:input_type -> grpc.lookup.v1.RouteLookupRequest - 1, // 2: grpc.lookup.v1.RouteLookupService.RouteLookup:output_type -> grpc.lookup.v1.RouteLookupResponse - 2, // [2:3] is the sub-list for method output_type - 1, // [1:2] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 0, // 0: grpc.lookup.v1.RouteLookupRequest.reason:type_name -> grpc.lookup.v1.RouteLookupRequest.Reason + 3, // 1: grpc.lookup.v1.RouteLookupRequest.key_map:type_name -> grpc.lookup.v1.RouteLookupRequest.KeyMapEntry + 1, // 2: grpc.lookup.v1.RouteLookupService.RouteLookup:input_type -> grpc.lookup.v1.RouteLookupRequest + 2, // 3: grpc.lookup.v1.RouteLookupService.RouteLookup:output_type -> grpc.lookup.v1.RouteLookupResponse + 3, // [3:4] is the sub-list for method output_type + 2, // [2:3] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name } func init() { file_grpc_lookup_v1_rls_proto_init() } @@ -284,13 +375,14 @@ func file_grpc_lookup_v1_rls_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_grpc_lookup_v1_rls_proto_rawDesc, - NumEnums: 0, + NumEnums: 1, NumMessages: 3, NumExtensions: 0, NumServices: 1, }, GoTypes: file_grpc_lookup_v1_rls_proto_goTypes, DependencyIndexes: file_grpc_lookup_v1_rls_proto_depIdxs, + EnumInfos: file_grpc_lookup_v1_rls_proto_enumTypes, MessageInfos: file_grpc_lookup_v1_rls_proto_msgTypes, }.Build() File_grpc_lookup_v1_rls_proto = out.File diff --git a/balancer/rls/internal/proto/grpc_lookup_v1/rls_config.pb.go b/balancer/rls/internal/proto/grpc_lookup_v1/rls_config.pb.go index 6b0924b335f..414b74cdb3b 100644 --- a/balancer/rls/internal/proto/grpc_lookup_v1/rls_config.pb.go +++ b/balancer/rls/internal/proto/grpc_lookup_v1/rls_config.pb.go @@ -50,6 +50,9 @@ type NameMatcher struct { unknownFields protoimpl.UnknownFields // The name that will be used in the RLS key_map to refer to this value. + // If required_match is true, you may omit this field or set it to an empty + // string, in which case the matcher will require a match, but won't update + // the key_map. Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` // Ordered list of names (headers or query parameter names) that can supply // this value; the first one with a non-empty value is used. @@ -118,11 +121,17 @@ type GrpcKeyBuilder struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Names []*GrpcKeyBuilder_Name `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"` + Names []*GrpcKeyBuilder_Name `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"` + ExtraKeys *GrpcKeyBuilder_ExtraKeys `protobuf:"bytes,3,opt,name=extra_keys,json=extraKeys,proto3" json:"extra_keys,omitempty"` // Extract keys from all listed headers. // For gRPC, it is an error to specify "required_match" on the NameMatcher // protos. Headers []*NameMatcher `protobuf:"bytes,2,rep,name=headers,proto3" json:"headers,omitempty"` + // You can optionally set one or more specific key/value pairs to be added to + // the key_map. This can be useful to identify which builder built the key, + // for example if you are suppressing the actual method, but need to + // separately cache and request all the matched methods. + ConstantKeys map[string]string `protobuf:"bytes,4,rep,name=constant_keys,json=constantKeys,proto3" json:"constant_keys,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } func (x *GrpcKeyBuilder) Reset() { @@ -164,6 +173,13 @@ func (x *GrpcKeyBuilder) GetNames() []*GrpcKeyBuilder_Name { return nil } +func (x *GrpcKeyBuilder) GetExtraKeys() *GrpcKeyBuilder_ExtraKeys { + if x != nil { + return x.ExtraKeys + } + return nil +} + func (x *GrpcKeyBuilder) GetHeaders() []*NameMatcher { if x != nil { return x.Headers @@ -171,6 +187,13 @@ func (x *GrpcKeyBuilder) GetHeaders() []*NameMatcher { return nil } +func (x *GrpcKeyBuilder) GetConstantKeys() map[string]string { + if x != nil { + return x.ConstantKeys + } + return nil +} + // An HttpKeyBuilder applies to a given HTTP URL and headers. // // Path and host patterns use the matching syntax from gRPC transcoding to @@ -245,6 +268,11 @@ type HttpKeyBuilder struct { // to match. If a given header appears multiple times in the request we will // report it as a comma-separated string, in standard HTTP fashion. Headers []*NameMatcher `protobuf:"bytes,4,rep,name=headers,proto3" json:"headers,omitempty"` + // You can optionally set one or more specific key/value pairs to be added to + // the key_map. This can be useful to identify which builder built the key, + // for example if you are suppressing a lot of information from the URL, but + // need to separately cache and request URLs with that content. + ConstantKeys map[string]string `protobuf:"bytes,5,rep,name=constant_keys,json=constantKeys,proto3" json:"constant_keys,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } func (x *HttpKeyBuilder) Reset() { @@ -307,6 +335,13 @@ func (x *HttpKeyBuilder) GetHeaders() []*NameMatcher { return nil } +func (x *HttpKeyBuilder) GetConstantKeys() map[string]string { + if x != nil { + return x.ConstantKeys + } + return nil +} + type RouteLookupConfig struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -510,6 +545,75 @@ func (x *GrpcKeyBuilder_Name) GetMethod() string { return "" } +// If you wish to include the host, service, or method names as keys in the +// generated RouteLookupRequest, specify key names to use in the extra_keys +// submessage. If a key name is empty, no key will be set for that value. +// If this submessage is specified, the normal host/path fields will be left +// unset in the RouteLookupRequest. We are deprecating host/path in the +// RouteLookupRequest, so services should migrate to the ExtraKeys approach. +type GrpcKeyBuilder_ExtraKeys struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Host string `protobuf:"bytes,1,opt,name=host,proto3" json:"host,omitempty"` + Service string `protobuf:"bytes,2,opt,name=service,proto3" json:"service,omitempty"` + Method string `protobuf:"bytes,3,opt,name=method,proto3" json:"method,omitempty"` +} + +func (x *GrpcKeyBuilder_ExtraKeys) Reset() { + *x = GrpcKeyBuilder_ExtraKeys{} + if protoimpl.UnsafeEnabled { + mi := &file_grpc_lookup_v1_rls_config_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GrpcKeyBuilder_ExtraKeys) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GrpcKeyBuilder_ExtraKeys) ProtoMessage() {} + +func (x *GrpcKeyBuilder_ExtraKeys) ProtoReflect() protoreflect.Message { + mi := &file_grpc_lookup_v1_rls_config_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GrpcKeyBuilder_ExtraKeys.ProtoReflect.Descriptor instead. +func (*GrpcKeyBuilder_ExtraKeys) Descriptor() ([]byte, []int) { + return file_grpc_lookup_v1_rls_config_proto_rawDescGZIP(), []int{1, 1} +} + +func (x *GrpcKeyBuilder_ExtraKeys) GetHost() string { + if x != nil { + return x.Host + } + return "" +} + +func (x *GrpcKeyBuilder_ExtraKeys) GetService() string { + if x != nil { + return x.Service + } + return "" +} + +func (x *GrpcKeyBuilder_ExtraKeys) GetMethod() string { + if x != nil { + return x.Method + } + return "" +} + var File_grpc_lookup_v1_rls_config_proto protoreflect.FileDescriptor var file_grpc_lookup_v1_rls_config_proto_rawDesc = []byte{ @@ -524,72 +628,101 @@ var file_grpc_lookup_v1_rls_config_proto_rawDesc = []byte{ 0x09, 0x52, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x5f, 0x6d, 0x61, 0x74, 0x63, 0x68, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x22, - 0xbc, 0x01, 0x0a, 0x0e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, + 0xf0, 0x03, 0x0a, 0x0e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, - 0x72, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x52, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x35, 0x0a, - 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, - 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, - 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x52, 0x07, 0x68, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x73, 0x1a, 0x38, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, + 0x72, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x52, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x47, 0x0a, + 0x0a, 0x65, 0x78, 0x74, 0x72, 0x61, 0x5f, 0x6b, 0x65, 0x79, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x28, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, + 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, + 0x72, 0x2e, 0x45, 0x78, 0x74, 0x72, 0x61, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x09, 0x65, 0x78, 0x74, + 0x72, 0x61, 0x4b, 0x65, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, + 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, 0x74, + 0x63, 0x68, 0x65, 0x72, 0x52, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x12, 0x55, 0x0a, + 0x0d, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x5f, 0x6b, 0x65, 0x79, 0x73, 0x18, 0x04, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, + 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, + 0x6c, 0x64, 0x65, 0x72, 0x2e, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4b, 0x65, 0x79, + 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, + 0x4b, 0x65, 0x79, 0x73, 0x1a, 0x38, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x22, 0xd9, - 0x01, 0x0a, 0x0e, 0x48, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, - 0x72, 0x12, 0x23, 0x0a, 0x0d, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x74, 0x74, 0x65, 0x72, - 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x68, 0x6f, 0x73, 0x74, 0x50, 0x61, - 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x70, 0x61, 0x74, 0x68, 0x5f, 0x70, - 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x70, - 0x61, 0x74, 0x68, 0x50, 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x71, - 0x75, 0x65, 0x72, 0x79, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, - 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, - 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, - 0x65, 0x72, 0x52, 0x0f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, - 0x65, 0x72, 0x73, 0x12, 0x35, 0x0a, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, - 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x65, - 0x72, 0x52, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x22, 0xa6, 0x04, 0x0a, 0x11, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x49, 0x0a, 0x10, 0x68, 0x74, 0x74, 0x70, 0x5f, 0x6b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, - 0x64, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x72, 0x70, - 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x74, 0x74, 0x70, - 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x52, 0x0f, 0x68, 0x74, 0x74, 0x70, - 0x4b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x12, 0x49, 0x0a, 0x10, 0x67, - 0x72, 0x70, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, - 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, - 0x69, 0x6c, 0x64, 0x65, 0x72, 0x52, 0x0f, 0x67, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x62, 0x75, - 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, - 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, - 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x4f, 0x0a, - 0x16, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, - 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x14, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, - 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x32, - 0x0a, 0x07, 0x6d, 0x61, 0x78, 0x5f, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x6d, 0x61, 0x78, 0x41, - 0x67, 0x65, 0x12, 0x36, 0x0a, 0x09, 0x73, 0x74, 0x61, 0x6c, 0x65, 0x5f, 0x61, 0x67, 0x65, 0x18, - 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x52, 0x08, 0x73, 0x74, 0x61, 0x6c, 0x65, 0x41, 0x67, 0x65, 0x12, 0x28, 0x0a, 0x10, 0x63, 0x61, - 0x63, 0x68, 0x65, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x0e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x42, - 0x79, 0x74, 0x65, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x5f, 0x74, 0x61, - 0x72, 0x67, 0x65, 0x74, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x76, 0x61, 0x6c, - 0x69, 0x64, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x66, - 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0d, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, - 0x4a, 0x04, 0x08, 0x0a, 0x10, 0x0b, 0x52, 0x1b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, - 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x74, 0x72, 0x61, 0x74, - 0x65, 0x67, 0x79, 0x42, 0x53, 0x0a, 0x11, 0x69, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, - 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x42, 0x0e, 0x52, 0x6c, 0x73, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x72, - 0x70, 0x63, 0x2f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x6c, - 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x1a, 0x51, + 0x0a, 0x09, 0x45, 0x78, 0x74, 0x72, 0x61, 0x4b, 0x65, 0x79, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x68, + 0x6f, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, + 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, + 0x68, 0x6f, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, + 0x64, 0x1a, 0x3f, 0x0a, 0x11, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4b, 0x65, 0x79, + 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, + 0x38, 0x01, 0x22, 0xf1, 0x02, 0x0a, 0x0e, 0x48, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x42, 0x75, + 0x69, 0x6c, 0x64, 0x65, 0x72, 0x12, 0x23, 0x0a, 0x0d, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x70, 0x61, + 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x68, 0x6f, + 0x73, 0x74, 0x50, 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x70, 0x61, + 0x74, 0x68, 0x5f, 0x70, 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0c, 0x70, 0x61, 0x74, 0x68, 0x50, 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x12, + 0x46, 0x0a, 0x10, 0x71, 0x75, 0x65, 0x72, 0x79, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, + 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, + 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, + 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x52, 0x0f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x50, 0x61, 0x72, + 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, 0x35, 0x0a, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, + 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, + 0x74, 0x63, 0x68, 0x65, 0x72, 0x52, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x12, 0x55, + 0x0a, 0x0d, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x5f, 0x6b, 0x65, 0x79, 0x73, 0x18, + 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, + 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x42, 0x75, + 0x69, 0x6c, 0x64, 0x65, 0x72, 0x2e, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4b, 0x65, + 0x79, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, + 0x74, 0x4b, 0x65, 0x79, 0x73, 0x1a, 0x3f, 0x0a, 0x11, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, + 0x74, 0x4b, 0x65, 0x79, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xa6, 0x04, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x49, 0x0a, 0x10, + 0x68, 0x74, 0x74, 0x70, 0x5f, 0x6b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, + 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x42, + 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x52, 0x0f, 0x68, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x62, + 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x12, 0x49, 0x0a, 0x10, 0x67, 0x72, 0x70, 0x63, 0x5f, + 0x6b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, + 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, + 0x72, 0x52, 0x0f, 0x67, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, + 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6c, 0x6f, 0x6f, 0x6b, + 0x75, 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x4f, 0x0a, 0x16, 0x6c, 0x6f, 0x6f, + 0x6b, 0x75, 0x70, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, + 0x6f, 0x75, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x14, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, + 0x78, 0x5f, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, + 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x6d, 0x61, 0x78, 0x41, 0x67, 0x65, 0x12, 0x36, + 0x0a, 0x09, 0x73, 0x74, 0x61, 0x6c, 0x65, 0x5f, 0x61, 0x67, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x73, 0x74, + 0x61, 0x6c, 0x65, 0x41, 0x67, 0x65, 0x12, 0x28, 0x0a, 0x10, 0x63, 0x61, 0x63, 0x68, 0x65, 0x5f, + 0x73, 0x69, 0x7a, 0x65, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x0e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x42, 0x79, 0x74, 0x65, 0x73, + 0x12, 0x23, 0x0a, 0x0d, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x5f, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, + 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x54, 0x61, + 0x72, 0x67, 0x65, 0x74, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, + 0x5f, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, + 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4a, 0x04, 0x08, 0x0a, + 0x10, 0x0b, 0x52, 0x1b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x72, 0x6f, 0x63, + 0x65, 0x73, 0x73, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x42, + 0x53, 0x0a, 0x11, 0x69, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, + 0x70, 0x2e, 0x76, 0x31, 0x42, 0x0e, 0x52, 0x6c, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x50, + 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, + 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x6c, + 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, + 0x70, 0x5f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -604,30 +737,36 @@ func file_grpc_lookup_v1_rls_config_proto_rawDescGZIP() []byte { return file_grpc_lookup_v1_rls_config_proto_rawDescData } -var file_grpc_lookup_v1_rls_config_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_grpc_lookup_v1_rls_config_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_grpc_lookup_v1_rls_config_proto_goTypes = []interface{}{ - (*NameMatcher)(nil), // 0: grpc.lookup.v1.NameMatcher - (*GrpcKeyBuilder)(nil), // 1: grpc.lookup.v1.GrpcKeyBuilder - (*HttpKeyBuilder)(nil), // 2: grpc.lookup.v1.HttpKeyBuilder - (*RouteLookupConfig)(nil), // 3: grpc.lookup.v1.RouteLookupConfig - (*GrpcKeyBuilder_Name)(nil), // 4: grpc.lookup.v1.GrpcKeyBuilder.Name - (*durationpb.Duration)(nil), // 5: google.protobuf.Duration + (*NameMatcher)(nil), // 0: grpc.lookup.v1.NameMatcher + (*GrpcKeyBuilder)(nil), // 1: grpc.lookup.v1.GrpcKeyBuilder + (*HttpKeyBuilder)(nil), // 2: grpc.lookup.v1.HttpKeyBuilder + (*RouteLookupConfig)(nil), // 3: grpc.lookup.v1.RouteLookupConfig + (*GrpcKeyBuilder_Name)(nil), // 4: grpc.lookup.v1.GrpcKeyBuilder.Name + (*GrpcKeyBuilder_ExtraKeys)(nil), // 5: grpc.lookup.v1.GrpcKeyBuilder.ExtraKeys + nil, // 6: grpc.lookup.v1.GrpcKeyBuilder.ConstantKeysEntry + nil, // 7: grpc.lookup.v1.HttpKeyBuilder.ConstantKeysEntry + (*durationpb.Duration)(nil), // 8: google.protobuf.Duration } var file_grpc_lookup_v1_rls_config_proto_depIdxs = []int32{ - 4, // 0: grpc.lookup.v1.GrpcKeyBuilder.names:type_name -> grpc.lookup.v1.GrpcKeyBuilder.Name - 0, // 1: grpc.lookup.v1.GrpcKeyBuilder.headers:type_name -> grpc.lookup.v1.NameMatcher - 0, // 2: grpc.lookup.v1.HttpKeyBuilder.query_parameters:type_name -> grpc.lookup.v1.NameMatcher - 0, // 3: grpc.lookup.v1.HttpKeyBuilder.headers:type_name -> grpc.lookup.v1.NameMatcher - 2, // 4: grpc.lookup.v1.RouteLookupConfig.http_keybuilders:type_name -> grpc.lookup.v1.HttpKeyBuilder - 1, // 5: grpc.lookup.v1.RouteLookupConfig.grpc_keybuilders:type_name -> grpc.lookup.v1.GrpcKeyBuilder - 5, // 6: grpc.lookup.v1.RouteLookupConfig.lookup_service_timeout:type_name -> google.protobuf.Duration - 5, // 7: grpc.lookup.v1.RouteLookupConfig.max_age:type_name -> google.protobuf.Duration - 5, // 8: grpc.lookup.v1.RouteLookupConfig.stale_age:type_name -> google.protobuf.Duration - 9, // [9:9] is the sub-list for method output_type - 9, // [9:9] is the sub-list for method input_type - 9, // [9:9] is the sub-list for extension type_name - 9, // [9:9] is the sub-list for extension extendee - 0, // [0:9] is the sub-list for field type_name + 4, // 0: grpc.lookup.v1.GrpcKeyBuilder.names:type_name -> grpc.lookup.v1.GrpcKeyBuilder.Name + 5, // 1: grpc.lookup.v1.GrpcKeyBuilder.extra_keys:type_name -> grpc.lookup.v1.GrpcKeyBuilder.ExtraKeys + 0, // 2: grpc.lookup.v1.GrpcKeyBuilder.headers:type_name -> grpc.lookup.v1.NameMatcher + 6, // 3: grpc.lookup.v1.GrpcKeyBuilder.constant_keys:type_name -> grpc.lookup.v1.GrpcKeyBuilder.ConstantKeysEntry + 0, // 4: grpc.lookup.v1.HttpKeyBuilder.query_parameters:type_name -> grpc.lookup.v1.NameMatcher + 0, // 5: grpc.lookup.v1.HttpKeyBuilder.headers:type_name -> grpc.lookup.v1.NameMatcher + 7, // 6: grpc.lookup.v1.HttpKeyBuilder.constant_keys:type_name -> grpc.lookup.v1.HttpKeyBuilder.ConstantKeysEntry + 2, // 7: grpc.lookup.v1.RouteLookupConfig.http_keybuilders:type_name -> grpc.lookup.v1.HttpKeyBuilder + 1, // 8: grpc.lookup.v1.RouteLookupConfig.grpc_keybuilders:type_name -> grpc.lookup.v1.GrpcKeyBuilder + 8, // 9: grpc.lookup.v1.RouteLookupConfig.lookup_service_timeout:type_name -> google.protobuf.Duration + 8, // 10: grpc.lookup.v1.RouteLookupConfig.max_age:type_name -> google.protobuf.Duration + 8, // 11: grpc.lookup.v1.RouteLookupConfig.stale_age:type_name -> google.protobuf.Duration + 12, // [12:12] is the sub-list for method output_type + 12, // [12:12] is the sub-list for method input_type + 12, // [12:12] is the sub-list for extension type_name + 12, // [12:12] is the sub-list for extension extendee + 0, // [0:12] is the sub-list for field type_name } func init() { file_grpc_lookup_v1_rls_config_proto_init() } @@ -696,6 +835,18 @@ func file_grpc_lookup_v1_rls_config_proto_init() { return nil } } + file_grpc_lookup_v1_rls_config_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GrpcKeyBuilder_ExtraKeys); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -703,7 +854,7 @@ func file_grpc_lookup_v1_rls_config_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_grpc_lookup_v1_rls_config_proto_rawDesc, NumEnums: 0, - NumMessages: 5, + NumMessages: 8, NumExtensions: 0, NumServices: 0, }, diff --git a/balancer/rls/internal/proto/grpc_lookup_v1/rls_grpc.pb.go b/balancer/rls/internal/proto/grpc_lookup_v1/rls_grpc.pb.go index b469089ed57..39d79e13343 100644 --- a/balancer/rls/internal/proto/grpc_lookup_v1/rls_grpc.pb.go +++ b/balancer/rls/internal/proto/grpc_lookup_v1/rls_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/lookup/v1/rls.proto package grpc_lookup_v1 diff --git a/balancer/roundrobin/roundrobin.go b/balancer/roundrobin/roundrobin.go index 43c2a15373a..274eb2f8580 100644 --- a/balancer/roundrobin/roundrobin.go +++ b/balancer/roundrobin/roundrobin.go @@ -47,11 +47,11 @@ func init() { type rrPickerBuilder struct{} func (*rrPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { - logger.Infof("roundrobinPicker: newPicker called with info: %v", info) + logger.Infof("roundrobinPicker: Build called with info: %v", info) if len(info.ReadySCs) == 0 { return base.NewErrPicker(balancer.ErrNoSubConnAvailable) } - var scs []balancer.SubConn + scs := make([]balancer.SubConn, 0, len(info.ReadySCs)) for sc := range info.ReadySCs { scs = append(scs, sc) } diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index 41061d6d3dc..f4ea6174682 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -37,14 +37,20 @@ type scStateUpdate struct { err error } +// exitIdle contains no data and is just a signal sent on the updateCh in +// ccBalancerWrapper to instruct the balancer to exit idle. +type exitIdle struct{} + // ccBalancerWrapper is a wrapper on top of cc for balancers. // It implements balancer.ClientConn interface. type ccBalancerWrapper struct { - cc *ClientConn - balancerMu sync.Mutex // synchronizes calls to the balancer - balancer balancer.Balancer - scBuffer *buffer.Unbounded - done *grpcsync.Event + cc *ClientConn + balancerMu sync.Mutex // synchronizes calls to the balancer + balancer balancer.Balancer + hasExitIdle bool + updateCh *buffer.Unbounded + closed *grpcsync.Event + done *grpcsync.Event mu sync.Mutex subConns map[*acBalancerWrapper]struct{} @@ -53,12 +59,14 @@ type ccBalancerWrapper struct { func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper { ccb := &ccBalancerWrapper{ cc: cc, - scBuffer: buffer.NewUnbounded(), + updateCh: buffer.NewUnbounded(), + closed: grpcsync.NewEvent(), done: grpcsync.NewEvent(), subConns: make(map[*acBalancerWrapper]struct{}), } go ccb.watcher() ccb.balancer = b.Build(ccb, bopts) + _, ccb.hasExitIdle = ccb.balancer.(balancer.ExitIdler) return ccb } @@ -67,35 +75,72 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui func (ccb *ccBalancerWrapper) watcher() { for { select { - case t := <-ccb.scBuffer.Get(): - ccb.scBuffer.Load() - if ccb.done.HasFired() { + case t := <-ccb.updateCh.Get(): + ccb.updateCh.Load() + if ccb.closed.HasFired() { break } - ccb.balancerMu.Lock() - su := t.(*scStateUpdate) - ccb.balancer.UpdateSubConnState(su.sc, balancer.SubConnState{ConnectivityState: su.state, ConnectionError: su.err}) - ccb.balancerMu.Unlock() - case <-ccb.done.Done(): + switch u := t.(type) { + case *scStateUpdate: + ccb.balancerMu.Lock() + ccb.balancer.UpdateSubConnState(u.sc, balancer.SubConnState{ConnectivityState: u.state, ConnectionError: u.err}) + ccb.balancerMu.Unlock() + case *acBalancerWrapper: + ccb.mu.Lock() + if ccb.subConns != nil { + delete(ccb.subConns, u) + ccb.cc.removeAddrConn(u.getAddrConn(), errConnDrain) + } + ccb.mu.Unlock() + case exitIdle: + if ccb.cc.GetState() == connectivity.Idle { + if ei, ok := ccb.balancer.(balancer.ExitIdler); ok { + // We already checked that the balancer implements + // ExitIdle before pushing the event to updateCh, but + // check conditionally again as defensive programming. + ccb.balancerMu.Lock() + ei.ExitIdle() + ccb.balancerMu.Unlock() + } + } + default: + logger.Errorf("ccBalancerWrapper.watcher: unknown update %+v, type %T", t, t) + } + case <-ccb.closed.Done(): } - if ccb.done.HasFired() { + if ccb.closed.HasFired() { + ccb.balancerMu.Lock() ccb.balancer.Close() + ccb.balancerMu.Unlock() ccb.mu.Lock() scs := ccb.subConns ccb.subConns = nil ccb.mu.Unlock() + ccb.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: nil}) + ccb.done.Fire() + // Fire done before removing the addr conns. We can safely unblock + // ccb.close and allow the removeAddrConns to happen + // asynchronously. for acbw := range scs { ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) } - ccb.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: nil}) return } } } func (ccb *ccBalancerWrapper) close() { - ccb.done.Fire() + ccb.closed.Fire() + <-ccb.done.Done() +} + +func (ccb *ccBalancerWrapper) exitIdle() bool { + if !ccb.hasExitIdle { + return false + } + ccb.updateCh.Put(exitIdle{}) + return true } func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State, err error) { @@ -109,7 +154,7 @@ func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s co if sc == nil { return } - ccb.scBuffer.Put(&scStateUpdate{ + ccb.updateCh.Put(&scStateUpdate{ sc: sc, state: s, err: err, @@ -124,8 +169,8 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat func (ccb *ccBalancerWrapper) resolverError(err error) { ccb.balancerMu.Lock() + defer ccb.balancerMu.Unlock() ccb.balancer.ResolverError(err) - ccb.balancerMu.Unlock() } func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { @@ -150,17 +195,10 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer } func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { - acbw, ok := sc.(*acBalancerWrapper) - if !ok { - return - } - ccb.mu.Lock() - defer ccb.mu.Unlock() - if ccb.subConns == nil { - return - } - delete(ccb.subConns, acbw) - ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) + // The RemoveSubConn() is handled in the run() goroutine, to avoid deadlock + // during switchBalancer() if the old balancer calls RemoveSubConn() in its + // Close(). + ccb.updateCh.Put(sc) } func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { @@ -205,7 +243,7 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { acbw.mu.Lock() defer acbw.mu.Unlock() if len(addrs) <= 0 { - acbw.ac.tearDown(errConnDrain) + acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain) return } if !acbw.ac.tryUpdateAddrs(addrs) { @@ -220,23 +258,23 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { acbw.ac.acbw = nil acbw.ac.mu.Unlock() acState := acbw.ac.getState() - acbw.ac.tearDown(errConnDrain) + acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain) if acState == connectivity.Shutdown { return } - ac, err := cc.newAddrConn(addrs, opts) + newAC, err := cc.newAddrConn(addrs, opts) if err != nil { channelz.Warningf(logger, acbw.ac.channelzID, "acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err) return } - acbw.ac = ac - ac.mu.Lock() - ac.acbw = acbw - ac.mu.Unlock() + acbw.ac = newAC + newAC.mu.Lock() + newAC.acbw = acbw + newAC.mu.Unlock() if acState != connectivity.Idle { - ac.connect() + go newAC.connect() } } } @@ -244,7 +282,7 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { func (acbw *acBalancerWrapper) Connect() { acbw.mu.Lock() defer acbw.mu.Unlock() - acbw.ac.connect() + go acbw.ac.connect() } func (acbw *acBalancerWrapper) getAddrConn() *addrConn { diff --git a/balancer_conn_wrappers_test.go b/balancer_conn_wrappers_test.go deleted file mode 100644 index 935d11d1d39..00000000000 --- a/balancer_conn_wrappers_test.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package grpc - -import ( - "fmt" - "net" - "testing" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/roundrobin" - "google.golang.org/grpc/internal/balancer/stub" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/resolver/manual" -) - -// TestBalancerErrorResolverPolling injects balancer errors and verifies -// ResolveNow is called on the resolver with the appropriate backoff strategy -// being consulted between ResolveNow calls. -func (s) TestBalancerErrorResolverPolling(t *testing.T) { - // The test balancer will return ErrBadResolverState iff the - // ClientConnState contains no addresses. - bf := stub.BalancerFuncs{ - UpdateClientConnState: func(_ *stub.BalancerData, s balancer.ClientConnState) error { - if len(s.ResolverState.Addresses) == 0 { - return balancer.ErrBadResolverState - } - return nil - }, - } - const balName = "BalancerErrorResolverPolling" - stub.Register(balName, bf) - - testResolverErrorPolling(t, - func(r *manual.Resolver) { - // No addresses so the balancer will fail. - r.CC.UpdateState(resolver.State{}) - }, func(r *manual.Resolver) { - // UpdateState will block if ResolveNow is being called (which blocks on - // rn), so call it in a goroutine. Include some address so the balancer - // will be happy. - go r.CC.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: "x"}}}) - }, - WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, balName))) -} - -// TestRoundRobinZeroAddressesResolverPolling reports no addresses to the round -// robin balancer and verifies ResolveNow is called on the resolver with the -// appropriate backoff strategy being consulted between ResolveNow calls. -func (s) TestRoundRobinZeroAddressesResolverPolling(t *testing.T) { - // We need to start a real server or else the connecting loop will call - // ResolveNow after every iteration, even after a valid resolver result is - // returned. - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Error while listening. Err: %v", err) - } - defer lis.Close() - s := NewServer() - defer s.Stop() - go s.Serve(lis) - - testResolverErrorPolling(t, - func(r *manual.Resolver) { - // No addresses so the balancer will fail. - r.CC.UpdateState(resolver.State{}) - }, func(r *manual.Resolver) { - // UpdateState will block if ResolveNow is being called (which - // blocks on rn), so call it in a goroutine. Include a valid - // address so the balancer will be happy. - go r.CC.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) - }, - WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, roundrobin.Name))) -} diff --git a/balancer_switching_test.go b/balancer_switching_test.go index 2c6ed576620..5d9a1f9fffc 100644 --- a/balancer_switching_test.go +++ b/balancer_switching_test.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/serviceconfig" @@ -57,6 +58,8 @@ func (b *magicalLB) UpdateClientConnState(balancer.ClientConnState) error { func (b *magicalLB) Close() {} +func (b *magicalLB) ExitIdle() {} + func init() { balancer.Register(&magicalLB{}) } @@ -531,6 +534,51 @@ func (s) TestSwitchBalancerGRPCLBWithGRPCLBNotRegistered(t *testing.T) { } } +const inlineRemoveSubConnBalancerName = "test-inline-remove-subconn-balancer" + +func init() { + stub.Register(inlineRemoveSubConnBalancerName, stub.BalancerFuncs{ + Close: func(data *stub.BalancerData) { + data.ClientConn.RemoveSubConn(&acBalancerWrapper{}) + }, + }) +} + +// Test that when switching to balancers, the old balancer calls RemoveSubConn +// in Close. +// +// This test is to make sure this close doesn't cause a deadlock. +func (s) TestSwitchBalancerOldRemoveSubConn(t *testing.T) { + r := manual.NewBuilderWithScheme("whatever") + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithResolvers(r)) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + cc.updateResolverState(resolver.State{ServiceConfig: parseCfg(r, fmt.Sprintf(`{"loadBalancingPolicy": "%v"}`, inlineRemoveSubConnBalancerName))}, nil) + // This service config update will switch balancer from + // "test-inline-remove-subconn-balancer" to "pick_first". The test balancer + // will be closed, which will call cc.RemoveSubConn() inline (this + // RemoveSubConn is not required by the API, but some balancers might do + // it). + // + // This is to make sure the cc.RemoveSubConn() from Close() doesn't cause a + // deadlock (e.g. trying to grab a mutex while it's already locked). + // + // Do it in a goroutine so this test will fail with a helpful message + // (though the goroutine will still leak). + done := make(chan struct{}) + go func() { + cc.updateResolverState(resolver.State{ServiceConfig: parseCfg(r, `{"loadBalancingPolicy": "pick_first"}`)}, nil) + close(done) + }() + select { + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for updateResolverState to finish") + case <-done: + } +} + func parseCfg(r *manual.Resolver, s string) *serviceconfig.ParseResult { scpr := r.CC.ParseServiceConfig(s) if scpr.Err != nil { diff --git a/benchmark/stats/histogram.go b/benchmark/stats/histogram.go index f038d26ed0a..461135f0125 100644 --- a/benchmark/stats/histogram.go +++ b/benchmark/stats/histogram.go @@ -118,10 +118,6 @@ func (h *Histogram) PrintWithUnit(w io.Writer, unit float64) { } maxBucketDigitLen := len(strconv.FormatFloat(h.Buckets[len(h.Buckets)-1].LowBound, 'f', 6, 64)) - if maxBucketDigitLen < 3 { - // For "inf". - maxBucketDigitLen = 3 - } maxCountDigitLen := len(strconv.FormatInt(h.Count, 10)) percentMulti := 100 / float64(h.Count) @@ -131,9 +127,9 @@ func (h *Histogram) PrintWithUnit(w io.Writer, unit float64) { if i+1 < len(h.Buckets) { fmt.Fprintf(w, "%*f)", maxBucketDigitLen, h.Buckets[i+1].LowBound/unit) } else { - fmt.Fprintf(w, "%*s)", maxBucketDigitLen, "inf") + upperBound := float64(h.opts.MinValue) + (b.LowBound-float64(h.opts.MinValue))*(1.0+h.opts.GrowthFactor) + fmt.Fprintf(w, "%*f)", maxBucketDigitLen, upperBound/unit) } - accCount += b.Count fmt.Fprintf(w, " %*d %5.1f%% %5.1f%%", maxCountDigitLen, b.Count, float64(b.Count)*percentMulti, float64(accCount)*percentMulti) @@ -188,6 +184,9 @@ func (h *Histogram) Add(value int64) error { func (h *Histogram) findBucket(value int64) (int, error) { delta := float64(value - h.opts.MinValue) + if delta < 0 { + return 0, fmt.Errorf("no bucket for value: %d", value) + } var b int if delta >= h.opts.BaseBucketSize { // b = log_{1+growthFactor} (delta / baseBucketSize) + 1 diff --git a/call_test.go b/call_test.go index abc4537ddb7..8fdbc9b7eb7 100644 --- a/call_test.go +++ b/call_test.go @@ -160,7 +160,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) { config := &transport.ServerConfig{ MaxStreams: maxStreams, } - st, err := transport.NewServerTransport("http2", conn, config) + st, err := transport.NewServerTransport(conn, config) if err != nil { continue } diff --git a/channelz/grpc_channelz_v1/channelz_grpc.pb.go b/channelz/grpc_channelz_v1/channelz_grpc.pb.go index 051d1ac440c..ee425c21994 100644 --- a/channelz/grpc_channelz_v1/channelz_grpc.pb.go +++ b/channelz/grpc_channelz_v1/channelz_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/channelz/v1/channelz.proto package grpc_channelz_v1 diff --git a/channelz/service/func_linux.go b/channelz/service/func_linux.go index ce38a921b97..2e52d5f5a98 100644 --- a/channelz/service/func_linux.go +++ b/channelz/service/func_linux.go @@ -25,6 +25,7 @@ import ( durpb "github.com/golang/protobuf/ptypes/duration" channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/testutils" ) func convertToPtypesDuration(sec int64, usec int64) *durpb.Duration { @@ -34,41 +35,32 @@ func convertToPtypesDuration(sec int64, usec int64) *durpb.Duration { func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOption { var opts []*channelzpb.SocketOption if skopts.Linger != nil { - additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionLinger{ - Active: skopts.Linger.Onoff != 0, - Duration: convertToPtypesDuration(int64(skopts.Linger.Linger), 0), + opts = append(opts, &channelzpb.SocketOption{ + Name: "SO_LINGER", + Additional: testutils.MarshalAny(&channelzpb.SocketOptionLinger{ + Active: skopts.Linger.Onoff != 0, + Duration: convertToPtypesDuration(int64(skopts.Linger.Linger), 0), + }), }) - if err == nil { - opts = append(opts, &channelzpb.SocketOption{ - Name: "SO_LINGER", - Additional: additional, - }) - } } if skopts.RecvTimeout != nil { - additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{ - Duration: convertToPtypesDuration(int64(skopts.RecvTimeout.Sec), int64(skopts.RecvTimeout.Usec)), + opts = append(opts, &channelzpb.SocketOption{ + Name: "SO_RCVTIMEO", + Additional: testutils.MarshalAny(&channelzpb.SocketOptionTimeout{ + Duration: convertToPtypesDuration(int64(skopts.RecvTimeout.Sec), int64(skopts.RecvTimeout.Usec)), + }), }) - if err == nil { - opts = append(opts, &channelzpb.SocketOption{ - Name: "SO_RCVTIMEO", - Additional: additional, - }) - } } if skopts.SendTimeout != nil { - additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{ - Duration: convertToPtypesDuration(int64(skopts.SendTimeout.Sec), int64(skopts.SendTimeout.Usec)), + opts = append(opts, &channelzpb.SocketOption{ + Name: "SO_SNDTIMEO", + Additional: testutils.MarshalAny(&channelzpb.SocketOptionTimeout{ + Duration: convertToPtypesDuration(int64(skopts.SendTimeout.Sec), int64(skopts.SendTimeout.Usec)), + }), }) - if err == nil { - opts = append(opts, &channelzpb.SocketOption{ - Name: "SO_SNDTIMEO", - Additional: additional, - }) - } } if skopts.TCPInfo != nil { - additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTcpInfo{ + additional := testutils.MarshalAny(&channelzpb.SocketOptionTcpInfo{ TcpiState: uint32(skopts.TCPInfo.State), TcpiCaState: uint32(skopts.TCPInfo.Ca_state), TcpiRetransmits: uint32(skopts.TCPInfo.Retransmits), @@ -99,12 +91,10 @@ func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOptio TcpiAdvmss: skopts.TCPInfo.Advmss, TcpiReordering: skopts.TCPInfo.Reordering, }) - if err == nil { - opts = append(opts, &channelzpb.SocketOption{ - Name: "TCP_INFO", - Additional: additional, - }) - } + opts = append(opts, &channelzpb.SocketOption{ + Name: "TCP_INFO", + Additional: additional, + }) } return opts } diff --git a/channelz/service/func_nonlinux.go b/channelz/service/func_nonlinux.go index eb53334ed0d..473495d6655 100644 --- a/channelz/service/func_nonlinux.go +++ b/channelz/service/func_nonlinux.go @@ -1,4 +1,5 @@ -// +build !linux appengine +//go:build !linux +// +build !linux /* * diff --git a/channelz/service/service.go b/channelz/service/service.go index 4d175fef823..9e325376f6c 100644 --- a/channelz/service/service.go +++ b/channelz/service/service.go @@ -43,7 +43,11 @@ func init() { var logger = grpclog.Component("channelz") // RegisterChannelzServiceToServer registers the channelz service to the given server. -func RegisterChannelzServiceToServer(s *grpc.Server) { +// +// Note: it is preferred to use the admin API +// (https://pkg.go.dev/google.golang.org/grpc/admin#Register) instead to +// register Channelz and other administrative services. +func RegisterChannelzServiceToServer(s grpc.ServiceRegistrar) { channelzgrpc.RegisterChannelzServer(s, newCZServer()) } @@ -78,7 +82,7 @@ func channelTraceToProto(ct *channelz.ChannelTrace) *channelzpb.ChannelTrace { if ts, err := ptypes.TimestampProto(ct.CreationTime); err == nil { pbt.CreationTimestamp = ts } - var events []*channelzpb.ChannelTraceEvent + events := make([]*channelzpb.ChannelTraceEvent, 0, len(ct.Events)) for _, e := range ct.Events { cte := &channelzpb.ChannelTraceEvent{ Description: e.Desc, diff --git a/channelz/service/service_sktopt_test.go b/channelz/service/service_sktopt_test.go index ecd4a2ad05f..4ea6b20cd6a 100644 --- a/channelz/service/service_sktopt_test.go +++ b/channelz/service/service_sktopt_test.go @@ -1,3 +1,4 @@ +//go:build linux && (386 || amd64) // +build linux // +build 386 amd64 diff --git a/channelz/service/util_sktopt_386_test.go b/channelz/service/util_sktopt_386_test.go index d9c98127136..3ba3dc96e7c 100644 --- a/channelz/service/util_sktopt_386_test.go +++ b/channelz/service/util_sktopt_386_test.go @@ -1,3 +1,4 @@ +//go:build 386 && linux // +build 386,linux /* diff --git a/channelz/service/util_sktopt_amd64_test.go b/channelz/service/util_sktopt_amd64_test.go index 0ff06d12833..124d7b75819 100644 --- a/channelz/service/util_sktopt_amd64_test.go +++ b/channelz/service/util_sktopt_amd64_test.go @@ -1,3 +1,4 @@ +//go:build amd64 && linux // +build amd64,linux /* diff --git a/clientconn.go b/clientconn.go index 77a08fd33bf..34cc4c948db 100644 --- a/clientconn.go +++ b/clientconn.go @@ -143,6 +143,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * firstResolveEvent: grpcsync.NewEvent(), } cc.retryThrottler.Store((*retryThrottler)(nil)) + cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{nil}) cc.ctx, cc.cancel = context.WithCancel(context.Background()) for _, opt := range opts { @@ -321,6 +322,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * // A blocking dial blocks until the clientConn is ready. if cc.dopts.block { for { + cc.Connect() s := cc.GetState() if s == connectivity.Ready { break @@ -538,12 +540,31 @@ func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connec // // Experimental // -// Notice: This API is EXPERIMENTAL and may be changed or removed in a -// later release. +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. func (cc *ClientConn) GetState() connectivity.State { return cc.csMgr.getState() } +// Connect causes all subchannels in the ClientConn to attempt to connect if +// the channel is idle. Does not wait for the connection attempts to begin +// before returning. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +func (cc *ClientConn) Connect() { + cc.mu.Lock() + defer cc.mu.Unlock() + if cc.balancerWrapper != nil && cc.balancerWrapper.exitIdle() { + return + } + for ac := range cc.conns { + go ac.connect() + } +} + func (cc *ClientConn) scWatcher() { for { select { @@ -710,7 +731,12 @@ func (cc *ClientConn) switchBalancer(name string) { return } if cc.balancerWrapper != nil { + // Don't hold cc.mu while closing the balancers. The balancers may call + // methods that require cc.mu (e.g. cc.NewSubConn()). Holding the mutex + // would cause a deadlock in that case. + cc.mu.Unlock() cc.balancerWrapper.close() + cc.mu.Lock() } builder := balancer.Get(name) @@ -839,8 +865,7 @@ func (ac *addrConn) connect() error { ac.updateConnectivityState(connectivity.Connecting, nil) ac.mu.Unlock() - // Start a goroutine connecting to the server asynchronously. - go ac.resetTransport() + ac.resetTransport() return nil } @@ -877,6 +902,10 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool { // ac.state is Ready, try to find the connected address. var curAddrFound bool for _, a := range addrs { + // a.ServerName takes precedent over ClientConn authority, if present. + if a.ServerName == "" { + a.ServerName = ac.cc.authority + } if reflect.DeepEqual(ac.curAddr, a) { curAddrFound = true break @@ -1045,12 +1074,12 @@ func (cc *ClientConn) Close() error { cc.blockingpicker.close() - if rWrapper != nil { - rWrapper.close() - } if bWrapper != nil { bWrapper.close() } + if rWrapper != nil { + rWrapper.close() + } for ac := range conns { ac.tearDown(ErrClientConnClosing) @@ -1129,112 +1158,86 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) { } func (ac *addrConn) resetTransport() { - for i := 0; ; i++ { - if i > 0 { - ac.cc.resolveNow(resolver.ResolveNowOptions{}) - } + ac.mu.Lock() + if ac.state == connectivity.Shutdown { + ac.mu.Unlock() + return + } + addrs := ac.addrs + backoffFor := ac.dopts.bs.Backoff(ac.backoffIdx) + // This will be the duration that dial gets to finish. + dialDuration := minConnectTimeout + if ac.dopts.minConnectTimeout != nil { + dialDuration = ac.dopts.minConnectTimeout() + } + + if dialDuration < backoffFor { + // Give dial more time as we keep failing to connect. + dialDuration = backoffFor + } + // We can potentially spend all the time trying the first address, and + // if the server accepts the connection and then hangs, the following + // addresses will never be tried. + // + // The spec doesn't mention what should be done for multiple addresses. + // https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md#proposed-backoff-algorithm + connectDeadline := time.Now().Add(dialDuration) + + ac.updateConnectivityState(connectivity.Connecting, nil) + ac.mu.Unlock() + + if err := ac.tryAllAddrs(addrs, connectDeadline); err != nil { + ac.cc.resolveNow(resolver.ResolveNowOptions{}) + // After exhausting all addresses, the addrConn enters + // TRANSIENT_FAILURE. ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() return } + ac.updateConnectivityState(connectivity.TransientFailure, err) - addrs := ac.addrs - backoffFor := ac.dopts.bs.Backoff(ac.backoffIdx) - // This will be the duration that dial gets to finish. - dialDuration := minConnectTimeout - if ac.dopts.minConnectTimeout != nil { - dialDuration = ac.dopts.minConnectTimeout() - } - - if dialDuration < backoffFor { - // Give dial more time as we keep failing to connect. - dialDuration = backoffFor - } - // We can potentially spend all the time trying the first address, and - // if the server accepts the connection and then hangs, the following - // addresses will never be tried. - // - // The spec doesn't mention what should be done for multiple addresses. - // https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md#proposed-backoff-algorithm - connectDeadline := time.Now().Add(dialDuration) - - ac.updateConnectivityState(connectivity.Connecting, nil) - ac.transport = nil + // Backoff. + b := ac.resetBackoff ac.mu.Unlock() - newTr, addr, reconnect, err := ac.tryAllAddrs(addrs, connectDeadline) - if err != nil { - // After exhausting all addresses, the addrConn enters - // TRANSIENT_FAILURE. + timer := time.NewTimer(backoffFor) + select { + case <-timer.C: ac.mu.Lock() - if ac.state == connectivity.Shutdown { - ac.mu.Unlock() - return - } - ac.updateConnectivityState(connectivity.TransientFailure, err) - - // Backoff. - b := ac.resetBackoff + ac.backoffIdx++ ac.mu.Unlock() - - timer := time.NewTimer(backoffFor) - select { - case <-timer.C: - ac.mu.Lock() - ac.backoffIdx++ - ac.mu.Unlock() - case <-b: - timer.Stop() - case <-ac.ctx.Done(): - timer.Stop() - return - } - continue + case <-b: + timer.Stop() + case <-ac.ctx.Done(): + timer.Stop() + return } ac.mu.Lock() - if ac.state == connectivity.Shutdown { - ac.mu.Unlock() - newTr.Close() - return + if ac.state != connectivity.Shutdown { + ac.updateConnectivityState(connectivity.Idle, err) } - ac.curAddr = addr - ac.transport = newTr - ac.backoffIdx = 0 - - hctx, hcancel := context.WithCancel(ac.ctx) - ac.startHealthCheck(hctx) ac.mu.Unlock() - - // Block until the created transport is down. And when this happens, - // we restart from the top of the addr list. - <-reconnect.Done() - hcancel() - // restart connecting - the top of the loop will set state to - // CONNECTING. This is against the current connectivity semantics doc, - // however it allows for graceful behavior for RPCs not yet dispatched - // - unfortunate timing would otherwise lead to the RPC failing even - // though the TRANSIENT_FAILURE state (called for by the doc) would be - // instantaneous. - // - // Ideally we should transition to Idle here and block until there is - // RPC activity that leads to the balancer requesting a reconnect of - // the associated SubConn. + return } + // Success; reset backoff. + ac.mu.Lock() + ac.backoffIdx = 0 + ac.mu.Unlock() } -// tryAllAddrs tries to creates a connection to the addresses, and stop when at the -// first successful one. It returns the transport, the address and a Event in -// the successful case. The Event fires when the returned transport disconnects. -func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.Time) (transport.ClientTransport, resolver.Address, *grpcsync.Event, error) { +// tryAllAddrs tries to creates a connection to the addresses, and stop when at +// the first successful one. It returns an error if no address was successfully +// connected, or updates ac appropriately with the new transport. +func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.Time) error { var firstConnErr error for _, addr := range addrs { ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() - return nil, resolver.Address{}, nil, errConnClosing + return errConnClosing } ac.cc.mu.RLock() @@ -1249,9 +1252,9 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T channelz.Infof(logger, ac.channelzID, "Subchannel picks a new address %q to connect", addr.Addr) - newTr, reconnect, err := ac.createTransport(addr, copts, connectDeadline) + err := ac.createTransport(addr, copts, connectDeadline) if err == nil { - return newTr, addr, reconnect, nil + return nil } if firstConnErr == nil { firstConnErr = err @@ -1260,57 +1263,54 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T } // Couldn't connect to any address. - return nil, resolver.Address{}, nil, firstConnErr + return firstConnErr } -// createTransport creates a connection to addr. It returns the transport and a -// Event in the successful case. The Event fires when the returned transport -// disconnects. -func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) (transport.ClientTransport, *grpcsync.Event, error) { - prefaceReceived := make(chan struct{}) - onCloseCalled := make(chan struct{}) - reconnect := grpcsync.NewEvent() +// createTransport creates a connection to addr. It returns an error if the +// address was not successfully connected, or updates ac appropriately with the +// new transport. +func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error { + // TODO: Delete prefaceReceived and move the logic to wait for it into the + // transport. + prefaceReceived := grpcsync.NewEvent() + connClosed := grpcsync.NewEvent() // addr.ServerName takes precedent over ClientConn authority, if present. if addr.ServerName == "" { addr.ServerName = ac.cc.authority } - once := sync.Once{} - onGoAway := func(r transport.GoAwayReason) { - ac.mu.Lock() - ac.adjustParams(r) - once.Do(func() { - if ac.state == connectivity.Ready { - // Prevent this SubConn from being used for new RPCs by setting its - // state to Connecting. - // - // TODO: this should be Idle when grpc-go properly supports it. - ac.updateConnectivityState(connectivity.Connecting, nil) - } - }) - ac.mu.Unlock() - reconnect.Fire() - } + hctx, hcancel := context.WithCancel(ac.ctx) + hcStarted := false // protected by ac.mu onClose := func() { ac.mu.Lock() - once.Do(func() { - if ac.state == connectivity.Ready { - // Prevent this SubConn from being used for new RPCs by setting its - // state to Connecting. - // - // TODO: this should be Idle when grpc-go properly supports it. - ac.updateConnectivityState(connectivity.Connecting, nil) - } - }) - ac.mu.Unlock() - close(onCloseCalled) - reconnect.Fire() + defer ac.mu.Unlock() + defer connClosed.Fire() + if !hcStarted || hctx.Err() != nil { + // We didn't start the health check or set the state to READY, so + // no need to do anything else here. + // + // OR, we have already cancelled the health check context, meaning + // we have already called onClose once for this transport. In this + // case it would be dangerous to clear the transport and update the + // state, since there may be a new transport in this addrConn. + return + } + hcancel() + ac.transport = nil + // Refresh the name resolver + ac.cc.resolveNow(resolver.ResolveNowOptions{}) + if ac.state != connectivity.Shutdown { + ac.updateConnectivityState(connectivity.Idle, nil) + } } - onPrefaceReceipt := func() { - close(prefaceReceived) + onGoAway := func(r transport.GoAwayReason) { + ac.mu.Lock() + ac.adjustParams(r) + ac.mu.Unlock() + onClose() } connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline) @@ -1319,27 +1319,67 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne copts.ChannelzParentID = ac.channelzID } - newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onPrefaceReceipt, onGoAway, onClose) + newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, func() { prefaceReceived.Fire() }, onGoAway, onClose) if err != nil { // newTr is either nil, or closed. - channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v. Err: %v. Reconnecting...", addr, err) - return nil, nil, err + channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v. Err: %v", addr, err) + return err } select { - case <-time.After(time.Until(connectDeadline)): + case <-connectCtx.Done(): // We didn't get the preface in time. - newTr.Close() - channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v: didn't receive server preface in time. Reconnecting...", addr) - return nil, nil, errors.New("timed out waiting for server handshake") - case <-prefaceReceived: + // The error we pass to Close() is immaterial since there are no open + // streams at this point, so no trailers with error details will be sent + // out. We just need to pass a non-nil error. + newTr.Close(transport.ErrConnClosing) + if connectCtx.Err() == context.DeadlineExceeded { + err := errors.New("failed to receive server preface within timeout") + channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v: %v", addr, err) + return err + } + return nil + case <-prefaceReceived.Done(): // We got the preface - huzzah! things are good. - case <-onCloseCalled: - // The transport has already closed - noop. - return nil, nil, errors.New("connection closed") - // TODO(deklerk) this should bail on ac.ctx.Done(). Add a test and fix. + ac.mu.Lock() + defer ac.mu.Unlock() + if connClosed.HasFired() { + // onClose called first; go idle but do nothing else. + if ac.state != connectivity.Shutdown { + ac.updateConnectivityState(connectivity.Idle, nil) + } + return nil + } + if ac.state == connectivity.Shutdown { + // This can happen if the subConn was removed while in `Connecting` + // state. tearDown() would have set the state to `Shutdown`, but + // would not have closed the transport since ac.transport would not + // been set at that point. + // + // We run this in a goroutine because newTr.Close() calls onClose() + // inline, which requires locking ac.mu. + // + // The error we pass to Close() is immaterial since there are no open + // streams at this point, so no trailers with error details will be sent + // out. We just need to pass a non-nil error. + go newTr.Close(transport.ErrConnClosing) + return nil + } + ac.curAddr = addr + ac.transport = newTr + hcStarted = true + ac.startHealthCheck(hctx) // Will set state to READY if appropriate. + return nil + case <-connClosed.Done(): + // The transport has already closed. If we received the preface, too, + // this is not an error. + select { + case <-prefaceReceived.Done(): + return nil + default: + return errors.New("connection closed before server preface received") + } } - return newTr, reconnect, nil } // startHealthCheck starts the health checking stream (RPC) to watch the health @@ -1423,33 +1463,20 @@ func (ac *addrConn) resetConnectBackoff() { ac.mu.Unlock() } -// getReadyTransport returns the transport if ac's state is READY. -// Otherwise it returns nil, false. -// If ac's state is IDLE, it will trigger ac to connect. -func (ac *addrConn) getReadyTransport() (transport.ClientTransport, bool) { +// getReadyTransport returns the transport if ac's state is READY or nil if not. +func (ac *addrConn) getReadyTransport() transport.ClientTransport { ac.mu.Lock() - if ac.state == connectivity.Ready && ac.transport != nil { - t := ac.transport - ac.mu.Unlock() - return t, true - } - var idle bool - if ac.state == connectivity.Idle { - idle = true - } - ac.mu.Unlock() - // Trigger idle ac to connect. - if idle { - ac.connect() + defer ac.mu.Unlock() + if ac.state == connectivity.Ready { + return ac.transport } - return nil, false + return nil } // tearDown starts to tear down the addrConn. -// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in -// some edge cases (e.g., the caller opens and closes many addrConn's in a -// tight loop. -// tearDown doesn't remove ac from ac.cc.conns. +// +// Note that tearDown doesn't remove ac from ac.cc.conns, so the addrConn struct +// will leak. In most cases, call cc.removeAddrConn() instead. func (ac *addrConn) tearDown(err error) { ac.mu.Lock() if ac.state == connectivity.Shutdown { diff --git a/clientconn_authority_test.go b/clientconn_authority_test.go new file mode 100644 index 00000000000..5cd705e2d4f --- /dev/null +++ b/clientconn_authority_test.go @@ -0,0 +1,122 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "context" + "net" + "testing" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/testdata" +) + +func (s) TestClientConnAuthority(t *testing.T) { + serverNameOverride := "over.write.server.name" + creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), serverNameOverride) + if err != nil { + t.Fatalf("credentials.NewClientTLSFromFile(_, %q) failed: %v", err, serverNameOverride) + } + + tests := []struct { + name string + target string + opts []DialOption + wantAuthority string + }{ + { + name: "default", + target: "Non-Existent.Server:8080", + opts: []DialOption{WithInsecure()}, + wantAuthority: "Non-Existent.Server:8080", + }, + { + name: "override-via-creds", + target: "Non-Existent.Server:8080", + opts: []DialOption{WithTransportCredentials(creds)}, + wantAuthority: serverNameOverride, + }, + { + name: "override-via-WithAuthority", + target: "Non-Existent.Server:8080", + opts: []DialOption{WithInsecure(), WithAuthority("authority-override")}, + wantAuthority: "authority-override", + }, + { + name: "override-via-creds-and-WithAuthority", + target: "Non-Existent.Server:8080", + // WithAuthority override works only for insecure creds. + opts: []DialOption{WithTransportCredentials(creds), WithAuthority("authority-override")}, + wantAuthority: serverNameOverride, + }, + { + name: "unix relative", + target: "unix:sock.sock", + opts: []DialOption{WithInsecure()}, + wantAuthority: "localhost", + }, + { + name: "unix relative with custom dialer", + target: "unix:sock.sock", + opts: []DialOption{WithInsecure(), WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, "", addr) + })}, + wantAuthority: "localhost", + }, + { + name: "unix absolute", + target: "unix:/sock.sock", + opts: []DialOption{WithInsecure()}, + wantAuthority: "localhost", + }, + { + name: "unix absolute with custom dialer", + target: "unix:///sock.sock", + opts: []DialOption{WithInsecure(), WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, "", addr) + })}, + wantAuthority: "localhost", + }, + { + name: "localhost colon port", + target: "localhost:50051", + opts: []DialOption{WithInsecure()}, + wantAuthority: "localhost:50051", + }, + { + name: "colon port", + target: ":50051", + opts: []DialOption{WithInsecure()}, + wantAuthority: "localhost:50051", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cc, err := Dial(test.target, test.opts...) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", test.target, err) + } + defer cc.Close() + if cc.authority != test.wantAuthority { + t.Fatalf("cc.authority = %q, want %q", cc.authority, test.wantAuthority) + } + }) + } +} diff --git a/clientconn_parsed_target_test.go b/clientconn_parsed_target_test.go new file mode 100644 index 00000000000..fda06f9fa14 --- /dev/null +++ b/clientconn_parsed_target_test.go @@ -0,0 +1,183 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "google.golang.org/grpc/resolver" +) + +func (s) TestParsedTarget_Success_WithoutCustomDialer(t *testing.T) { + defScheme := resolver.GetDefaultScheme() + tests := []struct { + target string + wantParsed resolver.Target + }{ + // No scheme is specified. + {target: "", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: ""}}, + {target: "://", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "://"}}, + {target: ":///", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: ":///"}}, + {target: "://a/", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "://a/"}}, + {target: ":///a", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: ":///a"}}, + {target: "://a/b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "://a/b"}}, + {target: "/", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "/"}}, + {target: "a/b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a/b"}}, + {target: "a//b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a//b"}}, + {target: "google.com", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "google.com"}}, + {target: "google.com/?a=b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "google.com/?a=b"}}, + {target: "/unix/socket/address", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "/unix/socket/address"}}, + + // An unregistered scheme is specified. + {target: "a:///", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a:///"}}, + {target: "a://b/", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a://b/"}}, + {target: "a:///b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a:///b"}}, + {target: "a://b/c", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a://b/c"}}, + {target: "a:b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a:b"}}, + {target: "a:/b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a:/b"}}, + {target: "a://b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a://b"}}, + + // A registered scheme is specified. + {target: "dns:///google.com", wantParsed: resolver.Target{Scheme: "dns", Authority: "", Endpoint: "google.com"}}, + {target: "dns://a.server.com/google.com", wantParsed: resolver.Target{Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com"}}, + {target: "dns://a.server.com/google.com/?a=b", wantParsed: resolver.Target{Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com/?a=b"}}, + {target: "unix:///a/b/c", wantParsed: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}}, + {target: "unix-abstract:a/b/c", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a/b/c"}}, + {target: "unix-abstract:a b", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a b"}}, + {target: "unix-abstract:a:b", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a:b"}}, + {target: "unix-abstract:a-b", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a-b"}}, + {target: "unix-abstract:/ a///://::!@#$%^&*()b", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/ a///://::!@#$%^&*()b"}}, + {target: "unix-abstract:passthrough:abc", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "passthrough:abc"}}, + {target: "unix-abstract:unix:///abc", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "unix:///abc"}}, + {target: "unix-abstract:///a/b/c", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/a/b/c"}}, + {target: "unix-abstract:///", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/"}}, + {target: "unix-abstract://authority", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "//authority"}}, + {target: "unix://domain", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "unix://domain"}}, + {target: "passthrough:///unix:///a/b/c", wantParsed: resolver.Target{Scheme: "passthrough", Authority: "", Endpoint: "unix:///a/b/c"}}, + } + + for _, test := range tests { + t.Run(test.target, func(t *testing.T) { + cc, err := Dial(test.target, WithInsecure()) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", test.target, err) + } + defer cc.Close() + + if gotParsed := cc.parsedTarget; gotParsed != test.wantParsed { + t.Errorf("cc.parsedTarget = %+v, want %+v", gotParsed, test.wantParsed) + } + }) + } +} + +func (s) TestParsedTarget_Failure_WithoutCustomDialer(t *testing.T) { + targets := []string{ + "unix://a/b/c", + "unix-abstract://authority/a/b/c", + } + + for _, target := range targets { + t.Run(target, func(t *testing.T) { + if cc, err := Dial(target, WithInsecure()); err == nil { + defer cc.Close() + t.Fatalf("Dial(%q) succeeded cc.parsedTarget = %+v, expected to fail", target, cc.parsedTarget) + } + }) + } +} + +func (s) TestParsedTarget_WithCustomDialer(t *testing.T) { + defScheme := resolver.GetDefaultScheme() + tests := []struct { + target string + wantParsed resolver.Target + wantDialerAddress string + }{ + // unix:[local_path], unix:[/absolute], and unix://[/absolute] have + // different behaviors with a custom dialer. + { + target: "unix:a/b/c", + wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "unix:a/b/c"}, + wantDialerAddress: "unix:a/b/c", + }, + { + target: "unix:/a/b/c", + wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "unix:/a/b/c"}, + wantDialerAddress: "unix:/a/b/c", + }, + { + target: "unix:///a/b/c", + wantParsed: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}, + wantDialerAddress: "unix:///a/b/c", + }, + { + target: "dns:///127.0.0.1:50051", + wantParsed: resolver.Target{Scheme: "dns", Authority: "", Endpoint: "127.0.0.1:50051"}, + wantDialerAddress: "127.0.0.1:50051", + }, + { + target: ":///127.0.0.1:50051", + wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: ":///127.0.0.1:50051"}, + wantDialerAddress: ":///127.0.0.1:50051", + }, + { + target: "dns://authority/127.0.0.1:50051", + wantParsed: resolver.Target{Scheme: "dns", Authority: "authority", Endpoint: "127.0.0.1:50051"}, + wantDialerAddress: "127.0.0.1:50051", + }, + { + target: "://authority/127.0.0.1:50051", + wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "://authority/127.0.0.1:50051"}, + wantDialerAddress: "://authority/127.0.0.1:50051", + }, + } + + for _, test := range tests { + t.Run(test.target, func(t *testing.T) { + addrCh := make(chan string, 1) + dialer := func(ctx context.Context, address string) (net.Conn, error) { + addrCh <- address + return nil, errors.New("dialer error") + } + + cc, err := Dial(test.target, WithInsecure(), WithContextDialer(dialer)) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", test.target, err) + } + defer cc.Close() + + select { + case addr := <-addrCh: + if addr != test.wantDialerAddress { + t.Fatalf("address in custom dialer is %q, want %q", addr, test.wantDialerAddress) + } + case <-time.After(time.Second): + t.Fatal("timeout when waiting for custom dialer to be invoked") + } + if gotParsed := cc.parsedTarget; gotParsed != test.wantParsed { + t.Errorf("cc.parsedTarget for dial target %q = %+v, want %+v", test.target, gotParsed, test.wantParsed) + } + }) + } +} diff --git a/clientconn_state_transition_test.go b/clientconn_state_transition_test.go index 0c58131a1c6..2090c8de689 100644 --- a/clientconn_state_transition_test.go +++ b/clientconn_state_transition_test.go @@ -75,7 +75,7 @@ func (s) TestStateTransitions_SingleAddress(t *testing.T) { }, }, { - desc: "When the connection is closed, the client enters TRANSIENT FAILURE.", + desc: "When the connection is closed before the preface is sent, the client enters TRANSIENT FAILURE.", want: []connectivity.State{ connectivity.Connecting, connectivity.TransientFailure, @@ -167,6 +167,7 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s t.Fatal(err) } defer client.Close() + go stayConnected(client) stateNotifications := testBalancerBuilder.nextStateNotifier() @@ -193,11 +194,12 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s } } -// When a READY connection is closed, the client enters CONNECTING. +// When a READY connection is closed, the client enters IDLE then CONNECTING. func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { want := []connectivity.State{ connectivity.Connecting, connectivity.Ready, + connectivity.Idle, connectivity.Connecting, } @@ -210,7 +212,8 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { } defer lis.Close() - sawReady := make(chan struct{}) + sawReady := make(chan struct{}, 1) + defer close(sawReady) // Launch the server. go func() { @@ -239,6 +242,7 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { t.Fatal(err) } defer client.Close() + go stayConnected(client) stateNotifications := testBalancerBuilder.nextStateNotifier() @@ -250,7 +254,7 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want) case seen := <-stateNotifications: if seen == connectivity.Ready { - close(sawReady) + sawReady <- struct{}{} } if seen != want[i] { t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen) @@ -358,6 +362,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { want := []connectivity.State{ connectivity.Connecting, connectivity.Ready, + connectivity.Idle, connectivity.Connecting, } @@ -378,7 +383,8 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { defer lis2.Close() server1Done := make(chan struct{}) - sawReady := make(chan struct{}) + sawReady := make(chan struct{}, 1) + defer close(sawReady) // Launch server 1. go func() { @@ -400,12 +406,6 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { conn.Close() - _, err = lis1.Accept() - if err != nil { - t.Error(err) - return - } - close(server1Done) }() @@ -419,6 +419,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { t.Fatal(err) } defer client.Close() + go stayConnected(client) stateNotifications := testBalancerBuilder.nextStateNotifier() @@ -430,7 +431,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want) case seen := <-stateNotifications: if seen == connectivity.Ready { - close(sawReady) + sawReady <- struct{}{} } if seen != want[i] { t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen) diff --git a/clientconn_test.go b/clientconn_test.go index 6c61666b7ef..d276c7b5f2f 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -217,7 +217,7 @@ func (s) TestDialWaitsForServerSettingsAndFails(t *testing.T) { client.Close() t.Fatalf("Unexpected success (err=nil) while dialing") } - expectedMsg := "server handshake" + expectedMsg := "server preface" if !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) || !strings.Contains(err.Error(), expectedMsg) { t.Fatalf("DialContext(_) = %v; want a message that includes both %q and %q", err, context.DeadlineExceeded.Error(), expectedMsg) } @@ -289,6 +289,9 @@ func (s) TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) { if err != nil { t.Fatalf("Error while dialing. Err: %v", err) } + + go stayConnected(client) + // wait for connection to be accepted on the server. timer := time.NewTimer(time.Second * 10) select { @@ -311,9 +314,7 @@ func (s) TestBackoffWhenNoServerPrefaceReceived(t *testing.T) { defer lis.Close() done := make(chan struct{}) go func() { // Launch the server. - defer func() { - close(done) - }() + defer close(done) conn, err := lis.Accept() // Accept the connection only to close it immediately. if err != nil { t.Errorf("Error while accepting. Err: %v", err) @@ -340,13 +341,13 @@ func (s) TestBackoffWhenNoServerPrefaceReceived(t *testing.T) { prevAt = meow } }() - client, err := Dial(lis.Addr().String(), WithInsecure()) + cc, err := Dial(lis.Addr().String(), WithInsecure()) if err != nil { t.Fatalf("Error while dialing. Err: %v", err) } - defer client.Close() + defer cc.Close() + go stayConnected(cc) <-done - } func (s) TestWithTimeout(t *testing.T) { @@ -375,62 +376,6 @@ func (s) TestWithTransportCredentialsTLS(t *testing.T) { } } -func (s) TestDefaultAuthority(t *testing.T) { - target := "Non-Existent.Server:8080" - conn, err := Dial(target, WithInsecure()) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v, want _, ", err) - } - defer conn.Close() - if conn.authority != target { - t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, target) - } -} - -func (s) TestTLSServerNameOverwrite(t *testing.T) { - overwriteServerName := "over.write.server.name" - creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), overwriteServerName) - if err != nil { - t.Fatalf("Failed to create credentials %v", err) - } - conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds)) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v, want _, ", err) - } - defer conn.Close() - if conn.authority != overwriteServerName { - t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) - } -} - -func (s) TestWithAuthority(t *testing.T) { - overwriteServerName := "over.write.server.name" - conn, err := Dial("passthrough:///Non-Existent.Server:80", WithInsecure(), WithAuthority(overwriteServerName)) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v, want _, ", err) - } - defer conn.Close() - if conn.authority != overwriteServerName { - t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) - } -} - -func (s) TestWithAuthorityAndTLS(t *testing.T) { - overwriteServerName := "over.write.server.name" - creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), overwriteServerName) - if err != nil { - t.Fatalf("Failed to create credentials %v", err) - } - conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithAuthority("no.effect.authority")) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v, want _, ", err) - } - defer conn.Close() - if conn.authority != overwriteServerName { - t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) - } -} - // When creating a transport configured with n addresses, only calculate the // backoff once per "round" of attempts instead of once per address (n times // per "round" of attempts). @@ -735,16 +680,15 @@ func (s) TestClientUpdatesParamsAfterGoAway(t *testing.T) { time.Sleep(10 * time.Millisecond) cc.mu.RLock() v := cc.mkp.Time + cc.mu.RUnlock() if v == 20*time.Second { // Success - cc.mu.RUnlock() return } if ctx.Err() != nil { // Timeout t.Fatalf("cc.dopts.copts.Keepalive.Time = %v , want 20s", v) } - cc.mu.RUnlock() } } @@ -832,6 +776,7 @@ func (s) TestResetConnectBackoff(t *testing.T) { t.Fatalf("Dial() = _, %v; want _, nil", err) } defer cc.Close() + go stayConnected(cc) select { case <-dials: case <-time.NewTimer(10 * time.Second).C: @@ -986,6 +931,7 @@ func (s) TestUpdateAddresses_RetryFromFirstAddr(t *testing.T) { t.Fatal(err) } defer client.Close() + go stayConnected(client) timeout := time.After(5 * time.Second) @@ -1113,3 +1059,23 @@ func testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t *testing.T t.Fatal("default service config failed to be applied after 1s") } } + +// stayConnected makes cc stay connected by repeatedly calling cc.Connect() +// until the state becomes Shutdown or until 10 seconds elapses. +func stayConnected(cc *ClientConn) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for { + state := cc.GetState() + switch state { + case connectivity.Idle: + cc.Connect() + case connectivity.Shutdown: + return + } + if !cc.WaitForStateChange(ctx, state) { + return + } + } +} diff --git a/cmd/protoc-gen-go-grpc/grpc.go b/cmd/protoc-gen-go-grpc/grpc.go index 1e787344ebc..f45e0403fd4 100644 --- a/cmd/protoc-gen-go-grpc/grpc.go +++ b/cmd/protoc-gen-go-grpc/grpc.go @@ -24,7 +24,6 @@ import ( "strings" "google.golang.org/protobuf/compiler/protogen" - "google.golang.org/protobuf/types/descriptorpb" ) @@ -43,6 +42,14 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated filename := file.GeneratedFilenamePrefix + "_grpc.pb.go" g := gen.NewGeneratedFile(filename, file.GoImportPath) g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.") + g.P("// versions:") + g.P("// - protoc-gen-go-grpc v", version) + g.P("// - protoc ", protocVersion(gen)) + if file.Proto.GetOptions().GetDeprecated() { + g.P("// ", file.Desc.Path(), " is a deprecated file.") + } else { + g.P("// source: ", file.Desc.Path()) + } g.P() g.P("package ", file.GoPackageName) g.P() @@ -50,6 +57,18 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated return g } +func protocVersion(gen *protogen.Plugin) string { + v := gen.Request.GetCompilerVersion() + if v == nil { + return "(unknown)" + } + var suffix string + if s := v.GetSuffix(); s != "" { + suffix = "-" + s + } + return fmt.Sprintf("v%d.%d.%d%s", v.GetMajor(), v.GetMinor(), v.GetPatch(), suffix) +} + // generateFileContent generates the gRPC service definitions, excluding the package statement. func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) { if len(file.Services) == 0 { @@ -188,7 +207,7 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated g.P() // Server handler implementations. - var handlerNames []string + handlerNames := make([]string, 0, len(service.Methods)) for _, method := range service.Methods { hname := genServerMethod(gen, file, g, method) handlerNames = append(handlerNames, hname) diff --git a/connectivity/connectivity.go b/connectivity/connectivity.go index 01015626150..4a89926422b 100644 --- a/connectivity/connectivity.go +++ b/connectivity/connectivity.go @@ -18,7 +18,6 @@ // Package connectivity defines connectivity semantics. // For details, see https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md. -// All APIs in this package are experimental. package connectivity import ( @@ -45,7 +44,7 @@ func (s State) String() string { return "SHUTDOWN" default: logger.Errorf("unknown connectivity state: %d", s) - return "Invalid-State" + return "INVALID_STATE" } } @@ -61,3 +60,35 @@ const ( // Shutdown indicates the ClientConn has started shutting down. Shutdown ) + +// ServingMode indicates the current mode of operation of the server. +// +// Only xDS enabled gRPC servers currently report their serving mode. +type ServingMode int + +const ( + // ServingModeStarting indicates that the server is starting up. + ServingModeStarting ServingMode = iota + // ServingModeServing indicates that the server contains all required + // configuration and is serving RPCs. + ServingModeServing + // ServingModeNotServing indicates that the server is not accepting new + // connections. Existing connections will be closed gracefully, allowing + // in-progress RPCs to complete. A server enters this mode when it does not + // contain the required configuration to serve RPCs. + ServingModeNotServing +) + +func (s ServingMode) String() string { + switch s { + case ServingModeStarting: + return "STARTING" + case ServingModeServing: + return "SERVING" + case ServingModeNotServing: + return "NOT_SERVING" + default: + logger.Errorf("unknown serving mode: %d", s) + return "INVALID_MODE" + } +} diff --git a/credentials/alts/alts.go b/credentials/alts/alts.go index 729c4b43b5f..579adf210c4 100644 --- a/credentials/alts/alts.go +++ b/credentials/alts/alts.go @@ -37,6 +37,7 @@ import ( "google.golang.org/grpc/credentials/alts/internal/handshaker/service" altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/googlecloud" ) const ( @@ -54,6 +55,7 @@ const ( ) var ( + vmOnGCP bool once sync.Once maxRPCVersion = &altspb.RpcProtocolVersions_Version{ Major: protocolVersionMaxMajor, @@ -149,9 +151,8 @@ func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials { func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials { once.Do(func() { - vmOnGCP = isRunningOnGCP() + vmOnGCP = googlecloud.OnGCE() }) - if hsAddress == "" { hsAddress = hypervisorHandshakerServiceAddress } diff --git a/credentials/alts/alts_test.go b/credentials/alts/alts_test.go index cbb1656d20c..22ad5a48b09 100644 --- a/credentials/alts/alts_test.go +++ b/credentials/alts/alts_test.go @@ -1,3 +1,4 @@ +//go:build linux || windows // +build linux windows /* diff --git a/credentials/alts/internal/conn/record_test.go b/credentials/alts/internal/conn/record_test.go index 59d4f41e9e1..c18f902b401 100644 --- a/credentials/alts/internal/conn/record_test.go +++ b/credentials/alts/internal/conn/record_test.go @@ -40,11 +40,15 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } +const ( + rekeyRecordProtocol = "ALTSRP_GCM_AES128_REKEY" +) + var ( - nextProtocols = []string{"ALTSRP_GCM_AES128"} + recordProtocols = []string{rekeyRecordProtocol} altsRecordFuncs = map[string]ALTSRecordFunc{ // ALTS handshaker protocols. - "ALTSRP_GCM_AES128": func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) { + rekeyRecordProtocol: func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) { return NewAES128GCM(s, keyData) }, } @@ -77,7 +81,7 @@ func (c *testConn) Close() error { return nil } -func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string, protected []byte) *conn { +func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, rp string, protected []byte) *conn { key := []byte{ // 16 arbitrary bytes. 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} @@ -85,23 +89,23 @@ func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string, pro in: in, out: out, } - c, err := NewConn(&tc, side, np, key, protected) + c, err := NewConn(&tc, side, rp, key, protected) if err != nil { panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err)) } return c.(*conn) } -func newConnPair(np string, clientProtected []byte, serverProtected []byte) (client, server *conn) { +func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (client, server *conn) { clientBuf := new(bytes.Buffer) serverBuf := new(bytes.Buffer) - clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np, clientProtected) - serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np, serverProtected) + clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, rp, clientProtected) + serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, rp, serverProtected) return clientConn, serverConn } -func testPingPong(t *testing.T, np string) { - clientConn, serverConn := newConnPair(np, nil, nil) +func testPingPong(t *testing.T, rp string) { + clientConn, serverConn := newConnPair(rp, nil, nil) clientMsg := []byte("Client Message") if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { t.Fatalf("Client Write() = %v, %v; want %v, ", n, err, len(clientMsg)) @@ -128,13 +132,13 @@ func testPingPong(t *testing.T, np string) { } func (s) TestPingPong(t *testing.T) { - for _, np := range nextProtocols { - testPingPong(t, np) + for _, rp := range recordProtocols { + testPingPong(t, rp) } } -func testSmallReadBuffer(t *testing.T, np string) { - clientConn, serverConn := newConnPair(np, nil, nil) +func testSmallReadBuffer(t *testing.T, rp string) { + clientConn, serverConn := newConnPair(rp, nil, nil) msg := []byte("Very Important Message") if n, err := clientConn.Write(msg); err != nil { t.Fatalf("Write() = %v, %v; want %v, ", n, err, len(msg)) @@ -155,13 +159,13 @@ func testSmallReadBuffer(t *testing.T, np string) { } func (s) TestSmallReadBuffer(t *testing.T) { - for _, np := range nextProtocols { - testSmallReadBuffer(t, np) + for _, rp := range recordProtocols { + testSmallReadBuffer(t, rp) } } -func testLargeMsg(t *testing.T, np string) { - clientConn, serverConn := newConnPair(np, nil, nil) +func testLargeMsg(t *testing.T, rp string) { + clientConn, serverConn := newConnPair(rp, nil, nil) // msgLen is such that the length in the framing is larger than the // default size of one frame. msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 @@ -179,12 +183,12 @@ func testLargeMsg(t *testing.T, np string) { } func (s) TestLargeMsg(t *testing.T) { - for _, np := range nextProtocols { - testLargeMsg(t, np) + for _, rp := range recordProtocols { + testLargeMsg(t, rp) } } -func testIncorrectMsgType(t *testing.T, np string) { +func testIncorrectMsgType(t *testing.T, rp string) { // framedMsg is an empty ciphertext with correct framing but wrong // message type. framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize) @@ -193,7 +197,7 @@ func testIncorrectMsgType(t *testing.T, np string) { binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType) in := bytes.NewBuffer(framedMsg) - c := newTestALTSRecordConn(in, nil, core.ClientSide, np, nil) + c := newTestALTSRecordConn(in, nil, core.ClientSide, rp, nil) b := make([]byte, 1) if n, err := c.Read(b); n != 0 || err == nil { t.Fatalf("Read() = , want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType)) @@ -201,15 +205,15 @@ func testIncorrectMsgType(t *testing.T, np string) { } func (s) TestIncorrectMsgType(t *testing.T) { - for _, np := range nextProtocols { - testIncorrectMsgType(t, np) + for _, rp := range recordProtocols { + testIncorrectMsgType(t, rp) } } -func testFrameTooLarge(t *testing.T, np string) { +func testFrameTooLarge(t *testing.T, rp string) { buf := new(bytes.Buffer) - clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np, nil) - serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np, nil) + clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, rp, nil) + serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, rp, nil) // payloadLen is such that the length in the framing is larger than // allowed in one frame. payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 @@ -234,15 +238,15 @@ func testFrameTooLarge(t *testing.T, np string) { } func (s) TestFrameTooLarge(t *testing.T) { - for _, np := range nextProtocols { - testFrameTooLarge(t, np) + for _, rp := range recordProtocols { + testFrameTooLarge(t, rp) } } -func testWriteLargeData(t *testing.T, np string) { +func testWriteLargeData(t *testing.T, rp string) { // Test sending and receiving messages larger than the maximum write // buffer size. - clientConn, serverConn := newConnPair(np, nil, nil) + clientConn, serverConn := newConnPair(rp, nil, nil) // Message size is intentionally chosen to not be multiple of // payloadLengthLimtit. msgSize := altsWriteBufferMaxSize + (100 * 1024) @@ -277,25 +281,25 @@ func testWriteLargeData(t *testing.T, np string) { } func (s) TestWriteLargeData(t *testing.T) { - for _, np := range nextProtocols { - testWriteLargeData(t, np) + for _, rp := range recordProtocols { + testWriteLargeData(t, rp) } } -func testProtectedBuffer(t *testing.T, np string) { +func testProtectedBuffer(t *testing.T, rp string) { key := []byte{ // 16 arbitrary bytes. 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} // Encrypt a message to be passed to NewConn as a client-side protected // buffer. - newCrypto := protocols[np] + newCrypto := protocols[rp] if newCrypto == nil { - t.Fatalf("Unknown next protocol %q", np) + t.Fatalf("Unknown record protocol %q", rp) } crypto, err := newCrypto(core.ClientSide, key) if err != nil { - t.Fatalf("Failed to create a crypter for protocol %q: %v", np, err) + t.Fatalf("Failed to create a crypter for protocol %q: %v", rp, err) } msg := []byte("Client Protected Message") encryptedMsg, err := crypto.Encrypt(nil, msg) @@ -307,7 +311,7 @@ func testProtectedBuffer(t *testing.T, np string) { binary.LittleEndian.PutUint32(protectedMsg[4:], altsRecordMsgType) protectedMsg = append(protectedMsg, encryptedMsg...) - _, serverConn := newConnPair(np, nil, protectedMsg) + _, serverConn := newConnPair(rp, nil, protectedMsg) rcvClientMsg := make([]byte, len(msg)) if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil { t.Fatalf("Server Read() = %v, %v; want %v, ", n, err, len(rcvClientMsg)) @@ -318,7 +322,7 @@ func testProtectedBuffer(t *testing.T, np string) { } func (s) TestProtectedBuffer(t *testing.T) { - for _, np := range nextProtocols { - testProtectedBuffer(t, np) + for _, rp := range recordProtocols { + testProtectedBuffer(t, rp) } } diff --git a/credentials/alts/internal/handshaker/handshaker_test.go b/credentials/alts/internal/handshaker/handshaker_test.go index bf516dc53c8..14a0721054f 100644 --- a/credentials/alts/internal/handshaker/handshaker_test.go +++ b/credentials/alts/internal/handshaker/handshaker_test.go @@ -21,6 +21,7 @@ package handshaker import ( "bytes" "context" + "errors" "testing" "time" @@ -163,7 +164,8 @@ func (s) TestClientHandshake(t *testing.T) { go func() { _, context, err := chs.ClientHandshake(ctx) if err == nil && context == nil { - panic("expected non-nil ALTS context") + errc <- errors.New("expected non-nil ALTS context") + return } errc <- err chs.Close() @@ -219,7 +221,8 @@ func (s) TestServerHandshake(t *testing.T) { go func() { _, context, err := shs.ServerHandshake(ctx) if err == nil && context == nil { - panic("expected non-nil ALTS context") + errc <- errors.New("expected non-nil ALTS context") + return } errc <- err shs.Close() diff --git a/credentials/alts/internal/proto/grpc_gcp/handshaker_grpc.pb.go b/credentials/alts/internal/proto/grpc_gcp/handshaker_grpc.pb.go index efdbd13fa30..a02c4582815 100644 --- a/credentials/alts/internal/proto/grpc_gcp/handshaker_grpc.pb.go +++ b/credentials/alts/internal/proto/grpc_gcp/handshaker_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/gcp/handshaker.proto package grpc_gcp diff --git a/credentials/alts/utils.go b/credentials/alts/utils.go index 9a300bc19aa..cbfd056cfb1 100644 --- a/credentials/alts/utils.go +++ b/credentials/alts/utils.go @@ -21,14 +21,6 @@ package alts import ( "context" "errors" - "fmt" - "io" - "io/ioutil" - "log" - "os" - "os/exec" - "regexp" - "runtime" "strings" "google.golang.org/grpc/codes" @@ -36,92 +28,6 @@ import ( "google.golang.org/grpc/status" ) -const ( - linuxProductNameFile = "/sys/class/dmi/id/product_name" - windowsCheckCommand = "powershell.exe" - windowsCheckCommandArgs = "Get-WmiObject -Class Win32_BIOS" - powershellOutputFilter = "Manufacturer" - windowsManufacturerRegex = ":(.*)" -) - -type platformError string - -func (k platformError) Error() string { - return fmt.Sprintf("%s is not supported", string(k)) -} - -var ( - // The following two variables will be reassigned in tests. - runningOS = runtime.GOOS - manufacturerReader = func() (io.Reader, error) { - switch runningOS { - case "linux": - return os.Open(linuxProductNameFile) - case "windows": - cmd := exec.Command(windowsCheckCommand, windowsCheckCommandArgs) - out, err := cmd.Output() - if err != nil { - return nil, err - } - - for _, line := range strings.Split(strings.TrimSuffix(string(out), "\n"), "\n") { - if strings.HasPrefix(line, powershellOutputFilter) { - re := regexp.MustCompile(windowsManufacturerRegex) - name := re.FindString(line) - name = strings.TrimLeft(name, ":") - return strings.NewReader(name), nil - } - } - - return nil, errors.New("cannot determine the machine's manufacturer") - default: - return nil, platformError(runningOS) - } - } - vmOnGCP bool -) - -// isRunningOnGCP checks whether the local system, without doing a network request is -// running on GCP. -func isRunningOnGCP() bool { - manufacturer, err := readManufacturer() - if os.IsNotExist(err) { - return false - } - if err != nil { - log.Fatalf("failure to read manufacturer information: %v", err) - } - name := string(manufacturer) - switch runningOS { - case "linux": - name = strings.TrimSpace(name) - return name == "Google" || name == "Google Compute Engine" - case "windows": - name = strings.Replace(name, " ", "", -1) - name = strings.Replace(name, "\n", "", -1) - name = strings.Replace(name, "\r", "", -1) - return name == "Google" - default: - log.Fatal(platformError(runningOS)) - } - return false -} - -func readManufacturer() ([]byte, error) { - reader, err := manufacturerReader() - if err != nil { - return nil, err - } - if reader == nil { - return nil, errors.New("got nil reader") - } - manufacturer, err := ioutil.ReadAll(reader) - if err != nil { - return nil, fmt.Errorf("failed reading %v: %v", linuxProductNameFile, err) - } - return manufacturer, nil -} - // AuthInfoFromContext extracts the alts.AuthInfo object from the given context, // if it exists. This API should be used by gRPC server RPC handlers to get // information about the communicating peer. For client-side, use grpc.Peer() diff --git a/credentials/alts/utils_test.go b/credentials/alts/utils_test.go index 5b54b1d5f77..531cdfce6e3 100644 --- a/credentials/alts/utils_test.go +++ b/credentials/alts/utils_test.go @@ -1,3 +1,4 @@ +//go:build linux || windows // +build linux windows /* @@ -22,8 +23,6 @@ package alts import ( "context" - "io" - "os" "strings" "testing" "time" @@ -42,67 +41,6 @@ const ( defaultTestTimeout = 10 * time.Second ) -func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() { - tmpOS := runningOS - tmpReader := manufacturerReader - - // Set test OS and reader function. - runningOS = testOS - manufacturerReader = reader - return func() { - runningOS = tmpOS - manufacturerReader = tmpReader - } - -} - -func setup(testOS string, testReader io.Reader) func() { - reader := func() (io.Reader, error) { - return testReader, nil - } - return setupManufacturerReader(testOS, reader) -} - -func setupError(testOS string, err error) func() { - reader := func() (io.Reader, error) { - return nil, err - } - return setupManufacturerReader(testOS, reader) -} - -func (s) TestIsRunningOnGCP(t *testing.T) { - for _, tc := range []struct { - description string - testOS string - testReader io.Reader - out bool - }{ - // Linux tests. - {"linux: not a GCP platform", "linux", strings.NewReader("not GCP"), false}, - {"Linux: GCP platform (Google)", "linux", strings.NewReader("Google"), true}, - {"Linux: GCP platform (Google Compute Engine)", "linux", strings.NewReader("Google Compute Engine"), true}, - {"Linux: GCP platform (Google Compute Engine) with extra spaces", "linux", strings.NewReader(" Google Compute Engine "), true}, - // Windows tests. - {"windows: not a GCP platform", "windows", strings.NewReader("not GCP"), false}, - {"windows: GCP platform (Google)", "windows", strings.NewReader("Google"), true}, - {"windows: GCP platform (Google) with extra spaces", "windows", strings.NewReader(" Google "), true}, - } { - reverseFunc := setup(tc.testOS, tc.testReader) - if got, want := isRunningOnGCP(), tc.out; got != want { - t.Errorf("%v: isRunningOnGCP()=%v, want %v", tc.description, got, want) - } - reverseFunc() - } -} - -func (s) TestIsRunningOnGCPNoProductNameFile(t *testing.T) { - reverseFunc := setupError("linux", os.ErrNotExist) - if isRunningOnGCP() { - t.Errorf("ErrNotExist: isRunningOnGCP()=true, want false") - } - reverseFunc() -} - func (s) TestAuthInfoFromContext(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() diff --git a/credentials/credentials.go b/credentials/credentials.go index e69562e7878..7eee7e4ec12 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -30,7 +30,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc/attributes" - "google.golang.org/grpc/internal" + icredentials "google.golang.org/grpc/internal/credentials" ) // PerRPCCredentials defines the common interface for the credentials which need to @@ -188,15 +188,12 @@ type RequestInfo struct { AuthInfo AuthInfo } -// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object. -type requestInfoKey struct{} - // RequestInfoFromContext extracts the RequestInfo from the context if it exists. // // This API is experimental. func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) { - ri, ok = ctx.Value(requestInfoKey{}).(RequestInfo) - return + ri, ok = icredentials.RequestInfoFromContext(ctx).(RequestInfo) + return ri, ok } // ClientHandshakeInfo holds data to be passed to ClientHandshake. This makes @@ -211,16 +208,12 @@ type ClientHandshakeInfo struct { Attributes *attributes.Attributes } -// clientHandshakeInfoKey is a struct used as the key to store -// ClientHandshakeInfo in a context. -type clientHandshakeInfoKey struct{} - // ClientHandshakeInfoFromContext returns the ClientHandshakeInfo struct stored // in ctx. // // This API is experimental. func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo { - chi, _ := ctx.Value(clientHandshakeInfoKey{}).(ClientHandshakeInfo) + chi, _ := icredentials.ClientHandshakeInfoFromContext(ctx).(ClientHandshakeInfo) return chi } @@ -249,15 +242,6 @@ func CheckSecurityLevel(ai AuthInfo, level SecurityLevel) error { return nil } -func init() { - internal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context { - return context.WithValue(ctx, requestInfoKey{}, ri) - } - internal.NewClientHandshakeInfoContext = func(ctx context.Context, chi ClientHandshakeInfo) context.Context { - return context.WithValue(ctx, clientHandshakeInfoKey{}, chi) - } -} - // ChannelzSecurityInfo defines the interface that security protocols should implement // in order to provide security info to channelz. // diff --git a/credentials/go12.go b/credentials/go12.go deleted file mode 100644 index ccbf35b3312..00000000000 --- a/credentials/go12.go +++ /dev/null @@ -1,30 +0,0 @@ -// +build go1.12 - -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package credentials - -import "crypto/tls" - -// This init function adds cipher suite constants only defined in Go 1.12. -func init() { - cipherSuiteLookup[tls.TLS_AES_128_GCM_SHA256] = "TLS_AES_128_GCM_SHA256" - cipherSuiteLookup[tls.TLS_AES_256_GCM_SHA384] = "TLS_AES_256_GCM_SHA384" - cipherSuiteLookup[tls.TLS_CHACHA20_POLY1305_SHA256] = "TLS_CHACHA20_POLY1305_SHA256" -} diff --git a/credentials/google/google.go b/credentials/google/google.go index 7f3e240e475..63625a4b680 100644 --- a/credentials/google/google.go +++ b/credentials/google/google.go @@ -35,57 +35,63 @@ const tokenRequestTimeout = 30 * time.Second var logger = grpclog.Component("credentials") -// NewDefaultCredentials returns a credentials bundle that is configured to work -// with google services. +// DefaultCredentialsOptions constructs options to build DefaultCredentials. +type DefaultCredentialsOptions struct { + // PerRPCCreds is a per RPC credentials that is passed to a bundle. + PerRPCCreds credentials.PerRPCCredentials +} + +// NewDefaultCredentialsWithOptions returns a credentials bundle that is +// configured to work with google services. // // This API is experimental. -func NewDefaultCredentials() credentials.Bundle { - c := &creds{ - newPerRPCCreds: func() credentials.PerRPCCredentials { - ctx, cancel := context.WithTimeout(context.Background(), tokenRequestTimeout) - defer cancel() - perRPCCreds, err := oauth.NewApplicationDefault(ctx) - if err != nil { - logger.Warningf("google default creds: failed to create application oauth: %v", err) - } - return perRPCCreds - }, +func NewDefaultCredentialsWithOptions(opts DefaultCredentialsOptions) credentials.Bundle { + if opts.PerRPCCreds == nil { + ctx, cancel := context.WithTimeout(context.Background(), tokenRequestTimeout) + defer cancel() + var err error + opts.PerRPCCreds, err = oauth.NewApplicationDefault(ctx) + if err != nil { + logger.Warningf("NewDefaultCredentialsWithOptions: failed to create application oauth: %v", err) + } } + c := &creds{opts: opts} bundle, err := c.NewWithMode(internal.CredsBundleModeFallback) if err != nil { - logger.Warningf("google default creds: failed to create new creds: %v", err) + logger.Warningf("NewDefaultCredentialsWithOptions: failed to create new creds: %v", err) } return bundle } +// NewDefaultCredentials returns a credentials bundle that is configured to work +// with google services. +// +// This API is experimental. +func NewDefaultCredentials() credentials.Bundle { + return NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}) +} + // NewComputeEngineCredentials returns a credentials bundle that is configured to work // with google services. This API must only be used when running on GCE. Authentication configured // by this API represents the GCE VM's default service account. // // This API is experimental. func NewComputeEngineCredentials() credentials.Bundle { - c := &creds{ - newPerRPCCreds: func() credentials.PerRPCCredentials { - return oauth.NewComputeEngine() - }, - } - bundle, err := c.NewWithMode(internal.CredsBundleModeFallback) - if err != nil { - logger.Warningf("compute engine creds: failed to create new creds: %v", err) - } - return bundle + return NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{ + PerRPCCreds: oauth.NewComputeEngine(), + }) } // creds implements credentials.Bundle. type creds struct { + opts DefaultCredentialsOptions + // Supported modes are defined in internal/internal.go. mode string - // The transport credentials associated with this bundle. + // The active transport credentials associated with this bundle. transportCreds credentials.TransportCredentials - // The per RPC credentials associated with this bundle. + // The active per RPC credentials associated with this bundle. perRPCCreds credentials.PerRPCCredentials - // Creates new per RPC credentials - newPerRPCCreds func() credentials.PerRPCCredentials } func (c *creds) TransportCredentials() credentials.TransportCredentials { @@ -99,28 +105,37 @@ func (c *creds) PerRPCCredentials() credentials.PerRPCCredentials { return c.perRPCCreds } +var ( + newTLS = func() credentials.TransportCredentials { + return credentials.NewTLS(nil) + } + newALTS = func() credentials.TransportCredentials { + return alts.NewClientCreds(alts.DefaultClientOptions()) + } +) + // NewWithMode should make a copy of Bundle, and switch mode. Modifying the // existing Bundle may cause races. func (c *creds) NewWithMode(mode string) (credentials.Bundle, error) { newCreds := &creds{ - mode: mode, - newPerRPCCreds: c.newPerRPCCreds, + opts: c.opts, + mode: mode, } // Create transport credentials. switch mode { case internal.CredsBundleModeFallback: - newCreds.transportCreds = credentials.NewTLS(nil) + newCreds.transportCreds = newClusterTransportCreds(newTLS(), newALTS()) case internal.CredsBundleModeBackendFromBalancer, internal.CredsBundleModeBalancer: // Only the clients can use google default credentials, so we only need // to create new ALTS client creds here. - newCreds.transportCreds = alts.NewClientCreds(alts.DefaultClientOptions()) + newCreds.transportCreds = newALTS() default: return nil, fmt.Errorf("unsupported mode: %v", mode) } if mode == internal.CredsBundleModeFallback || mode == internal.CredsBundleModeBackendFromBalancer { - newCreds.perRPCCreds = newCreds.newPerRPCCreds() + newCreds.perRPCCreds = newCreds.opts.PerRPCCreds } return newCreds, nil diff --git a/credentials/google/google_test.go b/credentials/google/google_test.go new file mode 100644 index 00000000000..6a6e492ee77 --- /dev/null +++ b/credentials/google/google_test.go @@ -0,0 +1,131 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package google + +import ( + "context" + "net" + "testing" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" + icredentials "google.golang.org/grpc/internal/credentials" + "google.golang.org/grpc/resolver" +) + +type testCreds struct { + credentials.TransportCredentials + typ string +} + +func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, &testAuthInfo{typ: c.typ}, nil +} + +func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, &testAuthInfo{typ: c.typ}, nil +} + +type testAuthInfo struct { + typ string +} + +func (t *testAuthInfo) AuthType() string { + return t.typ +} + +var ( + testTLS = &testCreds{typ: "tls"} + testALTS = &testCreds{typ: "alts"} +) + +func overrideNewCredsFuncs() func() { + oldNewTLS := newTLS + newTLS = func() credentials.TransportCredentials { + return testTLS + } + oldNewALTS := newALTS + newALTS = func() credentials.TransportCredentials { + return testALTS + } + return func() { + newTLS = oldNewTLS + newALTS = oldNewALTS + } +} + +// TestClientHandshakeBasedOnClusterName that by default (without switching +// modes), ClientHandshake does either tls or alts base on the cluster name in +// attributes. +func TestClientHandshakeBasedOnClusterName(t *testing.T) { + defer overrideNewCredsFuncs()() + for bundleTyp, tc := range map[string]credentials.Bundle{ + "defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}), + "defaultCreds": NewDefaultCredentials(), + "computeCreds": NewComputeEngineCredentials(), + } { + tests := []struct { + name string + ctx context.Context + wantTyp string + }{ + { + name: "no cluster name", + ctx: context.Background(), + wantTyp: "tls", + }, + { + name: "with non-CFE cluster name", + ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ + Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes, + }), + // non-CFE backends should use alts. + wantTyp: "alts", + }, + { + name: "with CFE cluster name", + ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ + Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, cfeClusterName).Attributes, + }), + // CFE should use tls. + wantTyp: "tls", + }, + } + for _, tt := range tests { + t.Run(bundleTyp+" "+tt.name, func(t *testing.T) { + _, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil) + if err != nil { + t.Fatalf("ClientHandshake failed: %v", err) + } + if gotType := info.AuthType(); gotType != tt.wantTyp { + t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp) + } + + _, infoServer, err := tc.TransportCredentials().ServerHandshake(nil) + if err != nil { + t.Fatalf("ClientHandshake failed: %v", err) + } + // ServerHandshake should always do TLS. + if gotType := infoServer.AuthType(); gotType != "tls" { + t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls") + } + }) + } + } +} diff --git a/credentials/google/xds.go b/credentials/google/xds.go new file mode 100644 index 00000000000..588c685e259 --- /dev/null +++ b/credentials/google/xds.go @@ -0,0 +1,90 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package google + +import ( + "context" + "net" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" +) + +const cfeClusterName = "google-cfe" + +// clusterTransportCreds is a combo of TLS + ALTS. +// +// On the client, ClientHandshake picks TLS or ALTS based on address attributes. +// - if attributes has cluster name +// - if cluster name is "google_cfe", use TLS +// - otherwise, use ALTS +// - else, do TLS +// +// On the server, ServerHandshake always does TLS. +type clusterTransportCreds struct { + tls credentials.TransportCredentials + alts credentials.TransportCredentials +} + +func newClusterTransportCreds(tls, alts credentials.TransportCredentials) *clusterTransportCreds { + return &clusterTransportCreds{ + tls: tls, + alts: alts, + } +} + +func (c *clusterTransportCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + chi := credentials.ClientHandshakeInfoFromContext(ctx) + if chi.Attributes == nil { + return c.tls.ClientHandshake(ctx, authority, rawConn) + } + cn, ok := internal.GetXDSHandshakeClusterName(chi.Attributes) + if !ok || cn == cfeClusterName { + return c.tls.ClientHandshake(ctx, authority, rawConn) + } + // If attributes have cluster name, and cluster name is not cfe, it's a + // backend address, use ALTS. + return c.alts.ClientHandshake(ctx, authority, rawConn) +} + +func (c *clusterTransportCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return c.tls.ServerHandshake(conn) +} + +func (c *clusterTransportCreds) Info() credentials.ProtocolInfo { + // TODO: this always returns tls.Info now, because we don't have a cluster + // name to check when this method is called. This method doesn't affect + // anything important now. We may want to revisit this if it becomes more + // important later. + return c.tls.Info() +} + +func (c *clusterTransportCreds) Clone() credentials.TransportCredentials { + return &clusterTransportCreds{ + tls: c.tls.Clone(), + alts: c.alts.Clone(), + } +} + +func (c *clusterTransportCreds) OverrideServerName(s string) error { + if err := c.tls.OverrideServerName(s); err != nil { + return err + } + return c.alts.OverrideServerName(s) +} diff --git a/credentials/local/local_test.go b/credentials/local/local_test.go index 00ae39f07e5..47f8dbb4ec8 100644 --- a/credentials/local/local_test.go +++ b/credentials/local/local_test.go @@ -131,11 +131,13 @@ func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net. serverRawConn, err := lis.Accept() if err != nil { done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)} + return } serverAuthInfo, err := hs(serverRawConn) if err != nil { serverRawConn.Close() done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)} + return } done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil} } diff --git a/credentials/oauth/oauth.go b/credentials/oauth/oauth.go index 852ae375cfc..c748fd21ce2 100644 --- a/credentials/oauth/oauth.go +++ b/credentials/oauth/oauth.go @@ -23,6 +23,7 @@ import ( "context" "fmt" "io/ioutil" + "net/url" "sync" "golang.org/x/oauth2" @@ -56,6 +57,16 @@ func (ts TokenSource) RequireTransportSecurity() bool { return true } +// removeServiceNameFromJWTURI removes RPC service name from URI. +func removeServiceNameFromJWTURI(uri string) (string, error) { + parsed, err := url.Parse(uri) + if err != nil { + return "", err + } + parsed.Path = "/" + return parsed.String(), nil +} + type jwtAccess struct { jsonKey []byte } @@ -75,9 +86,15 @@ func NewJWTAccessFromKey(jsonKey []byte) (credentials.PerRPCCredentials, error) } func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + // Remove RPC service name from URI that will be used as audience + // in a self-signed JWT token. It follows https://google.aip.dev/auth/4111. + aud, err := removeServiceNameFromJWTURI(uri[0]) + if err != nil { + return nil, err + } // TODO: the returned TokenSource is reusable. Store it in a sync.Map, with // uri as the key, to avoid recreating for every RPC. - ts, err := google.JWTAccessTokenSourceFromJSON(j.jsonKey, uri[0]) + ts, err := google.JWTAccessTokenSourceFromJSON(j.jsonKey, aud) if err != nil { return nil, err } diff --git a/credentials/oauth/oauth_test.go b/credentials/oauth/oauth_test.go new file mode 100644 index 00000000000..7e62ecb36c1 --- /dev/null +++ b/credentials/oauth/oauth_test.go @@ -0,0 +1,60 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package oauth + +import ( + "strings" + "testing" +) + +func checkErrorMsg(err error, msg string) bool { + if err == nil && msg == "" { + return true + } else if err != nil { + return strings.Contains(err.Error(), msg) + } + return false +} + +func TestRemoveServiceNameFromJWTURI(t *testing.T) { + tests := []struct { + name string + uri string + wantedURI string + wantedErrMsg string + }{ + { + name: "invalid URI", + uri: "ht tp://foo.com", + wantedErrMsg: "first path segment in URL cannot contain colon", + }, + { + name: "valid URI", + uri: "https://foo.com/go/", + wantedURI: "https://foo.com/", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, err := removeServiceNameFromJWTURI(tt.uri); got != tt.wantedURI || !checkErrorMsg(err, tt.wantedErrMsg) { + t.Errorf("RemoveServiceNameFromJWTURI() = %s, %v, want %s, %v", got, err, tt.wantedURI, tt.wantedErrMsg) + } + }) + } +} diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go index 9285192a8eb..da5fa1ad16f 100644 --- a/credentials/sts/sts.go +++ b/credentials/sts/sts.go @@ -1,5 +1,3 @@ -// +build go1.13 - /* * * Copyright 2020 gRPC authors. diff --git a/credentials/sts/sts_test.go b/credentials/sts/sts_test.go index ac680e00111..dd634361d7c 100644 --- a/credentials/sts/sts_test.go +++ b/credentials/sts/sts_test.go @@ -1,5 +1,3 @@ -// +build go1.13 - /* * * Copyright 2020 gRPC authors. @@ -37,7 +35,7 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal" + icredentials "google.golang.org/grpc/internal/credentials" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" ) @@ -104,7 +102,7 @@ func createTestContext(ctx context.Context, s credentials.SecurityLevel) context Method: "testInfo", AuthInfo: auth, } - return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) + return icredentials.NewRequestInfoContext(ctx, ri) } // errReader implements the io.Reader interface and returns an error from the diff --git a/credentials/tls.go b/credentials/tls.go index 8ee7124f226..784822d0560 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -230,4 +230,7 @@ var cipherSuiteLookup = map[uint16]string{ tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + tls.TLS_AES_128_GCM_SHA256: "TLS_AES_128_GCM_SHA256", + tls.TLS_AES_256_GCM_SHA384: "TLS_AES_256_GCM_SHA384", + tls.TLS_CHACHA20_POLY1305_SHA256: "TLS_CHACHA20_POLY1305_SHA256", } diff --git a/credentials/tls/certprovider/distributor_test.go b/credentials/tls/certprovider/distributor_test.go index bec00e919bc..48d51375616 100644 --- a/credentials/tls/certprovider/distributor_test.go +++ b/credentials/tls/certprovider/distributor_test.go @@ -1,5 +1,3 @@ -// +build go1.13 - /* * * Copyright 2020 gRPC authors. diff --git a/credentials/tls/certprovider/meshca/builder.go b/credentials/tls/certprovider/meshca/builder.go deleted file mode 100644 index 4b8af7c9b3c..00000000000 --- a/credentials/tls/certprovider/meshca/builder.go +++ /dev/null @@ -1,165 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "crypto/x509" - "encoding/json" - "fmt" - "sync" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/sts" - "google.golang.org/grpc/credentials/tls/certprovider" - "google.golang.org/grpc/internal/backoff" -) - -const pluginName = "mesh_ca" - -// For overriding in unit tests. -var ( - grpcDialFunc = grpc.Dial - backoffFunc = backoff.DefaultExponential.Backoff -) - -func init() { - certprovider.Register(newPluginBuilder()) -} - -func newPluginBuilder() *pluginBuilder { - return &pluginBuilder{clients: make(map[ccMapKey]*refCountedCC)} -} - -// Key for the map containing ClientConns to the MeshCA server. Only the server -// name and the STS options (which is used to create call creds) from the plugin -// configuration determine if two configs can share the same ClientConn. Hence -// only those form the key to this map. -type ccMapKey struct { - name string - stsOpts sts.Options -} - -// refCountedCC wraps a grpc.ClientConn to MeshCA along with a reference count. -type refCountedCC struct { - cc *grpc.ClientConn - refCnt int -} - -// pluginBuilder is an implementation of the certprovider.Builder interface, -// which builds certificate provider instances to get certificates signed from -// the MeshCA. -type pluginBuilder struct { - // A collection of ClientConns to the MeshCA server along with a reference - // count. Provider instances whose config point to the same server name will - // end up sharing the ClientConn. - mu sync.Mutex - clients map[ccMapKey]*refCountedCC -} - -// ParseConfig parses the configuration to be passed to the MeshCA plugin -// implementation. Expects the config to be a json.RawMessage which contains a -// serialized JSON representation of the meshca_experimental.GoogleMeshCaConfig -// proto message. -// -// Takes care of sharing the ClientConn to the MeshCA server among -// different plugin instantiations. -func (b *pluginBuilder) ParseConfig(c interface{}) (*certprovider.BuildableConfig, error) { - data, ok := c.(json.RawMessage) - if !ok { - return nil, fmt.Errorf("meshca: unsupported config type: %T", c) - } - cfg, err := pluginConfigFromJSON(data) - if err != nil { - return nil, err - } - return certprovider.NewBuildableConfig(pluginName, cfg.canonical(), func(opts certprovider.BuildOptions) certprovider.Provider { - return b.buildFromConfig(cfg, opts) - }), nil -} - -// buildFromConfig builds a certificate provider instance for the given config -// and options. Provider instances are shared wherever possible. -func (b *pluginBuilder) buildFromConfig(cfg *pluginConfig, opts certprovider.BuildOptions) certprovider.Provider { - b.mu.Lock() - defer b.mu.Unlock() - - ccmk := ccMapKey{ - name: cfg.serverURI, - stsOpts: cfg.stsOpts, - } - rcc, ok := b.clients[ccmk] - if !ok { - // STS call credentials take care of exchanging a locally provisioned - // JWT token for an access token which will be accepted by the MeshCA. - callCreds, err := sts.NewCredentials(cfg.stsOpts) - if err != nil { - logger.Errorf("sts.NewCredentials() failed: %v", err) - return nil - } - - // MeshCA is a public endpoint whose certificate is Web-PKI compliant. - // So, we just need to use the system roots to authenticate the MeshCA. - cp, err := x509.SystemCertPool() - if err != nil { - logger.Errorf("x509.SystemCertPool() failed: %v", err) - return nil - } - transportCreds := credentials.NewClientTLSFromCert(cp, "") - - cc, err := grpcDialFunc(cfg.serverURI, grpc.WithTransportCredentials(transportCreds), grpc.WithPerRPCCredentials(callCreds)) - if err != nil { - logger.Errorf("grpc.Dial(%s) failed: %v", cfg.serverURI, err) - return nil - } - - rcc = &refCountedCC{cc: cc} - b.clients[ccmk] = rcc - } - rcc.refCnt++ - - p := newProviderPlugin(providerParams{ - cc: rcc.cc, - cfg: cfg, - opts: opts, - backoff: backoffFunc, - doneFunc: func() { - // The plugin implementation will invoke this function when it is - // being closed, and here we take care of closing the ClientConn - // when there are no more plugins using it. We need to acquire the - // lock before accessing the rcc from the enclosing function. - b.mu.Lock() - defer b.mu.Unlock() - rcc.refCnt-- - if rcc.refCnt == 0 { - logger.Infof("Closing grpc.ClientConn to %s", ccmk.name) - rcc.cc.Close() - delete(b.clients, ccmk) - } - }, - }) - return p -} - -// Name returns the MeshCA plugin name. -func (b *pluginBuilder) Name() string { - return pluginName -} diff --git a/credentials/tls/certprovider/meshca/builder_test.go b/credentials/tls/certprovider/meshca/builder_test.go deleted file mode 100644 index 79035d008d9..00000000000 --- a/credentials/tls/certprovider/meshca/builder_test.go +++ /dev/null @@ -1,177 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "context" - "encoding/json" - "fmt" - "testing" - - "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials/tls/certprovider" - "google.golang.org/grpc/internal/testutils" -) - -func overrideHTTPFuncs() func() { - // Directly override the functions which are used to read the zone and - // audience instead of overriding the http.Client. - origReadZone := readZoneFunc - readZoneFunc = func(httpDoer) string { return "test-zone" } - origReadAudience := readAudienceFunc - readAudienceFunc = func(httpDoer) string { return "test-audience" } - return func() { - readZoneFunc = origReadZone - readAudienceFunc = origReadAudience - } -} - -func (s) TestBuildSameConfig(t *testing.T) { - defer overrideHTTPFuncs()() - - // We will attempt to create `cnt` number of providers. So we create a - // channel of the same size here, even though we expect only one ClientConn - // to be pushed into this channel. This makes sure that even if more than - // one ClientConn ends up being created, the Build() call does not block. - const cnt = 5 - ccChan := testutils.NewChannelWithSize(cnt) - - // Override the dial func to dial a dummy MeshCA endpoint, and also push the - // returned ClientConn on a channel to be inspected by the test. - origDialFunc := grpcDialFunc - grpcDialFunc = func(string, ...grpc.DialOption) (*grpc.ClientConn, error) { - cc, err := grpc.Dial("dummy-meshca-endpoint", grpc.WithInsecure()) - ccChan.Send(cc) - return cc, err - } - defer func() { grpcDialFunc = origDialFunc }() - - // Parse a good config to generate a stable config which will be passed to - // invocations of Build(). - builder := newPluginBuilder() - buildableConfig, err := builder.ParseConfig(goodConfigFullySpecified) - if err != nil { - t.Fatalf("builder.ParseConfig(%q) failed: %v", goodConfigFullySpecified, err) - } - - // Create multiple providers with the same config. All these providers must - // end up sharing the same ClientConn. - providers := []certprovider.Provider{} - for i := 0; i < cnt; i++ { - p, err := buildableConfig.Build(certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("Build(%+v) failed: %v", buildableConfig, err) - } - providers = append(providers, p) - } - - // Make sure only one ClientConn is created. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - val, err := ccChan.Receive(ctx) - if err != nil { - t.Fatalf("Failed to create ClientConn: %v", err) - } - testCC := val.(*grpc.ClientConn) - - // Attempt to read the second ClientConn should timeout. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer cancel() - if _, err := ccChan.Receive(ctx); err != context.DeadlineExceeded { - t.Fatal("Builder created more than one ClientConn") - } - - for _, p := range providers { - p.Close() - } - - for { - state := testCC.GetState() - if state == connectivity.Shutdown { - break - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if !testCC.WaitForStateChange(ctx, state) { - t.Fatalf("timeout waiting for clientConn state to change from %s", state) - } - } -} - -func (s) TestBuildDifferentConfig(t *testing.T) { - defer overrideHTTPFuncs()() - - // We will attempt to create two providers with different configs. So we - // expect two ClientConns to be pushed on to this channel. - const cnt = 2 - ccChan := testutils.NewChannelWithSize(cnt) - - // Override the dial func to dial a dummy MeshCA endpoint, and also push the - // returned ClientConn on a channel to be inspected by the test. - origDialFunc := grpcDialFunc - grpcDialFunc = func(string, ...grpc.DialOption) (*grpc.ClientConn, error) { - cc, err := grpc.Dial("dummy-meshca-endpoint", grpc.WithInsecure()) - ccChan.Send(cc) - return cc, err - } - defer func() { grpcDialFunc = origDialFunc }() - - builder := newPluginBuilder() - providers := []certprovider.Provider{} - for i := 0; i < cnt; i++ { - // Copy the good test config and modify the serverURI to make sure that - // a new provider is created for the config. - inputConfig := json.RawMessage(fmt.Sprintf(goodConfigFormatStr, fmt.Sprintf("test-mesh-ca:%d", i))) - buildableConfig, err := builder.ParseConfig(inputConfig) - if err != nil { - t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err) - } - - p, err := buildableConfig.Build(certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("Build(%+v) failed: %v", buildableConfig, err) - } - providers = append(providers, p) - } - - // Make sure two ClientConns are created. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - for i := 0; i < cnt; i++ { - if _, err := ccChan.Receive(ctx); err != nil { - t.Fatalf("Failed to create ClientConn: %v", err) - } - } - - // Close the first provider, and attempt to read key material from the - // second provider. The call to read key material should timeout, but it - // should not return certprovider.errProviderClosed. - providers[0].Close() - ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer cancel() - if _, err := providers[1].KeyMaterial(ctx); err != context.DeadlineExceeded { - t.Fatalf("provider.KeyMaterial(ctx) = %v, want contextDeadlineExceeded", err) - } - - // Close the second provider to make sure that the leakchecker is happy. - providers[1].Close() -} diff --git a/credentials/tls/certprovider/meshca/config.go b/credentials/tls/certprovider/meshca/config.go deleted file mode 100644 index c0772b3bb7e..00000000000 --- a/credentials/tls/certprovider/meshca/config.go +++ /dev/null @@ -1,310 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/http/httputil" - "path" - "strings" - "time" - - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - "github.com/golang/protobuf/ptypes" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/types/known/durationpb" - - "google.golang.org/grpc/credentials/sts" -) - -const ( - // GKE metadata server endpoint. - mdsBaseURI = "http://metadata.google.internal/" - mdsRequestTimeout = 5 * time.Second - - // The following are default values used in the interaction with MeshCA. - defaultMeshCaEndpoint = "meshca.googleapis.com" - defaultCallTimeout = 10 * time.Second - defaultCertLifetimeSecs = 86400 // 24h in seconds - defaultCertGraceTimeSecs = 43200 // 12h in seconds - defaultKeyTypeRSA = "RSA" - defaultKeySize = 2048 - - // The following are default values used in the interaction with STS or - // Secure Token Service, which is used to exchange the JWT token for an - // access token. - defaultSTSEndpoint = "securetoken.googleapis.com" - defaultCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" - defaultRequestedTokenType = "urn:ietf:params:oauth:token-type:access_token" - defaultSubjectTokenType = "urn:ietf:params:oauth:token-type:jwt" -) - -// For overriding in unit tests. -var ( - makeHTTPDoer = makeHTTPClient - readZoneFunc = readZone - readAudienceFunc = readAudience -) - -type pluginConfig struct { - serverURI string - stsOpts sts.Options - callTimeout time.Duration - certLifetime time.Duration - certGraceTime time.Duration - keyType string - keySize int - location string -} - -// Type of key to be embedded in CSRs sent to the MeshCA. -const ( - keyTypeUnknown = 0 - keyTypeRSA = 1 -) - -// pluginConfigFromJSON parses the provided config in JSON. -// -// For certain values missing in the config, we use default values defined at -// the top of this file. -// -// If the location field or STS audience field is missing, we try talking to the -// GKE Metadata server and try to infer these values. If this attempt does not -// succeed, we let those fields have empty values. -func pluginConfigFromJSON(data json.RawMessage) (*pluginConfig, error) { - // This anonymous struct corresponds to the expected JSON config. - cfgJSON := &struct { - Server json.RawMessage `json:"server,omitempty"` // Expect a v3corepb.ApiConfigSource - CertificateLifetime json.RawMessage `json:"certificate_lifetime,omitempty"` // Expect a durationpb.Duration - RenewalGracePeriod json.RawMessage `json:"renewal_grace_period,omitempty"` // Expect a durationpb.Duration - KeyType int `json:"key_type,omitempty"` - KeySize int `json:"key_size,omitempty"` - Location string `json:"location,omitempty"` - }{} - if err := json.Unmarshal(data, cfgJSON); err != nil { - return nil, fmt.Errorf("meshca: failed to unmarshal config: %v", err) - } - - // Further unmarshal fields represented as json.RawMessage in the above - // anonymous struct, and use default values if not specified. - server := &v3corepb.ApiConfigSource{} - if cfgJSON.Server != nil { - if err := protojson.Unmarshal(cfgJSON.Server, server); err != nil { - return nil, fmt.Errorf("meshca: protojson.Unmarshal(%+v) failed: %v", cfgJSON.Server, err) - } - } - certLifetime := &durationpb.Duration{Seconds: defaultCertLifetimeSecs} - if cfgJSON.CertificateLifetime != nil { - if err := protojson.Unmarshal(cfgJSON.CertificateLifetime, certLifetime); err != nil { - return nil, fmt.Errorf("meshca: protojson.Unmarshal(%+v) failed: %v", cfgJSON.CertificateLifetime, err) - } - } - certGraceTime := &durationpb.Duration{Seconds: defaultCertGraceTimeSecs} - if cfgJSON.RenewalGracePeriod != nil { - if err := protojson.Unmarshal(cfgJSON.RenewalGracePeriod, certGraceTime); err != nil { - return nil, fmt.Errorf("meshca: protojson.Unmarshal(%+v) failed: %v", cfgJSON.RenewalGracePeriod, err) - } - } - - if api := server.GetApiType(); api != v3corepb.ApiConfigSource_GRPC { - return nil, fmt.Errorf("meshca: server has apiType %s, want %s", api, v3corepb.ApiConfigSource_GRPC) - } - - pc := &pluginConfig{ - certLifetime: certLifetime.AsDuration(), - certGraceTime: certGraceTime.AsDuration(), - } - gs := server.GetGrpcServices() - if l := len(gs); l != 1 { - return nil, fmt.Errorf("meshca: number of gRPC services in config is %d, expected 1", l) - } - grpcService := gs[0] - googGRPC := grpcService.GetGoogleGrpc() - if googGRPC == nil { - return nil, errors.New("meshca: missing google gRPC service in config") - } - pc.serverURI = googGRPC.GetTargetUri() - if pc.serverURI == "" { - pc.serverURI = defaultMeshCaEndpoint - } - - callCreds := googGRPC.GetCallCredentials() - if len(callCreds) == 0 { - return nil, errors.New("meshca: missing call credentials in config") - } - var stsCallCreds *v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService - for _, cc := range callCreds { - if stsCallCreds = cc.GetStsService(); stsCallCreds != nil { - break - } - } - if stsCallCreds == nil { - return nil, errors.New("meshca: missing STS call credentials in config") - } - if stsCallCreds.GetSubjectTokenPath() == "" { - return nil, errors.New("meshca: missing subjectTokenPath in STS call credentials config") - } - pc.stsOpts = makeStsOptsWithDefaults(stsCallCreds) - - var err error - if pc.callTimeout, err = ptypes.Duration(grpcService.GetTimeout()); err != nil { - pc.callTimeout = defaultCallTimeout - } - switch cfgJSON.KeyType { - case keyTypeUnknown, keyTypeRSA: - pc.keyType = defaultKeyTypeRSA - default: - return nil, fmt.Errorf("meshca: unsupported key type: %s, only support RSA keys", pc.keyType) - } - pc.keySize = cfgJSON.KeySize - if pc.keySize == 0 { - pc.keySize = defaultKeySize - } - pc.location = cfgJSON.Location - if pc.location == "" { - pc.location = readZoneFunc(makeHTTPDoer()) - } - - return pc, nil -} - -func (pc *pluginConfig) canonical() []byte { - return []byte(fmt.Sprintf("%s:%s:%s:%s:%s:%s:%d:%s", pc.serverURI, pc.stsOpts, pc.callTimeout, pc.certLifetime, pc.certGraceTime, pc.keyType, pc.keySize, pc.location)) -} - -func makeStsOptsWithDefaults(stsCallCreds *v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService) sts.Options { - opts := sts.Options{ - TokenExchangeServiceURI: stsCallCreds.GetTokenExchangeServiceUri(), - Resource: stsCallCreds.GetResource(), - Audience: stsCallCreds.GetAudience(), - Scope: stsCallCreds.GetScope(), - RequestedTokenType: stsCallCreds.GetRequestedTokenType(), - SubjectTokenPath: stsCallCreds.GetSubjectTokenPath(), - SubjectTokenType: stsCallCreds.GetSubjectTokenType(), - ActorTokenPath: stsCallCreds.GetActorTokenPath(), - ActorTokenType: stsCallCreds.GetActorTokenType(), - } - - // Use sane defaults for unspecified fields. - if opts.TokenExchangeServiceURI == "" { - opts.TokenExchangeServiceURI = defaultSTSEndpoint - } - if opts.Audience == "" { - opts.Audience = readAudienceFunc(makeHTTPDoer()) - } - if opts.Scope == "" { - opts.Scope = defaultCloudPlatformScope - } - if opts.RequestedTokenType == "" { - opts.RequestedTokenType = defaultRequestedTokenType - } - if opts.SubjectTokenType == "" { - opts.SubjectTokenType = defaultSubjectTokenType - } - return opts -} - -// httpDoer wraps the single method on the http.Client type that we use. This -// helps with overriding in unit tests. -type httpDoer interface { - Do(req *http.Request) (*http.Response, error) -} - -func makeHTTPClient() httpDoer { - return &http.Client{Timeout: mdsRequestTimeout} -} - -func readMetadata(client httpDoer, uriPath string) (string, error) { - req, err := http.NewRequest("GET", mdsBaseURI+uriPath, nil) - if err != nil { - return "", err - } - req.Header.Add("Metadata-Flavor", "Google") - - resp, err := client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", err - } - if resp.StatusCode != http.StatusOK { - dump, err := httputil.DumpRequestOut(req, false) - if err != nil { - logger.Warningf("Failed to dump HTTP request: %v", err) - } - logger.Warningf("Request %q returned status %v", dump, resp.StatusCode) - } - return string(body), err -} - -func readZone(client httpDoer) string { - zoneURI := "computeMetadata/v1/instance/zone" - data, err := readMetadata(client, zoneURI) - if err != nil { - logger.Warningf("GET %s failed: %v", path.Join(mdsBaseURI, zoneURI), err) - return "" - } - - // The output returned by the metadata server looks like this: - // projects//zones/ - parts := strings.Split(data, "/") - if len(parts) == 0 { - logger.Warningf("GET %s returned {%s}, does not match expected format {projects//zones/}", path.Join(mdsBaseURI, zoneURI)) - return "" - } - return parts[len(parts)-1] -} - -// readAudience constructs the audience field to be used in the STS request, if -// it is not specified in the plugin configuration. -// -// "identitynamespace:{TRUST_DOMAIN}:{GKE_CLUSTER_URL}" is the format of the -// audience field. When workload identity is enabled on a GCP project, a default -// trust domain is created whose value is "{PROJECT_ID}.svc.id.goog". The format -// of the GKE_CLUSTER_URL is: -// https://container.googleapis.com/v1/projects/{PROJECT_ID}/zones/{ZONE}/clusters/{CLUSTER_NAME}. -func readAudience(client httpDoer) string { - projURI := "computeMetadata/v1/project/project-id" - project, err := readMetadata(client, projURI) - if err != nil { - logger.Warningf("GET %s failed: %v", path.Join(mdsBaseURI, projURI), err) - return "" - } - trustDomain := fmt.Sprintf("%s.svc.id.goog", project) - - clusterURI := "computeMetadata/v1/instance/attributes/cluster-name" - cluster, err := readMetadata(client, clusterURI) - if err != nil { - logger.Warningf("GET %s failed: %v", path.Join(mdsBaseURI, clusterURI), err) - return "" - } - zone := readZoneFunc(client) - clusterURL := fmt.Sprintf("https://container.googleapis.com/v1/projects/%s/zones/%s/clusters/%s", project, zone, cluster) - audience := fmt.Sprintf("identitynamespace:%s:%s", trustDomain, clusterURL) - return audience -} diff --git a/credentials/tls/certprovider/meshca/config_test.go b/credentials/tls/certprovider/meshca/config_test.go deleted file mode 100644 index 5deb484f341..00000000000 --- a/credentials/tls/certprovider/meshca/config_test.go +++ /dev/null @@ -1,375 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - - "google.golang.org/grpc/internal/grpctest" - "google.golang.org/grpc/internal/testutils" -) - -const ( - testProjectID = "test-project-id" - testGKECluster = "test-gke-cluster" - testGCEZone = "test-zone" -) - -type s struct { - grpctest.Tester -} - -func Test(t *testing.T) { - grpctest.RunSubTests(t, s{}) -} - -var ( - goodConfigFormatStr = ` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "target_uri": %q, - "call_credentials": [ - { - "access_token": "foo" - }, - { - "sts_service": { - "token_exchange_service_uri": "http://test-sts", - "resource": "test-resource", - "audience": "test-audience", - "scope": "test-scope", - "requested_token_type": "test-requested-token-type", - "subject_token_path": "test-subject-token-path", - "subject_token_type": "test-subject-token-type", - "actor_token_path": "test-actor-token-path", - "actor_token_type": "test-actor-token-type" - } - } - ] - }, - "timeout": "10s" - } - ] - }, - "certificate_lifetime": "86400s", - "renewal_grace_period": "43200s", - "key_type": 1, - "key_size": 2048, - "location": "us-west1-b" - }` - goodConfigWithDefaults = json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "call_credentials": [ - { - "sts_service": { - "subject_token_path": "test-subject-token-path" - } - } - ] - }, - "timeout": "10s" - } - ] - } - }`) -) - -var goodConfigFullySpecified = json.RawMessage(fmt.Sprintf(goodConfigFormatStr, "test-meshca")) - -// verifyReceivedRequest reads the HTTP request received by the fake client -// (exposed through a channel), and verifies that it matches the expected -// request. -func verifyReceivedRequest(fc *testutils.FakeHTTPClient, wantURI string) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - val, err := fc.ReqChan.Receive(ctx) - if err != nil { - return err - } - gotReq := val.(*http.Request) - if gotURI := gotReq.URL.String(); gotURI != wantURI { - return fmt.Errorf("request contains URL %q want %q", gotURI, wantURI) - } - if got, want := gotReq.Header.Get("Metadata-Flavor"), "Google"; got != want { - return fmt.Errorf("request contains flavor %q want %q", got, want) - } - return nil -} - -// TestParseConfigSuccessFullySpecified tests the case where the config is fully -// specified and no defaults are required. -func (s) TestParseConfigSuccessFullySpecified(t *testing.T) { - wantConfig := "test-meshca:http://test-sts:test-resource:test-audience:test-scope:test-requested-token-type:test-subject-token-path:test-subject-token-type:test-actor-token-path:test-actor-token-type:10s:24h0m0s:12h0m0s:RSA:2048:us-west1-b" - - cfg, err := pluginConfigFromJSON(goodConfigFullySpecified) - if err != nil { - t.Fatalf("pluginConfigFromJSON(%q) failed: %v", goodConfigFullySpecified, err) - } - gotConfig := cfg.canonical() - if diff := cmp.Diff(wantConfig, string(gotConfig)); diff != "" { - t.Errorf("pluginConfigFromJSON(%q) returned config does not match expected (-want +got):\n%s", string(goodConfigFullySpecified), diff) - } -} - -// TestParseConfigSuccessWithDefaults tests cases where the config is not fully -// specified, and we end up using some sane defaults. -func (s) TestParseConfigSuccessWithDefaults(t *testing.T) { - wantConfig := fmt.Sprintf("%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s", - "meshca.googleapis.com", // Mesh CA Server URI. - "securetoken.googleapis.com", // STS Server URI. - "", // STS Resource Name. - "identitynamespace:test-project-id.svc.id.goog:https://container.googleapis.com/v1/projects/test-project-id/zones/test-zone/clusters/test-gke-cluster", // STS Audience. - "https://www.googleapis.com/auth/cloud-platform", // STS Scope. - "urn:ietf:params:oauth:token-type:access_token", // STS requested token type. - "test-subject-token-path", // STS subject token path. - "urn:ietf:params:oauth:token-type:jwt", // STS subject token type. - "", // STS actor token path. - "", // STS actor token type. - "10s", // Call timeout. - "24h0m0s", // Cert life time. - "12h0m0s", // Cert grace time. - "RSA", // Key type - "2048", // Key size - "test-zone", // Zone - ) - - // We expect the config parser to make four HTTP requests and receive four - // responses. Hence we setup the request and response channels in the fake - // client with appropriate buffer size. - fc := &testutils.FakeHTTPClient{ - ReqChan: testutils.NewChannelWithSize(4), - RespChan: testutils.NewChannelWithSize(4), - } - // Set up the responses to be delivered to the config parser by the fake - // client. The config parser expects responses with project_id, - // gke_cluster_id and gce_zone. The zone is read twice, once as part of - // reading the STS audience and once to get location metadata. - fc.RespChan.Send(&http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader([]byte(testProjectID))), - }) - fc.RespChan.Send(&http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader([]byte(testGKECluster))), - }) - fc.RespChan.Send(&http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf("projects/%s/zones/%s", testProjectID, testGCEZone)))), - }) - fc.RespChan.Send(&http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf("projects/%s/zones/%s", testProjectID, testGCEZone)))), - }) - // Override the http.Client with our fakeClient. - origMakeHTTPDoer := makeHTTPDoer - makeHTTPDoer = func() httpDoer { return fc } - defer func() { makeHTTPDoer = origMakeHTTPDoer }() - - // Spawn a goroutine to verify the HTTP requests sent out as part of the - // config parsing. - errCh := make(chan error, 1) - go func() { - if err := verifyReceivedRequest(fc, "http://metadata.google.internal/computeMetadata/v1/project/project-id"); err != nil { - errCh <- err - return - } - if err := verifyReceivedRequest(fc, "http://metadata.google.internal/computeMetadata/v1/instance/attributes/cluster-name"); err != nil { - errCh <- err - return - } - if err := verifyReceivedRequest(fc, "http://metadata.google.internal/computeMetadata/v1/instance/zone"); err != nil { - errCh <- err - return - } - errCh <- nil - }() - - cfg, err := pluginConfigFromJSON(goodConfigWithDefaults) - if err != nil { - t.Fatalf("pluginConfigFromJSON(%q) failed: %v", goodConfigWithDefaults, err) - } - gotConfig := cfg.canonical() - if diff := cmp.Diff(wantConfig, string(gotConfig)); diff != "" { - t.Errorf("builder.ParseConfig(%q) returned config does not match expected (-want +got):\n%s", goodConfigWithDefaults, diff) - } - - if err := <-errCh; err != nil { - t.Fatal(err) - } -} - -// TestParseConfigFailureCases tests several invalid configs which all result in -// config parsing failures. -func (s) TestParseConfigFailureCases(t *testing.T) { - tests := []struct { - desc string - inputConfig json.RawMessage - wantErr string - }{ - { - desc: "invalid JSON", - inputConfig: json.RawMessage(`bad bad json`), - wantErr: "failed to unmarshal config", - }, - { - desc: "bad apiType", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 1 - } - }`), - wantErr: "server has apiType REST, want GRPC", - }, - { - desc: "no grpc services", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2 - } - }`), - wantErr: "number of gRPC services in config is 0, expected 1", - }, - { - desc: "too many grpc services", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [{}, {}] - } - }`), - wantErr: "number of gRPC services in config is 2, expected 1", - }, - { - desc: "missing google grpc service", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "envoyGrpc": {} - } - ] - } - }`), - wantErr: "missing google gRPC service in config", - }, - { - desc: "missing call credentials", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "target_uri": "foo" - } - } - ] - } - }`), - wantErr: "missing call credentials in config", - }, - { - desc: "missing STS call credentials", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "target_uri": "foo", - "call_credentials": [ - { - "access_token": "foo" - } - ] - } - } - ] - } - }`), - wantErr: "missing STS call credentials in config", - }, - { - desc: "with no defaults", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "target_uri": "foo", - "call_credentials": [ - { - "sts_service": {} - } - ] - } - } - ] - } - }`), - wantErr: "missing subjectTokenPath in STS call credentials config", - }, - } - - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - cfg, err := pluginConfigFromJSON(test.inputConfig) - if err == nil { - t.Fatalf("pluginConfigFromJSON(%q) = %v, expected to return error (%v)", test.inputConfig, string(cfg.canonical()), test.wantErr) - - } - if !strings.Contains(err.Error(), test.wantErr) { - t.Fatalf("builder.ParseConfig(%q) = (%v), want error (%v)", test.inputConfig, err, test.wantErr) - } - }) - } -} diff --git a/credentials/tls/certprovider/meshca/internal/v1/meshca.pb.go b/credentials/tls/certprovider/meshca/internal/v1/meshca.pb.go deleted file mode 100644 index 387f8c55abc..00000000000 --- a/credentials/tls/certprovider/meshca/internal/v1/meshca.pb.go +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright 2019 Istio Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.25.0 -// protoc v3.14.0 -// source: istio/google/security/meshca/v1/meshca.proto - -package google_security_meshca_v1 - -import ( - proto "github.com/golang/protobuf/proto" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - durationpb "google.golang.org/protobuf/types/known/durationpb" - reflect "reflect" - sync "sync" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -// This is a compile-time assertion that a sufficiently up-to-date version -// of the legacy proto package is being used. -const _ = proto.ProtoPackageIsVersion4 - -// Certificate request message. -type MeshCertificateRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The request ID must be a valid UUID with the exception that zero UUID is - // not supported (00000000-0000-0000-0000-000000000000). - RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - // PEM-encoded certificate request. - Csr string `protobuf:"bytes,2,opt,name=csr,proto3" json:"csr,omitempty"` - // Optional: requested certificate validity period. - Validity *durationpb.Duration `protobuf:"bytes,3,opt,name=validity,proto3" json:"validity,omitempty"` // Reserved 4 -} - -func (x *MeshCertificateRequest) Reset() { - *x = MeshCertificateRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_istio_google_security_meshca_v1_meshca_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *MeshCertificateRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*MeshCertificateRequest) ProtoMessage() {} - -func (x *MeshCertificateRequest) ProtoReflect() protoreflect.Message { - mi := &file_istio_google_security_meshca_v1_meshca_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use MeshCertificateRequest.ProtoReflect.Descriptor instead. -func (*MeshCertificateRequest) Descriptor() ([]byte, []int) { - return file_istio_google_security_meshca_v1_meshca_proto_rawDescGZIP(), []int{0} -} - -func (x *MeshCertificateRequest) GetRequestId() string { - if x != nil { - return x.RequestId - } - return "" -} - -func (x *MeshCertificateRequest) GetCsr() string { - if x != nil { - return x.Csr - } - return "" -} - -func (x *MeshCertificateRequest) GetValidity() *durationpb.Duration { - if x != nil { - return x.Validity - } - return nil -} - -// Certificate response message. -type MeshCertificateResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // PEM-encoded certificate chain. - // Leaf cert is element '0'. Root cert is element 'n'. - CertChain []string `protobuf:"bytes,1,rep,name=cert_chain,json=certChain,proto3" json:"cert_chain,omitempty"` -} - -func (x *MeshCertificateResponse) Reset() { - *x = MeshCertificateResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_istio_google_security_meshca_v1_meshca_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *MeshCertificateResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*MeshCertificateResponse) ProtoMessage() {} - -func (x *MeshCertificateResponse) ProtoReflect() protoreflect.Message { - mi := &file_istio_google_security_meshca_v1_meshca_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use MeshCertificateResponse.ProtoReflect.Descriptor instead. -func (*MeshCertificateResponse) Descriptor() ([]byte, []int) { - return file_istio_google_security_meshca_v1_meshca_proto_rawDescGZIP(), []int{1} -} - -func (x *MeshCertificateResponse) GetCertChain() []string { - if x != nil { - return x.CertChain - } - return nil -} - -var File_istio_google_security_meshca_v1_meshca_proto protoreflect.FileDescriptor - -var file_istio_google_security_meshca_v1_meshca_proto_rawDesc = []byte{ - 0x0a, 0x2c, 0x69, 0x73, 0x74, 0x69, 0x6f, 0x2f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x73, - 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2f, 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2f, 0x76, - 0x31, 0x2f, 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x19, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x73, 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2e, - 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2e, 0x76, 0x31, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x80, 0x01, 0x0a, 0x16, 0x4d, 0x65, - 0x73, 0x68, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x49, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x63, 0x73, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x63, 0x73, 0x72, 0x12, 0x35, 0x0a, 0x08, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x69, 0x74, - 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x08, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x69, 0x74, 0x79, 0x22, 0x38, 0x0a, 0x17, - 0x4d, 0x65, 0x73, 0x68, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x5f, - 0x63, 0x68, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x63, 0x65, 0x72, - 0x74, 0x43, 0x68, 0x61, 0x69, 0x6e, 0x32, 0x96, 0x01, 0x0a, 0x16, 0x4d, 0x65, 0x73, 0x68, 0x43, - 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x12, 0x7c, 0x0a, 0x11, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x43, 0x65, 0x72, 0x74, 0x69, - 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x31, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x73, 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2e, 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2e, - 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x32, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x73, 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2e, 0x6d, 0x65, 0x73, 0x68, - 0x63, 0x61, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, - 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, - 0x2e, 0x0a, 0x1d, 0x63, 0x6f, 0x6d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x73, 0x65, - 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2e, 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2e, 0x76, 0x31, - 0x42, 0x0b, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x61, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} - -var ( - file_istio_google_security_meshca_v1_meshca_proto_rawDescOnce sync.Once - file_istio_google_security_meshca_v1_meshca_proto_rawDescData = file_istio_google_security_meshca_v1_meshca_proto_rawDesc -) - -func file_istio_google_security_meshca_v1_meshca_proto_rawDescGZIP() []byte { - file_istio_google_security_meshca_v1_meshca_proto_rawDescOnce.Do(func() { - file_istio_google_security_meshca_v1_meshca_proto_rawDescData = protoimpl.X.CompressGZIP(file_istio_google_security_meshca_v1_meshca_proto_rawDescData) - }) - return file_istio_google_security_meshca_v1_meshca_proto_rawDescData -} - -var file_istio_google_security_meshca_v1_meshca_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_istio_google_security_meshca_v1_meshca_proto_goTypes = []interface{}{ - (*MeshCertificateRequest)(nil), // 0: google.security.meshca.v1.MeshCertificateRequest - (*MeshCertificateResponse)(nil), // 1: google.security.meshca.v1.MeshCertificateResponse - (*durationpb.Duration)(nil), // 2: google.protobuf.Duration -} -var file_istio_google_security_meshca_v1_meshca_proto_depIdxs = []int32{ - 2, // 0: google.security.meshca.v1.MeshCertificateRequest.validity:type_name -> google.protobuf.Duration - 0, // 1: google.security.meshca.v1.MeshCertificateService.CreateCertificate:input_type -> google.security.meshca.v1.MeshCertificateRequest - 1, // 2: google.security.meshca.v1.MeshCertificateService.CreateCertificate:output_type -> google.security.meshca.v1.MeshCertificateResponse - 2, // [2:3] is the sub-list for method output_type - 1, // [1:2] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name -} - -func init() { file_istio_google_security_meshca_v1_meshca_proto_init() } -func file_istio_google_security_meshca_v1_meshca_proto_init() { - if File_istio_google_security_meshca_v1_meshca_proto != nil { - return - } - if !protoimpl.UnsafeEnabled { - file_istio_google_security_meshca_v1_meshca_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*MeshCertificateRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_istio_google_security_meshca_v1_meshca_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*MeshCertificateResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_istio_google_security_meshca_v1_meshca_proto_rawDesc, - NumEnums: 0, - NumMessages: 2, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_istio_google_security_meshca_v1_meshca_proto_goTypes, - DependencyIndexes: file_istio_google_security_meshca_v1_meshca_proto_depIdxs, - MessageInfos: file_istio_google_security_meshca_v1_meshca_proto_msgTypes, - }.Build() - File_istio_google_security_meshca_v1_meshca_proto = out.File - file_istio_google_security_meshca_v1_meshca_proto_rawDesc = nil - file_istio_google_security_meshca_v1_meshca_proto_goTypes = nil - file_istio_google_security_meshca_v1_meshca_proto_depIdxs = nil -} diff --git a/credentials/tls/certprovider/meshca/internal/v1/meshca_grpc.pb.go b/credentials/tls/certprovider/meshca/internal/v1/meshca_grpc.pb.go deleted file mode 100644 index e53a61598ab..00000000000 --- a/credentials/tls/certprovider/meshca/internal/v1/meshca_grpc.pb.go +++ /dev/null @@ -1,106 +0,0 @@ -// Code generated by protoc-gen-go-grpc. DO NOT EDIT. - -package google_security_meshca_v1 - -import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" -) - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 - -// MeshCertificateServiceClient is the client API for MeshCertificateService service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type MeshCertificateServiceClient interface { - // Using provided CSR, returns a signed certificate that represents a GCP - // service account identity. - CreateCertificate(ctx context.Context, in *MeshCertificateRequest, opts ...grpc.CallOption) (*MeshCertificateResponse, error) -} - -type meshCertificateServiceClient struct { - cc grpc.ClientConnInterface -} - -func NewMeshCertificateServiceClient(cc grpc.ClientConnInterface) MeshCertificateServiceClient { - return &meshCertificateServiceClient{cc} -} - -func (c *meshCertificateServiceClient) CreateCertificate(ctx context.Context, in *MeshCertificateRequest, opts ...grpc.CallOption) (*MeshCertificateResponse, error) { - out := new(MeshCertificateResponse) - err := c.cc.Invoke(ctx, "/google.security.meshca.v1.MeshCertificateService/CreateCertificate", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// MeshCertificateServiceServer is the server API for MeshCertificateService service. -// All implementations must embed UnimplementedMeshCertificateServiceServer -// for forward compatibility -type MeshCertificateServiceServer interface { - // Using provided CSR, returns a signed certificate that represents a GCP - // service account identity. - CreateCertificate(context.Context, *MeshCertificateRequest) (*MeshCertificateResponse, error) - mustEmbedUnimplementedMeshCertificateServiceServer() -} - -// UnimplementedMeshCertificateServiceServer must be embedded to have forward compatible implementations. -type UnimplementedMeshCertificateServiceServer struct { -} - -func (UnimplementedMeshCertificateServiceServer) CreateCertificate(context.Context, *MeshCertificateRequest) (*MeshCertificateResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method CreateCertificate not implemented") -} -func (UnimplementedMeshCertificateServiceServer) mustEmbedUnimplementedMeshCertificateServiceServer() { -} - -// UnsafeMeshCertificateServiceServer may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to MeshCertificateServiceServer will -// result in compilation errors. -type UnsafeMeshCertificateServiceServer interface { - mustEmbedUnimplementedMeshCertificateServiceServer() -} - -func RegisterMeshCertificateServiceServer(s grpc.ServiceRegistrar, srv MeshCertificateServiceServer) { - s.RegisterService(&MeshCertificateService_ServiceDesc, srv) -} - -func _MeshCertificateService_CreateCertificate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(MeshCertificateRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(MeshCertificateServiceServer).CreateCertificate(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/google.security.meshca.v1.MeshCertificateService/CreateCertificate", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(MeshCertificateServiceServer).CreateCertificate(ctx, req.(*MeshCertificateRequest)) - } - return interceptor(ctx, in, info, handler) -} - -// MeshCertificateService_ServiceDesc is the grpc.ServiceDesc for MeshCertificateService service. -// It's only intended for direct use with grpc.RegisterService, -// and not to be introspected or modified (even as a copy) -var MeshCertificateService_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "google.security.meshca.v1.MeshCertificateService", - HandlerType: (*MeshCertificateServiceServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "CreateCertificate", - Handler: _MeshCertificateService_CreateCertificate_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "istio/google/security/meshca/v1/meshca.proto", -} diff --git a/credentials/tls/certprovider/meshca/plugin.go b/credentials/tls/certprovider/meshca/plugin.go deleted file mode 100644 index ab1958ac1fd..00000000000 --- a/credentials/tls/certprovider/meshca/plugin.go +++ /dev/null @@ -1,289 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// Package meshca provides an implementation of the Provider interface which -// communicates with MeshCA to get certificates signed. -package meshca - -import ( - "context" - "crypto" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "encoding/pem" - "fmt" - "time" - - durationpb "github.com/golang/protobuf/ptypes/duration" - "github.com/google/uuid" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/tls/certprovider" - meshgrpc "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1" - meshpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1" - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/metadata" -) - -// In requests sent to the MeshCA, we add a metadata header with this key and -// the value being the GCE zone in which the workload is running in. -const locationMetadataKey = "x-goog-request-params" - -// For overriding from unit tests. -var newDistributorFunc = func() distributor { return certprovider.NewDistributor() } - -// distributor wraps the methods on certprovider.Distributor which are used by -// the plugin. This is very useful in tests which need to know exactly when the -// plugin updates its key material. -type distributor interface { - KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) - Set(km *certprovider.KeyMaterial, err error) - Stop() -} - -// providerPlugin is an implementation of the certprovider.Provider interface, -// which gets certificates signed by communicating with the MeshCA. -type providerPlugin struct { - distributor // Holds the key material. - cancel context.CancelFunc - cc *grpc.ClientConn // Connection to MeshCA server. - cfg *pluginConfig // Plugin configuration. - opts certprovider.BuildOptions // Key material options. - logger *grpclog.PrefixLogger // Plugin instance specific prefix. - backoff func(int) time.Duration // Exponential backoff. - doneFunc func() // Notify the builder when done. -} - -// providerParams wraps params passed to the provider plugin at creation time. -type providerParams struct { - // This ClientConn to the MeshCA server is owned by the builder. - cc *grpc.ClientConn - cfg *pluginConfig - opts certprovider.BuildOptions - backoff func(int) time.Duration - doneFunc func() -} - -func newProviderPlugin(params providerParams) *providerPlugin { - ctx, cancel := context.WithCancel(context.Background()) - p := &providerPlugin{ - cancel: cancel, - cc: params.cc, - cfg: params.cfg, - opts: params.opts, - backoff: params.backoff, - doneFunc: params.doneFunc, - distributor: newDistributorFunc(), - } - p.logger = prefixLogger((p)) - p.logger.Infof("plugin created") - go p.run(ctx) - return p -} - -func (p *providerPlugin) Close() { - p.logger.Infof("plugin closed") - p.Stop() // Stop the embedded distributor. - p.cancel() - p.doneFunc() -} - -// run is a long running goroutine which periodically sends out CSRs to the -// MeshCA, and updates the underlying Distributor with the new key material. -func (p *providerPlugin) run(ctx context.Context) { - // We need to start fetching key material right away. The next attempt will - // be triggered by the timer firing. - for { - certValidity, err := p.updateKeyMaterial(ctx) - if err != nil { - return - } - - // We request a certificate with the configured validity duration (which - // is usually twice as much as the grace period). But the server is free - // to return a certificate with whatever validity time it deems right. - refreshAfter := p.cfg.certGraceTime - if refreshAfter > certValidity { - // The default value of cert grace time is half that of the default - // cert validity time. So here, when we have to use a non-default - // cert life time, we will set the grace time again to half that of - // the validity time. - refreshAfter = certValidity / 2 - } - timer := time.NewTimer(refreshAfter) - select { - case <-ctx.Done(): - return - case <-timer.C: - } - } -} - -// updateKeyMaterial generates a CSR and attempts to get it signed from the -// MeshCA. It retries with an exponential backoff till it succeeds or the -// deadline specified in ctx expires. Once it gets the CSR signed from the -// MeshCA, it updates the Distributor with the new key material. -// -// It returns the amount of time the new certificate is valid for. -func (p *providerPlugin) updateKeyMaterial(ctx context.Context) (time.Duration, error) { - client := meshgrpc.NewMeshCertificateServiceClient(p.cc) - retries := 0 - for { - if ctx.Err() != nil { - return 0, ctx.Err() - } - - if retries != 0 { - bi := p.backoff(retries) - p.logger.Warningf("Backing off for %s before attempting the next CreateCertificate() request", bi) - timer := time.NewTimer(bi) - select { - case <-timer.C: - case <-ctx.Done(): - return 0, ctx.Err() - } - } - retries++ - - privKey, err := rsa.GenerateKey(rand.Reader, p.cfg.keySize) - if err != nil { - p.logger.Warningf("RSA key generation failed: %v", err) - continue - } - // We do not set any fields in the CSR (we use an empty - // x509.CertificateRequest as the template) because the MeshCA discards - // them anyways, and uses the workload identity from the access token - // that we present (as part of the STS call creds). - csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{}, crypto.PrivateKey(privKey)) - if err != nil { - p.logger.Warningf("CSR creation failed: %v", err) - continue - } - csrPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrBytes}) - - // Send out the CSR with a call timeout and location metadata, as - // specified in the plugin configuration. - req := &meshpb.MeshCertificateRequest{ - RequestId: uuid.New().String(), - Csr: string(csrPEM), - Validity: &durationpb.Duration{Seconds: int64(p.cfg.certLifetime / time.Second)}, - } - p.logger.Debugf("Sending CreateCertificate() request: %v", req) - - callCtx, ctxCancel := context.WithTimeout(context.Background(), p.cfg.callTimeout) - callCtx = metadata.NewOutgoingContext(callCtx, metadata.Pairs(locationMetadataKey, p.cfg.location)) - resp, err := client.CreateCertificate(callCtx, req) - if err != nil { - p.logger.Warningf("CreateCertificate request failed: %v", err) - ctxCancel() - continue - } - ctxCancel() - - // The returned cert chain must contain more than one cert. Leaf cert is - // element '0', while root cert is element 'n', and the intermediate - // entries form the chain from the root to the leaf. - certChain := resp.GetCertChain() - if l := len(certChain); l <= 1 { - p.logger.Errorf("Received certificate chain contains %d certificates, need more than one", l) - continue - } - - // We need to explicitly parse the PEM cert contents as an - // x509.Certificate to read the certificate validity period. We use this - // to decide when to refresh the cert. Even though the call to - // tls.X509KeyPair actually parses the PEM contents into an - // x509.Certificate, it does not store that in the `Leaf` field. See: - // https://golang.org/pkg/crypto/tls/#X509KeyPair. - identity, intermediates, roots, err := parseCertChain(certChain) - if err != nil { - p.logger.Errorf(err.Error()) - continue - } - _, err = identity.Verify(x509.VerifyOptions{ - Intermediates: intermediates, - Roots: roots, - }) - if err != nil { - p.logger.Errorf("Certificate verification failed for return certChain: %v", err) - continue - } - - key := x509.MarshalPKCS1PrivateKey(privKey) - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: key}) - certPair, err := tls.X509KeyPair([]byte(certChain[0]), keyPEM) - if err != nil { - p.logger.Errorf("Failed to create x509 key pair: %v", err) - continue - } - - // At this point, the received response has been deemed good. - retries = 0 - - // All certs signed by the MeshCA roll up to the same root. And treating - // the last element of the returned chain as the root is the only - // supported option to get the root certificate. So, we ignore the - // options specified in the call to Build(), which contain certificate - // name and whether the caller is interested in identity or root cert. - p.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{certPair}, Roots: roots}, nil) - return time.Until(identity.NotAfter), nil - } -} - -// ParseCertChain parses the result returned by the MeshCA which consists of a -// list of PEM encoded certs. The first element in the list is the leaf or -// identity cert, while the last element is the root, and everything in between -// form the chain of trust. -// -// Caller needs to make sure that certChain has at least two elements. -func parseCertChain(certChain []string) (*x509.Certificate, *x509.CertPool, *x509.CertPool, error) { - identity, err := parseCert([]byte(certChain[0])) - if err != nil { - return nil, nil, nil, err - } - - intermediates := x509.NewCertPool() - for _, cert := range certChain[1 : len(certChain)-1] { - i, err := parseCert([]byte(cert)) - if err != nil { - return nil, nil, nil, err - } - intermediates.AddCert(i) - } - - roots := x509.NewCertPool() - root, err := parseCert([]byte(certChain[len(certChain)-1])) - if err != nil { - return nil, nil, nil, err - } - roots.AddCert(root) - - return identity, intermediates, roots, nil -} - -func parseCert(certPEM []byte) (*x509.Certificate, error) { - block, _ := pem.Decode(certPEM) - if block == nil { - return nil, fmt.Errorf("failed to decode received PEM data: %v", certPEM) - } - return x509.ParseCertificate(block.Bytes) -} diff --git a/credentials/tls/certprovider/meshca/plugin_test.go b/credentials/tls/certprovider/meshca/plugin_test.go deleted file mode 100644 index 51f545d6a0e..00000000000 --- a/credentials/tls/certprovider/meshca/plugin_test.go +++ /dev/null @@ -1,459 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/json" - "encoding/pem" - "errors" - "fmt" - "math/big" - "net" - "reflect" - "testing" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/tls/certprovider" - meshgrpc "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1" - meshpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1" - "google.golang.org/grpc/internal/testutils" -) - -const ( - // Used when waiting for something that is expected to *not* happen. - defaultTestShortTimeout = 10 * time.Millisecond - defaultTestTimeout = 5 * time.Second - defaultTestCertLife = time.Hour - shortTestCertLife = 2 * time.Second - maxErrCount = 2 -) - -// fakeCA provides a very simple fake implementation of the certificate signing -// service as exported by the MeshCA. -type fakeCA struct { - meshgrpc.UnimplementedMeshCertificateServiceServer - - withErrors bool // Whether the CA returns errors to begin with. - withShortLife bool // Whether to create certs with short lifetime - - ccChan *testutils.Channel // Channel to get notified about CreateCertificate calls. - errors int // Error count. - key *rsa.PrivateKey // Private key of CA. - cert *x509.Certificate // Signing certificate. - certPEM []byte // PEM encoding of signing certificate. -} - -// Returns a new instance of the fake Mesh CA. It generates a new RSA key and a -// self-signed certificate which will be used to sign CSRs received in incoming -// requests. -// withErrors controls whether the fake returns errors before succeeding, while -// withShortLife controls whether the fake returns certs with very small -// lifetimes (to test plugin refresh behavior). Every time a CreateCertificate() -// call succeeds, an event is pushed on the ccChan. -func newFakeMeshCA(ccChan *testutils.Channel, withErrors, withShortLife bool) (*fakeCA, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, fmt.Errorf("RSA key generation failed: %v", err) - } - - now := time.Now() - tmpl := &x509.Certificate{ - Subject: pkix.Name{CommonName: "my-fake-ca"}, - SerialNumber: big.NewInt(10), - NotBefore: now.Add(-time.Hour), - NotAfter: now.Add(time.Hour), - KeyUsage: x509.KeyUsageCertSign, - IsCA: true, - BasicConstraintsValid: true, - } - certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) - if err != nil { - return nil, fmt.Errorf("x509.CreateCertificate(%v) failed: %v", tmpl, err) - } - // The PEM encoding of the self-signed certificate is stored because we need - // to return a chain of certificates in the response, starting with the - // client certificate and ending in the root. - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - cert, err := x509.ParseCertificate(certDER) - if err != nil { - return nil, fmt.Errorf("x509.ParseCertificate(%v) failed: %v", certDER, err) - } - - return &fakeCA{ - withErrors: withErrors, - withShortLife: withShortLife, - ccChan: ccChan, - key: key, - cert: cert, - certPEM: certPEM, - }, nil -} - -// CreateCertificate helps implement the MeshCA service. -// -// If the fakeMeshCA was created with `withErrors` set to true, the first -// `maxErrCount` number of RPC return errors. Subsequent requests are signed and -// returned without error. -func (f *fakeCA) CreateCertificate(ctx context.Context, req *meshpb.MeshCertificateRequest) (*meshpb.MeshCertificateResponse, error) { - if f.withErrors { - if f.errors < maxErrCount { - f.errors++ - return nil, errors.New("fake Mesh CA error") - - } - } - - csrPEM := []byte(req.GetCsr()) - block, _ := pem.Decode(csrPEM) - if block == nil { - return nil, fmt.Errorf("failed to decode received CSR: %v", csrPEM) - } - csr, err := x509.ParseCertificateRequest(block.Bytes) - if err != nil { - return nil, fmt.Errorf("failed to parse received CSR: %v", csrPEM) - } - - // By default, we create certs which are valid for an hour. But if - // `withShortLife` is set, we create certs which are valid only for a couple - // of seconds. - now := time.Now() - notBefore, notAfter := now.Add(-defaultTestCertLife), now.Add(defaultTestCertLife) - if f.withShortLife { - notBefore, notAfter = now.Add(-shortTestCertLife), now.Add(shortTestCertLife) - } - tmpl := &x509.Certificate{ - Subject: pkix.Name{CommonName: "signed-cert"}, - SerialNumber: big.NewInt(10), - NotBefore: notBefore, - NotAfter: notAfter, - KeyUsage: x509.KeyUsageDigitalSignature, - } - certDER, err := x509.CreateCertificate(rand.Reader, tmpl, f.cert, csr.PublicKey, f.key) - if err != nil { - return nil, fmt.Errorf("x509.CreateCertificate(%v) failed: %v", tmpl, err) - } - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - - // Push to ccChan to indicate that the RPC is processed. - f.ccChan.Send(nil) - - certChain := []string{ - string(certPEM), // Signed certificate corresponding to CSR - string(f.certPEM), // Root certificate - } - return &meshpb.MeshCertificateResponse{CertChain: certChain}, nil -} - -// opts wraps the options to be passed to setup. -type opts struct { - // Whether the CA returns certs with short lifetime. Used to test client refresh. - withShortLife bool - // Whether the CA returns errors to begin with. Used to test client backoff. - withbackoff bool -} - -// events wraps channels which indicate different events. -type events struct { - // Pushed to when the plugin dials the MeshCA. - dialDone *testutils.Channel - // Pushed to when CreateCertifcate() succeeds on the MeshCA. - createCertDone *testutils.Channel - // Pushed to when the plugin updates the distributor with new key material. - keyMaterialDone *testutils.Channel - // Pushed to when the client backs off after a failed CreateCertificate(). - backoffDone *testutils.Channel -} - -// setup performs tasks common to all tests in this file. -func setup(t *testing.T, o opts) (events, string, func()) { - t.Helper() - - // Create a fake MeshCA which pushes events on the passed channel for - // successful RPCs. - createCertDone := testutils.NewChannel() - fs, err := newFakeMeshCA(createCertDone, o.withbackoff, o.withShortLife) - if err != nil { - t.Fatal(err) - } - - // Create a gRPC server and register the fake MeshCA on it. - server := grpc.NewServer() - meshgrpc.RegisterMeshCertificateServiceServer(server, fs) - - // Start a net.Listener on a local port, and pass it to the gRPC server - // created above and start serving. - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal(err) - } - addr := lis.Addr().String() - go server.Serve(lis) - - // Override the plugin's dial function and perform a blocking dial. Also - // push on dialDone once the dial is complete so that test can block on this - // event before verifying other things. - dialDone := testutils.NewChannel() - origDialFunc := grpcDialFunc - grpcDialFunc = func(uri string, _ ...grpc.DialOption) (*grpc.ClientConn, error) { - if uri != addr { - t.Fatalf("plugin dialing MeshCA at %s, want %s", uri, addr) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - cc, err := grpc.DialContext(ctx, uri, grpc.WithInsecure(), grpc.WithBlock()) - if err != nil { - t.Fatalf("grpc.DialContext(%s) failed: %v", addr, err) - } - dialDone.Send(nil) - return cc, nil - } - - // Override the plugin's newDistributorFunc and return a wrappedDistributor - // which allows the test to be notified whenever the plugin pushes new key - // material into the distributor. - origDistributorFunc := newDistributorFunc - keyMaterialDone := testutils.NewChannel() - d := newWrappedDistributor(keyMaterialDone) - newDistributorFunc = func() distributor { return d } - - // Override the plugin's backoff function to perform no real backoff, but - // push on a channel so that the test can verifiy that backoff actually - // happened. - backoffDone := testutils.NewChannelWithSize(maxErrCount) - origBackoffFunc := backoffFunc - if o.withbackoff { - // Override the plugin's backoff function with this, so that we can verify - // that a backoff actually was triggered. - backoffFunc = func(v int) time.Duration { - backoffDone.Send(v) - return 0 - } - } - - // Return all the channels, and a cancel function to undo all the overrides. - e := events{ - dialDone: dialDone, - createCertDone: createCertDone, - keyMaterialDone: keyMaterialDone, - backoffDone: backoffDone, - } - done := func() { - server.Stop() - grpcDialFunc = origDialFunc - newDistributorFunc = origDistributorFunc - backoffFunc = origBackoffFunc - } - return e, addr, done -} - -// wrappedDistributor wraps a distributor and pushes on a channel whenever new -// key material is pushed to the distributor. -type wrappedDistributor struct { - *certprovider.Distributor - kmChan *testutils.Channel -} - -func newWrappedDistributor(kmChan *testutils.Channel) *wrappedDistributor { - return &wrappedDistributor{ - kmChan: kmChan, - Distributor: certprovider.NewDistributor(), - } -} - -func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) { - wd.Distributor.Set(km, err) - wd.kmChan.Send(nil) -} - -// TestCreateCertificate verifies the simple case where the MeshCA server -// returns a good certificate. -func (s) TestCreateCertificate(t *testing.T) { - e, addr, cancel := setup(t, opts{}) - defer cancel() - - // Set the MeshCA targetURI to point to our fake MeshCA. - inputConfig := json.RawMessage(fmt.Sprintf(goodConfigFormatStr, addr)) - - // Lookup MeshCA plugin builder, parse config and start the plugin. - prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("GetProvider(%s, %s) failed: %v", pluginName, string(inputConfig), err) - } - defer prov.Close() - - // Wait till the plugin dials the MeshCA server. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.dialDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to dial MeshCA") - } - - // Wait till the plugin makes a CreateCertificate() call. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.createCertDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to make CreateCertificate RPC") - } - - // We don't really care about the exact key material returned here. All we - // care about is whether we get any key material at all, and that we don't - // get any errors. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err = prov.KeyMaterial(ctx); err != nil { - t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err) - } -} - -// TestCreateCertificateWithBackoff verifies the case where the MeshCA server -// returns errors initially and then returns a good certificate. The test makes -// sure that the client backs off when the server returns errors. -func (s) TestCreateCertificateWithBackoff(t *testing.T) { - e, addr, cancel := setup(t, opts{withbackoff: true}) - defer cancel() - - // Set the MeshCA targetURI to point to our fake MeshCA. - inputConfig := json.RawMessage(fmt.Sprintf(goodConfigFormatStr, addr)) - - // Lookup MeshCA plugin builder, parse config and start the plugin. - prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("GetProvider(%s, %s) failed: %v", pluginName, string(inputConfig), err) - } - defer prov.Close() - - // Wait till the plugin dials the MeshCA server. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.dialDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to dial MeshCA") - } - - // Making the CreateCertificateRPC involves generating the keys, creating - // the CSR etc which seem to take reasonable amount of time. And in this - // test, the first two attempts will fail. Hence we give it a reasonable - // deadline here. - ctx, cancel = context.WithTimeout(context.Background(), 3*defaultTestTimeout) - defer cancel() - if _, err := e.createCertDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to make CreateCertificate RPC") - } - - // The first `maxErrCount` calls to CreateCertificate end in failure, and - // should lead to a backoff. - for i := 0; i < maxErrCount; i++ { - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.backoffDone.Receive(ctx); err != nil { - t.Fatalf("plugin failed to backoff after error from fake server: %v", err) - } - } - - // We don't really care about the exact key material returned here. All we - // care about is whether we get any key material at all, and that we don't - // get any errors. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err = prov.KeyMaterial(ctx); err != nil { - t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err) - } -} - -// TestCreateCertificateWithRefresh verifies the case where the MeshCA returns a -// certificate with a really short lifetime, and makes sure that the plugin -// refreshes the cert in time. -func (s) TestCreateCertificateWithRefresh(t *testing.T) { - e, addr, cancel := setup(t, opts{withShortLife: true}) - defer cancel() - - // Set the MeshCA targetURI to point to our fake MeshCA. - inputConfig := json.RawMessage(fmt.Sprintf(goodConfigFormatStr, addr)) - - // Lookup MeshCA plugin builder, parse config and start the plugin. - prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("GetProvider(%s, %s) failed: %v", pluginName, string(inputConfig), err) - } - defer prov.Close() - - // Wait till the plugin dials the MeshCA server. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.dialDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to dial MeshCA") - } - - // Wait till the plugin makes a CreateCertificate() call. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.createCertDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to make CreateCertificate RPC") - } - - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - km1, err := prov.KeyMaterial(ctx) - if err != nil { - t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err) - } - - // At this point, we have read the first key material, and since the - // returned key material has a really short validity period, we expect the - // key material to be refreshed quite soon. We drain the channel on which - // the event corresponding to setting of new key material is pushed. This - // enables us to block on the same channel, waiting for refreshed key - // material. - // Since we do not expect this call to block, it is OK to pass the - // background context. - e.keyMaterialDone.Receive(context.Background()) - - // Wait for the next call to CreateCertificate() to refresh the certificate - // returned earlier. - ctx, cancel = context.WithTimeout(context.Background(), 2*shortTestCertLife) - defer cancel() - if _, err := e.keyMaterialDone.Receive(ctx); err != nil { - t.Fatalf("CreateCertificate() RPC not made: %v", err) - } - - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - km2, err := prov.KeyMaterial(ctx) - if err != nil { - t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err) - } - - // TODO(easwars): Remove all references to reflect.DeepEqual and use - // cmp.Equal instead. Currently, the later panics because x509.Certificate - // type defines an Equal method, but does not check for nil. This has been - // fixed in - // https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351, - // but this is only available starting go1.14. So, once we remove support - // for go1.13, we can make the switch. - if reflect.DeepEqual(km1, km2) { - t.Error("certificate refresh did not happen in the background") - } -} diff --git a/credentials/tls/certprovider/pemfile/watcher_test.go b/credentials/tls/certprovider/pemfile/watcher_test.go index e43cf7358ec..6cc65bd5000 100644 --- a/credentials/tls/certprovider/pemfile/watcher_test.go +++ b/credentials/tls/certprovider/pemfile/watcher_test.go @@ -22,7 +22,6 @@ import ( "context" "fmt" "io/ioutil" - "math/big" "os" "path" "testing" @@ -30,7 +29,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" @@ -57,17 +55,15 @@ func Test(t *testing.T) { } func compareKeyMaterial(got, want *certprovider.KeyMaterial) error { - // x509.Certificate type defines an Equal() method, but does not check for - // nil. This has been fixed in - // https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351, - // but this is only available starting go1.14. - // TODO(easwars): Remove this check once we remove support for go1.13. - if (got.Certs == nil && want.Certs != nil) || (want.Certs == nil && got.Certs != nil) { + if len(got.Certs) != len(want.Certs) { return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) } - if !cmp.Equal(got.Certs, want.Certs, cmp.AllowUnexported(big.Int{})) { - return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + for i := 0; i < len(got.Certs); i++ { + if !got.Certs[i].Leaf.Equal(want.Certs[i].Leaf) { + return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + } } + // x509.CertPool contains only unexported fields some of which contain other // unexported fields. So usage of cmp.AllowUnexported() or // cmpopts.IgnoreUnexported() does not help us much here. Also, the standard diff --git a/credentials/tls/certprovider/store_test.go b/credentials/tls/certprovider/store_test.go index 00d33a2be87..ee1f4a358ba 100644 --- a/credentials/tls/certprovider/store_test.go +++ b/credentials/tls/certprovider/store_test.go @@ -1,5 +1,3 @@ -// +build go1.13 - /* * * Copyright 2020 gRPC authors. @@ -27,10 +25,11 @@ import ( "errors" "fmt" "io/ioutil" - "reflect" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/testdata" @@ -154,15 +153,23 @@ func readAndVerifyKeyMaterial(ctx context.Context, kmr kmReader, wantKM *KeyMate } func compareKeyMaterial(got, want *KeyMaterial) error { - // TODO(easwars): Remove all references to reflect.DeepEqual and use - // cmp.Equal instead. Currently, the later panics because x509.Certificate - // type defines an Equal method, but does not check for nil. This has been - // fixed in - // https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351, - // but this is only available starting go1.14. So, once we remove support - // for go1.13, we can make the switch. - if !reflect.DeepEqual(got, want) { - return fmt.Errorf("provider.KeyMaterial() = %+v, want %+v", got, want) + if len(got.Certs) != len(want.Certs) { + return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + } + for i := 0; i < len(got.Certs); i++ { + if !got.Certs[i].Leaf.Equal(want.Certs[i].Leaf) { + return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + } + } + + // x509.CertPool contains only unexported fields some of which contain other + // unexported fields. So usage of cmp.AllowUnexported() or + // cmpopts.IgnoreUnexported() does not help us much here. Also, the standard + // library does not provide a way to compare CertPool values. Comparing the + // subjects field of the certs in the CertPool seems like a reasonable + // approach. + if gotR, wantR := got.Roots.Subjects(), want.Roots.Subjects(); !cmp.Equal(gotR, wantR, cmpopts.EquateEmpty()) { + return fmt.Errorf("keyMaterial roots = %v, want %v", gotR, wantR) } return nil } diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index ede0806d70d..680ea9cfa10 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -18,11 +18,6 @@ // Package xds provides a transport credentials implementation where the // security configuration is pushed by a management server using xDS APIs. -// -// Experimental -// -// Notice: All APIs in this package are EXPERIMENTAL and may be removed in a -// later release. package xds import ( @@ -216,12 +211,15 @@ func (c *credsImpl) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.Aut // `HandshakeInfo` does not contain the information we are looking for, we // delegate the handshake to the fallback credentials. hiConn, ok := rawConn.(interface { - XDSHandshakeInfo() *xdsinternal.HandshakeInfo + XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) }) if !ok { return c.fallback.ServerHandshake(rawConn) } - hi := hiConn.XDSHandshakeInfo() + hi, err := hiConn.XDSHandshakeInfo() + if err != nil { + return nil, nil, err + } if hi.UseFallbackCreds() { return c.fallback.ServerHandshake(rawConn) } diff --git a/credentials/xds/xds_client_test.go b/credentials/xds/xds_client_test.go index 219d0aefcba..f4b86df060b 100644 --- a/credentials/xds/xds_client_test.go +++ b/credentials/xds/xds_client_test.go @@ -32,18 +32,19 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" - "google.golang.org/grpc/internal" + icredentials "google.golang.org/grpc/internal/credentials" xdsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/resolver" "google.golang.org/grpc/testdata" ) const ( - defaultTestTimeout = 10 * time.Second + defaultTestTimeout = 1 * time.Second defaultTestShortTimeout = 10 * time.Millisecond - defaultTestCertSAN = "*.test.example.com" + defaultTestCertSAN = "abc.test.example.com" authority = "authority" ) @@ -214,18 +215,20 @@ func makeRootProvider(t *testing.T, caPath string) *fakeProvider { // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo // context value added to it. -func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sans ...string) context.Context { +func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context { // Creating the HandshakeInfo and adding it to the attributes is very // similar to what the CDS balancer would do when it intercepts calls to // NewSubConn(). - info := xdsinternal.NewHandshakeInfo(root, identity, sans...) + info := xdsinternal.NewHandshakeInfo(root, identity) + if sanExactMatch != "" { + info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}) + } addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info) // Moving the attributes from the resolver.Address to the context passed to // the handshaker is done in the transport layer. Since we directly call the // handshaker in these tests, we need to do the same here. - contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) - return contextWithHandshakeInfo(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) + return icredentials.NewClientHandshakeInfoContext(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) } // compareAuthInfo compares the AuthInfo received on the client side after a @@ -292,7 +295,7 @@ func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) { pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}) + ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "") if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil { t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo") } @@ -329,7 +332,7 @@ func (s) TestClientCredsProviderFailure(t *testing.T) { t.Run(test.desc, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider) + ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "") if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) { t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr) } @@ -371,7 +374,7 @@ func (s) TestClientCredsSuccess(t *testing.T) { desc: "mTLS with no acceptedSANs specified", handshakeFunc: testServerMutualTLSHandshake, handshakeInfoCtx: func(ctx context.Context) context.Context { - return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem")) + return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "") }, }, } @@ -530,14 +533,14 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { // Create a root provider which will fail the handshake because it does not // use the correct trust roots. root1 := makeRootProvider(t, "x509/client_ca_cert.pem") - handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, defaultTestCertSAN) + handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil) + handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}) // We need to repeat most of what newTestContextWithHandshakeInfo() does // here because we need access to the underlying HandshakeInfo so that we // can update it before the next call to ClientHandshake(). addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo) - contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) - ctx = contextWithHandshakeInfo(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) + ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { t.Fatal("ClientHandshake() succeeded when expected to fail") } @@ -582,3 +585,7 @@ func (s) TestClientClone(t *testing.T) { t.Fatal("return value from Clone() doesn't point to new credentials instance") } } + +func newStringP(s string) *string { + return &s +} diff --git a/credentials/xds/xds_server_test.go b/credentials/xds/xds_server_test.go index 68d92b28e28..5c29ba38c28 100644 --- a/credentials/xds/xds_server_test.go +++ b/credentials/xds/xds_server_test.go @@ -95,12 +95,13 @@ func (s) TestServerCredsWithoutFallback(t *testing.T) { type wrapperConn struct { net.Conn - xdsHI *xdsinternal.HandshakeInfo - deadline time.Time + xdsHI *xdsinternal.HandshakeInfo + deadline time.Time + handshakeInfoErr error } -func (wc *wrapperConn) XDSHandshakeInfo() *xdsinternal.HandshakeInfo { - return wc.xdsHI +func (wc *wrapperConn) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) { + return wc.xdsHI, wc.handshakeInfoErr } func (wc *wrapperConn) GetDeadline() time.Time { @@ -166,6 +167,58 @@ func (s) TestServerCredsProviderFailure(t *testing.T) { } } +// TestServerCredsHandshake_XDSHandshakeInfoError verifies the case where the +// call to XDSHandshakeInfo() from the ServerHandshake() method returns an +// error, and the test verifies that the ServerHandshake() fails with the +// expected error. +func (s) TestServerCredsHandshake_XDSHandshakeInfoError(t *testing.T) { + opts := ServerOptions{FallbackCreds: &errorCreds{}} + creds, err := NewServerCredentials(opts) + if err != nil { + t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) + } + + // Create a test server which uses the xDS server credentials created above + // to perform TLS handshake on incoming connections. + ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + // Create a wrapped conn which returns a nil HandshakeInfo and a non-nil error. + conn := newWrappedConn(rawConn, nil, time.Now().Add(defaultTestTimeout)) + hiErr := errors.New("xdsHandshakeInfo error") + conn.handshakeInfoErr = hiErr + + // Invoke the ServerHandshake() method on the xDS credentials and verify + // that the error returned by the XDSHandshakeInfo() method on the + // wrapped conn is returned here. + _, _, err := creds.ServerHandshake(conn) + if !errors.Is(err, hiErr) { + return handshakeResult{err: fmt.Errorf("ServerHandshake() returned err: %v, wantErr: %v", err, hiErr)} + } + return handshakeResult{} + }) + defer ts.stop() + + // Dial the test server, but don't trigger the TLS handshake. This will + // cause ServerHandshake() to fail. + rawConn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer rawConn.Close() + + // Read handshake result from the testServer which will return an error if + // the handshake succeeded. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + val, err := ts.hsResult.Receive(ctx) + if err != nil { + t.Fatalf("testServer failed to return handshake result: %v", err) + } + hsr := val.(handshakeResult) + if hsr.err != nil { + t.Fatalf("testServer handshake failure: %v", hsr.err) + } +} + // TestServerCredsHandshakeTimeout verifies the case where the client does not // send required handshake data before the deadline set on the net.Conn passed // to ServerHandshake(). diff --git a/dialoptions.go b/dialoptions.go index e7f86e6d7c8..7a497237bbd 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -66,11 +66,7 @@ type dialOptions struct { minConnectTimeout func() time.Duration defaultServiceConfig *ServiceConfig // defaultServiceConfig is parsed from defaultServiceConfigRawJSON. defaultServiceConfigRawJSON *string - // This is used by ccResolverWrapper to backoff between successive calls to - // resolver.ResolveNow(). The user will have no need to configure this, but - // we need to be able to configure this in tests. - resolveNowBackoff func(int) time.Duration - resolvers []resolver.Builder + resolvers []resolver.Builder } // DialOption configures how we set up the connection. @@ -596,7 +592,6 @@ func defaultDialOptions() dialOptions { ReadBufferSize: defaultReadBufSize, UseProxy: true, }, - resolveNowBackoff: internalbackoff.DefaultExponential.Backoff, } } @@ -611,16 +606,6 @@ func withMinConnectDeadline(f func() time.Duration) DialOption { }) } -// withResolveNowBackoff specifies the function that clientconn uses to backoff -// between successive calls to resolver.ResolveNow(). -// -// For testing purpose only. -func withResolveNowBackoff(f func(int) time.Duration) DialOption { - return newFuncDialOption(func(o *dialOptions) { - o.resolveNowBackoff = f - }) -} - // WithResolvers allows a list of resolver implementations to be registered // locally with the ClientConn without needing to be globally registered via // resolver.Register. They will be matched against the scheme used for the diff --git a/examples/examples_test.sh b/examples/examples_test.sh index 9015272f33e..f5c82d062b2 100755 --- a/examples/examples_test.sh +++ b/examples/examples_test.sh @@ -58,6 +58,7 @@ EXAMPLES=( "features/metadata" "features/multiplex" "features/name_resolving" + "features/unix_abstract" ) declare -A EXPECTED_SERVER_OUTPUT=( @@ -73,6 +74,7 @@ declare -A EXPECTED_SERVER_OUTPUT=( ["features/metadata"]="message:\"this is examples/metadata\", sending echo" ["features/multiplex"]=":50051" ["features/name_resolving"]="serving on localhost:50051" + ["features/unix_abstract"]="serving on @abstract-unix-socket" ) declare -A EXPECTED_CLIENT_OUTPUT=( @@ -88,6 +90,7 @@ declare -A EXPECTED_CLIENT_OUTPUT=( ["features/metadata"]="this is examples/metadata" ["features/multiplex"]="Greeting: Hello multiplex" ["features/name_resolving"]="calling helloworld.Greeter/SayHello to \"example:///resolver.example.grpc.io\"" + ["features/unix_abstract"]="calling echo.Echo/UnaryEcho to unix-abstract:abstract-unix-socket" ) cd ./examples diff --git a/examples/features/encryption/README.md b/examples/features/encryption/README.md index a00188d66a2..2afca1d785f 100644 --- a/examples/features/encryption/README.md +++ b/examples/features/encryption/README.md @@ -42,8 +42,8 @@ configure TLS and create the server credential using On client side, we provide the path to the "ca_cert.pem" to configure TLS and create the client credential using [`credentials.NewClientTLSFromFile`](https://godoc.org/google.golang.org/grpc/credentials#NewClientTLSFromFile). -Note that we override the server name with "x.test.youtube.com", as the server -certificate is valid for *.test.youtube.com but not localhost. It is solely for +Note that we override the server name with "x.test.example.com", as the server +certificate is valid for *.test.example.com but not localhost. It is solely for the convenience of making an example. Once the credentials have been created at both sides, we can start the server diff --git a/examples/features/proto/echo/echo_grpc.pb.go b/examples/features/proto/echo/echo_grpc.pb.go index 052087dae36..e1d24b1e830 100644 --- a/examples/features/proto/echo/echo_grpc.pb.go +++ b/examples/features/proto/echo/echo_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: examples/features/proto/echo/echo.proto package echo diff --git a/examples/features/unix_abstract/README.md b/examples/features/unix_abstract/README.md new file mode 100644 index 00000000000..32b3bd5f262 --- /dev/null +++ b/examples/features/unix_abstract/README.md @@ -0,0 +1,29 @@ +# Unix abstract sockets + +This examples shows how to start a gRPC server listening on a unix abstract +socket and how to get a gRPC client to connect to it. + +## What is a unix abstract socket + +An abstract socket address is distinguished from a regular unix socket by the +fact that the first byte of the address is a null byte ('\0'). The address has +no connection with filesystem path names. + +## Try it + +``` +go run server/main.go +``` + +``` +go run client/main.go +``` + +## Explanation + +The gRPC server in this example listens on an address starting with a null byte +and the network is `unix`. The client uses the `unix-abstract` scheme with the +endpoint set to the abstract unix socket address without the null byte. The +`unix` resolver takes care of adding the null byte on the client. See +https://github.com/grpc/grpc/blob/master/doc/naming.md for the more details. + diff --git a/examples/features/unix_abstract/client/main.go b/examples/features/unix_abstract/client/main.go new file mode 100644 index 00000000000..4f48aca9bdf --- /dev/null +++ b/examples/features/unix_abstract/client/main.go @@ -0,0 +1,68 @@ +//go:build linux +// +build linux + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Binary client is an example client which dials a server on an abstract unix +// socket. +package main + +import ( + "context" + "fmt" + "log" + "time" + + "google.golang.org/grpc" + ecpb "google.golang.org/grpc/examples/features/proto/echo" +) + +func callUnaryEcho(c ecpb.EchoClient, message string) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + r, err := c.UnaryEcho(ctx, &ecpb.EchoRequest{Message: message}) + if err != nil { + log.Fatalf("could not greet: %v", err) + } + fmt.Println(r.Message) +} + +func makeRPCs(cc *grpc.ClientConn, n int) { + hwc := ecpb.NewEchoClient(cc) + for i := 0; i < n; i++ { + callUnaryEcho(hwc, "this is examples/unix_abstract") + } +} + +func main() { + // A dial target of `unix:@abstract-unix-socket` should also work fine for + // this example because of golang conventions (net.Dial behavior). But we do + // not recommend this since we explicitly added the `unix-abstract` scheme + // for cross-language compatibility. + addr := "unix-abstract:abstract-unix-socket" + cc, err := grpc.Dial(addr, grpc.WithInsecure(), grpc.WithBlock()) + if err != nil { + log.Fatalf("grpc.Dial(%q) failed: %v", addr, err) + } + defer cc.Close() + + fmt.Printf("--- calling echo.Echo/UnaryEcho to %s\n", addr) + makeRPCs(cc, 10) + fmt.Println() +} diff --git a/examples/features/unix_abstract/server/main.go b/examples/features/unix_abstract/server/main.go new file mode 100644 index 00000000000..a82b957c1f0 --- /dev/null +++ b/examples/features/unix_abstract/server/main.go @@ -0,0 +1,58 @@ +//go:build linux +// +build linux + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Binary server is an example server listening for gRPC connections on an +// abstract unix socket. +package main + +import ( + "context" + "fmt" + "log" + "net" + + "google.golang.org/grpc" + + pb "google.golang.org/grpc/examples/features/proto/echo" +) + +type ecServer struct { + pb.UnimplementedEchoServer + addr string +} + +func (s *ecServer) UnaryEcho(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) { + return &pb.EchoResponse{Message: fmt.Sprintf("%s (from %s)", req.Message, s.addr)}, nil +} + +func main() { + netw, addr := "unix", "\x00abstract-unix-socket" + lis, err := net.Listen(netw, addr) + if err != nil { + log.Fatalf("net.Listen(%q, %q) failed: %v", netw, addr, err) + } + s := grpc.NewServer() + pb.RegisterEchoServer(s, &ecServer{addr: addr}) + log.Printf("serving on %s\n", lis.Addr().String()) + if err := s.Serve(lis); err != nil { + log.Fatalf("failed to serve: %v", err) + } +} diff --git a/examples/features/xds/client/main.go b/examples/features/xds/client/main.go index b1daa1cae9c..97918faa224 100644 --- a/examples/features/xds/client/main.go +++ b/examples/features/xds/client/main.go @@ -16,78 +16,56 @@ * */ -// Package main implements a client for Greeter service. +// Binary main implements a client for Greeter service using gRPC's client-side +// support for xDS APIs. package main import ( "context" "flag" - "fmt" "log" + "strings" "time" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + xdscreds "google.golang.org/grpc/credentials/xds" pb "google.golang.org/grpc/examples/helloworld/helloworld" _ "google.golang.org/grpc/xds" // To install the xds resolvers and balancers. ) -const ( - defaultTarget = "localhost:50051" - defaultName = "world" +var ( + target = flag.String("target", "xds:///localhost:50051", "uri of the Greeter Server, e.g. 'xds:///helloworld-service:8080'") + name = flag.String("name", "world", "name you wished to be greeted by the server") + xdsCreds = flag.Bool("xds_creds", false, "whether the server should use xDS APIs to receive security configuration") ) -var help = flag.Bool("help", false, "Print usage information") - -func init() { - flag.Usage = func() { - fmt.Fprintf(flag.CommandLine.Output(), ` -Usage: client [name [target]] - - name - The name you wish to be greeted by. Defaults to %q - target - The URI of the server, e.g. "xds:///helloworld-service". Defaults to %q -`, defaultName, defaultTarget) - - flag.PrintDefaults() - } -} - func main() { flag.Parse() - if *help { - flag.Usage() - return - } - args := flag.Args() - - if len(args) > 2 { - flag.Usage() - return - } - name := defaultName - if len(args) > 0 { - name = args[0] + if !strings.HasPrefix(*target, "xds:///") { + log.Fatalf("-target must use a URI with scheme set to 'xds'") } - target := defaultTarget - if len(args) > 1 { - target = args[1] + creds := insecure.NewCredentials() + if *xdsCreds { + log.Println("Using xDS credentials...") + var err error + if creds, err = xdscreds.NewClientCredentials(xdscreds.ClientOptions{FallbackCreds: insecure.NewCredentials()}); err != nil { + log.Fatalf("failed to create client-side xDS credentials: %v", err) + } } - - // Set up a connection to the server. - conn, err := grpc.Dial(target, grpc.WithInsecure()) + conn, err := grpc.Dial(*target, grpc.WithTransportCredentials(creds)) if err != nil { - log.Fatalf("did not connect: %v", err) + log.Fatalf("grpc.Dial(%s) failed: %v", *target, err) } defer conn.Close() - c := pb.NewGreeterClient(conn) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - r, err := c.SayHello(ctx, &pb.HelloRequest{Name: name}) + c := pb.NewGreeterClient(conn) + r, err := c.SayHello(ctx, &pb.HelloRequest{Name: *name}) if err != nil { log.Fatalf("could not greet: %v", err) } diff --git a/examples/features/xds/server/main.go b/examples/features/xds/server/main.go index 7e0815645e5..0367060f4b5 100644 --- a/examples/features/xds/server/main.go +++ b/examples/features/xds/server/main.go @@ -16,7 +16,8 @@ * */ -// Package main starts Greeter service that will response with the hostname. +// Binary server demonstrated gRPC's support for xDS APIs on the server-side. It +// exposes the Greeter service that will response with the hostname. package main import ( @@ -27,36 +28,29 @@ import ( "math/rand" "net" "os" - "strconv" "time" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + xdscreds "google.golang.org/grpc/credentials/xds" pb "google.golang.org/grpc/examples/helloworld/helloworld" "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/reflection" + "google.golang.org/grpc/xds" ) -var help = flag.Bool("help", false, "Print usage information") - -const ( - defaultPort = 50051 +var ( + port = flag.Int("port", 50051, "the port to serve Greeter service requests on. Health service will be served on `port+1`") + xdsCreds = flag.Bool("xds_creds", false, "whether the server should use xDS APIs to receive security configuration") ) -// server is used to implement helloworld.GreeterServer. +// server implements helloworld.GreeterServer interface. type server struct { pb.UnimplementedGreeterServer - serverName string } -func newServer(serverName string) *server { - return &server{ - serverName: serverName, - } -} - -// SayHello implements helloworld.GreeterServer +// SayHello implements helloworld.GreeterServer interface. func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { log.Printf("Received: %v", in.GetName()) return &pb.HelloReply{Message: "Hello " + in.GetName() + ", from " + s.serverName}, nil @@ -72,65 +66,40 @@ func determineHostname() string { return hostname } -func init() { - flag.Usage = func() { - fmt.Fprintf(flag.CommandLine.Output(), ` -Usage: server [port [hostname]] - - port - The listen port. Defaults to %d - hostname - The name clients will see in greet responses. Defaults to the machine's hostname -`, defaultPort) - - flag.PrintDefaults() - } -} - func main() { flag.Parse() - if *help { - flag.Usage() - return - } - args := flag.Args() - if len(args) > 2 { - flag.Usage() - return + greeterPort := fmt.Sprintf(":%d", *port) + greeterLis, err := net.Listen("tcp4", greeterPort) + if err != nil { + log.Fatalf("net.Listen(tcp4, %q) failed: %v", greeterPort, err) } - port := defaultPort - if len(args) > 0 { + creds := insecure.NewCredentials() + if *xdsCreds { + log.Println("Using xDS credentials...") var err error - port, err = strconv.Atoi(args[0]) - if err != nil { - log.Printf("Invalid port number: %v", err) - flag.Usage() - return + if creds, err = xdscreds.NewServerCredentials(xdscreds.ServerOptions{FallbackCreds: insecure.NewCredentials()}); err != nil { + log.Fatalf("failed to create server-side xDS credentials: %v", err) } } - var hostname string - if len(args) > 1 { - hostname = args[1] - } - if hostname == "" { - hostname = determineHostname() - } + greeterServer := xds.NewGRPCServer(grpc.Creds(creds)) + pb.RegisterGreeterServer(greeterServer, &server{serverName: determineHostname()}) - lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) + healthPort := fmt.Sprintf(":%d", *port+1) + healthLis, err := net.Listen("tcp4", healthPort) if err != nil { - log.Fatalf("failed to listen: %v", err) + log.Fatalf("net.Listen(tcp4, %q) failed: %v", healthPort, err) } - s := grpc.NewServer() - pb.RegisterGreeterServer(s, newServer(hostname)) - - reflection.Register(s) + grpcServer := grpc.NewServer() healthServer := health.NewServer() healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) - healthpb.RegisterHealthServer(s, healthServer) + healthpb.RegisterHealthServer(grpcServer, healthServer) - log.Printf("serving on %s, hostname %s", lis.Addr(), hostname) - s.Serve(lis) + log.Printf("Serving GreeterService on %s and HealthService on %s", greeterLis.Addr().String(), healthLis.Addr().String()) + go func() { + greeterServer.Serve(greeterLis) + }() + grpcServer.Serve(healthLis) } diff --git a/examples/go.mod b/examples/go.mod index 18c67afed96..4f19b852edd 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -1,12 +1,12 @@ module google.golang.org/grpc/examples -go 1.11 +go 1.14 require ( - github.com/golang/protobuf v1.4.2 + github.com/golang/protobuf v1.4.3 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 - google.golang.org/grpc v1.31.0 + google.golang.org/grpc v1.36.0 google.golang.org/protobuf v1.25.0 ) diff --git a/examples/go.sum b/examples/go.sum index c6aa163b01a..a359cfc183f 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1,66 +1,77 @@ cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 h1:cqQfy1jclcSy/FwLjemeg3SR1yaINm74aQyupQ0Bl8M= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158 h1:CevA8fI91PAnP8vpnXuB8ZYAZ5wqY86nAbxfgK8tWO4= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d h1:QyzYnTnPE15SQyUeqU6qLbWxMkwyAyu+vGksa0b7j00= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021 h1:fP+fF0up6oPY49OrjPrhIJ8yQfdIM85NXMLkMg1EXVs= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0 h1:oOuy+ugB+P/kBdUnG5QaMXSIyJ1q38wWSojYCb3z5VQ= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 h1:gSJIx1SDwno+2ElGhA4+qG2zF97qiUzTM+rQ0klBOcE= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 h1:LCO0fg4kb6WwkXQXRQQgUYsFeFb5taTX5WAx5O/Vt28= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= @@ -68,16 +79,16 @@ google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLY google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0 h1:qdOKuR/EIArgaWNjetjgTzgVTAZ+S/WXVrq9HW9zimw= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0 h1:UhZDfRO8JRQru4/+LlLE0BRKGF8L+PICnvYZmx/fEGA= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/examples/helloworld/greeter_server/main.go b/examples/helloworld/greeter_server/main.go index 15604f9fc1f..4d077db92cf 100644 --- a/examples/helloworld/greeter_server/main.go +++ b/examples/helloworld/greeter_server/main.go @@ -50,6 +50,7 @@ func main() { } s := grpc.NewServer() pb.RegisterGreeterServer(s, &server{}) + log.Printf("server listening at %v", lis.Addr()) if err := s.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) } diff --git a/examples/helloworld/helloworld/helloworld_grpc.pb.go b/examples/helloworld/helloworld/helloworld_grpc.pb.go index 39a0301c16b..ae27dfa3cfe 100644 --- a/examples/helloworld/helloworld/helloworld_grpc.pb.go +++ b/examples/helloworld/helloworld/helloworld_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: examples/helloworld/helloworld/helloworld.proto package helloworld diff --git a/examples/route_guide/client/client.go b/examples/route_guide/client/client.go index 172f10fb308..f18c10af8b1 100644 --- a/examples/route_guide/client/client.go +++ b/examples/route_guide/client/client.go @@ -40,7 +40,7 @@ var ( tls = flag.Bool("tls", false, "Connection uses TLS if true, else plain TCP") caFile = flag.String("ca_file", "", "The file containing the CA root cert file") serverAddr = flag.String("server_addr", "localhost:10000", "The server address in the format of host:port") - serverHostOverride = flag.String("server_host_override", "x.test.youtube.com", "The server name used to verify the hostname returned by the TLS handshake") + serverHostOverride = flag.String("server_host_override", "x.test.example.com", "The server name used to verify the hostname returned by the TLS handshake") ) // printFeature gets the feature for the given point. diff --git a/examples/route_guide/routeguide/route_guide_grpc.pb.go b/examples/route_guide/routeguide/route_guide_grpc.pb.go index 66860e63c47..efa7c28ce6f 100644 --- a/examples/route_guide/routeguide/route_guide_grpc.pb.go +++ b/examples/route_guide/routeguide/route_guide_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: examples/route_guide/routeguide/route_guide.proto package routeguide diff --git a/go.mod b/go.mod index b177cfa66df..022cc9828fe 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,18 @@ module google.golang.org/grpc -go 1.11 +go 1.14 require ( + github.com/cespare/xxhash/v2 v2.1.1 github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 - github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d + github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b - github.com/golang/protobuf v1.4.2 + github.com/golang/protobuf v1.4.3 github.com/google/go-cmp v0.5.0 github.com/google/uuid v1.1.2 - golang.org/x/net v0.0.0-20190311183353-d8887717615a - golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be - golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a + golang.org/x/net v0.0.0-20200822124328-c89045814202 + golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d + golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 google.golang.org/protobuf v1.25.0 ) diff --git a/go.sum b/go.sum index bb25cd49156..6e7ae0db2b3 100644 --- a/go.sum +++ b/go.sum @@ -1,34 +1,44 @@ -cloud.google.com/go v0.26.0 h1:e0WKqKTd5BnrG8aKH3J3h+QvEIQtSUcf2n5UZ5ZgLtQ= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 h1:cqQfy1jclcSy/FwLjemeg3SR1yaINm74aQyupQ0Bl8M= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158 h1:CevA8fI91PAnP8vpnXuB8ZYAZ5wqY86nAbxfgK8tWO4= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d h1:QyzYnTnPE15SQyUeqU6qLbWxMkwyAyu+vGksa0b7j00= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021 h1:fP+fF0up6oPY49OrjPrhIJ8yQfdIM85NXMLkMg1EXVs= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -37,50 +47,65 @@ github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be h1:vEDujvNQGv4jgYKudGeI/+DAX4Jffq6hpD55MmoEvKs= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= +google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -93,7 +118,9 @@ google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4 google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/grpclog/loggerv2.go b/grpclog/loggerv2.go index 4ee33171e00..34098bb8eb5 100644 --- a/grpclog/loggerv2.go +++ b/grpclog/loggerv2.go @@ -19,11 +19,14 @@ package grpclog import ( + "encoding/json" + "fmt" "io" "io/ioutil" "log" "os" "strconv" + "strings" "google.golang.org/grpc/internal/grpclog" ) @@ -95,8 +98,9 @@ var severityName = []string{ // loggerT is the default logger used by grpclog. type loggerT struct { - m []*log.Logger - v int + m []*log.Logger + v int + jsonFormat bool } // NewLoggerV2 creates a loggerV2 with the provided writers. @@ -105,19 +109,32 @@ type loggerT struct { // Warning logs will be written to warningW and infoW. // Info logs will be written to infoW. func NewLoggerV2(infoW, warningW, errorW io.Writer) LoggerV2 { - return NewLoggerV2WithVerbosity(infoW, warningW, errorW, 0) + return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{}) } // NewLoggerV2WithVerbosity creates a loggerV2 with the provided writers and // verbosity level. func NewLoggerV2WithVerbosity(infoW, warningW, errorW io.Writer, v int) LoggerV2 { + return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{verbose: v}) +} + +type loggerV2Config struct { + verbose int + jsonFormat bool +} + +func newLoggerV2WithConfig(infoW, warningW, errorW io.Writer, c loggerV2Config) LoggerV2 { var m []*log.Logger - m = append(m, log.New(infoW, severityName[infoLog]+": ", log.LstdFlags)) - m = append(m, log.New(io.MultiWriter(infoW, warningW), severityName[warningLog]+": ", log.LstdFlags)) + flag := log.LstdFlags + if c.jsonFormat { + flag = 0 + } + m = append(m, log.New(infoW, "", flag)) + m = append(m, log.New(io.MultiWriter(infoW, warningW), "", flag)) ew := io.MultiWriter(infoW, warningW, errorW) // ew will be used for error and fatal. - m = append(m, log.New(ew, severityName[errorLog]+": ", log.LstdFlags)) - m = append(m, log.New(ew, severityName[fatalLog]+": ", log.LstdFlags)) - return &loggerT{m: m, v: v} + m = append(m, log.New(ew, "", flag)) + m = append(m, log.New(ew, "", flag)) + return &loggerT{m: m, v: c.verbose, jsonFormat: c.jsonFormat} } // newLoggerV2 creates a loggerV2 to be used as default logger. @@ -142,58 +159,79 @@ func newLoggerV2() LoggerV2 { if vl, err := strconv.Atoi(vLevel); err == nil { v = vl } - return NewLoggerV2WithVerbosity(infoW, warningW, errorW, v) + + jsonFormat := strings.EqualFold(os.Getenv("GRPC_GO_LOG_FORMATTER"), "json") + + return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{ + verbose: v, + jsonFormat: jsonFormat, + }) +} + +func (g *loggerT) output(severity int, s string) { + sevStr := severityName[severity] + if !g.jsonFormat { + g.m[severity].Output(2, fmt.Sprintf("%v: %v", sevStr, s)) + return + } + // TODO: we can also include the logging component, but that needs more + // (API) changes. + b, _ := json.Marshal(map[string]string{ + "severity": sevStr, + "message": s, + }) + g.m[severity].Output(2, string(b)) } func (g *loggerT) Info(args ...interface{}) { - g.m[infoLog].Print(args...) + g.output(infoLog, fmt.Sprint(args...)) } func (g *loggerT) Infoln(args ...interface{}) { - g.m[infoLog].Println(args...) + g.output(infoLog, fmt.Sprintln(args...)) } func (g *loggerT) Infof(format string, args ...interface{}) { - g.m[infoLog].Printf(format, args...) + g.output(infoLog, fmt.Sprintf(format, args...)) } func (g *loggerT) Warning(args ...interface{}) { - g.m[warningLog].Print(args...) + g.output(warningLog, fmt.Sprint(args...)) } func (g *loggerT) Warningln(args ...interface{}) { - g.m[warningLog].Println(args...) + g.output(warningLog, fmt.Sprintln(args...)) } func (g *loggerT) Warningf(format string, args ...interface{}) { - g.m[warningLog].Printf(format, args...) + g.output(warningLog, fmt.Sprintf(format, args...)) } func (g *loggerT) Error(args ...interface{}) { - g.m[errorLog].Print(args...) + g.output(errorLog, fmt.Sprint(args...)) } func (g *loggerT) Errorln(args ...interface{}) { - g.m[errorLog].Println(args...) + g.output(errorLog, fmt.Sprintln(args...)) } func (g *loggerT) Errorf(format string, args ...interface{}) { - g.m[errorLog].Printf(format, args...) + g.output(errorLog, fmt.Sprintf(format, args...)) } func (g *loggerT) Fatal(args ...interface{}) { - g.m[fatalLog].Fatal(args...) - // No need to call os.Exit() again because log.Logger.Fatal() calls os.Exit(). + g.output(fatalLog, fmt.Sprint(args...)) + os.Exit(1) } func (g *loggerT) Fatalln(args ...interface{}) { - g.m[fatalLog].Fatalln(args...) - // No need to call os.Exit() again because log.Logger.Fatal() calls os.Exit(). + g.output(fatalLog, fmt.Sprintln(args...)) + os.Exit(1) } func (g *loggerT) Fatalf(format string, args ...interface{}) { - g.m[fatalLog].Fatalf(format, args...) - // No need to call os.Exit() again because log.Logger.Fatal() calls os.Exit(). + g.output(fatalLog, fmt.Sprintf(format, args...)) + os.Exit(1) } func (g *loggerT) V(l int) bool { diff --git a/grpclog/loggerv2_test.go b/grpclog/loggerv2_test.go index 756f215f9c8..0b2c8b23d66 100644 --- a/grpclog/loggerv2_test.go +++ b/grpclog/loggerv2_test.go @@ -52,9 +52,9 @@ func TestLoggerV2Severity(t *testing.T) { } // check if b is in the format of: -// WARNING: 2017/04/07 14:55:42 WARNING +// 2017/04/07 14:55:42 WARNING: WARNING func checkLogForSeverity(s int, b []byte) error { - expected := regexp.MustCompile(fmt.Sprintf(`^%s: [0-9]{4}/[0-9]{2}/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} %s\n$`, severityName[s], severityName[s])) + expected := regexp.MustCompile(fmt.Sprintf(`^[0-9]{4}/[0-9]{2}/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} %s: %s\n$`, severityName[s], severityName[s])) if m := expected.Match(b); !m { return fmt.Errorf("got: %v, want string in format of: %v", string(b), severityName[s]+": 2016/10/05 17:09:26 "+severityName[s]) } diff --git a/health/grpc_health_v1/health_grpc.pb.go b/health/grpc_health_v1/health_grpc.pb.go index 386d16ce62d..bdc3ae284e7 100644 --- a/health/grpc_health_v1/health_grpc.pb.go +++ b/health/grpc_health_v1/health_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/health/v1/health.proto package grpc_health_v1 diff --git a/install_gae.sh b/install_gae.sh deleted file mode 100755 index 15ff9facdd7..00000000000 --- a/install_gae.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -TMP=$(mktemp -d /tmp/sdk.XXX) \ -&& curl -o $TMP.zip "https://storage.googleapis.com/appengine-sdks/featured/go_appengine_sdk_linux_amd64-1.9.68.zip" \ -&& unzip -q $TMP.zip -d $TMP \ -&& export PATH="$PATH:$TMP/go_appengine" \ No newline at end of file diff --git a/internal/admin/admin.go b/internal/admin/admin.go new file mode 100644 index 00000000000..a9285ee7484 --- /dev/null +++ b/internal/admin/admin.go @@ -0,0 +1,60 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package admin contains internal implementation for admin service. +package admin + +import "google.golang.org/grpc" + +// services is a map from name to service register functions. +var services []func(grpc.ServiceRegistrar) (func(), error) + +// AddService adds a service to the list of admin services. +// +// NOTE: this function must only be called during initialization time (i.e. in +// an init() function), and is not thread-safe. +// +// If multiple services with the same service name are added (e.g. two services +// for `grpc.channelz.v1.Channelz`), the server will panic on `Register()`. +func AddService(f func(grpc.ServiceRegistrar) (func(), error)) { + services = append(services, f) +} + +// Register registers the set of admin services to the given server. +func Register(s grpc.ServiceRegistrar) (cleanup func(), _ error) { + var cleanups []func() + for _, f := range services { + cleanup, err := f(s) + if err != nil { + callFuncs(cleanups) + return nil, err + } + if cleanup != nil { + cleanups = append(cleanups, cleanup) + } + } + return func() { + callFuncs(cleanups) + }, nil +} + +func callFuncs(fs []func()) { + for _, f := range fs { + f() + } +} diff --git a/internal/balancer/stub/stub.go b/internal/balancer/stub/stub.go index e3757c1a50b..950eaaa0278 100644 --- a/internal/balancer/stub/stub.go +++ b/internal/balancer/stub/stub.go @@ -33,6 +33,7 @@ type BalancerFuncs struct { ResolverError func(*BalancerData, error) UpdateSubConnState func(*BalancerData, balancer.SubConn, balancer.SubConnState) Close func(*BalancerData) + ExitIdle func(*BalancerData) } // BalancerData contains data relevant to a stub balancer. @@ -75,6 +76,12 @@ func (b *bal) Close() { } } +func (b *bal) ExitIdle() { + if b.bf.ExitIdle != nil { + b.bf.ExitIdle(b.bd) + } +} + type bb struct { name string bf BalancerFuncs diff --git a/internal/binarylog/sink.go b/internal/binarylog/sink.go index 7d7a3056b71..c2fdd58b319 100644 --- a/internal/binarylog/sink.go +++ b/internal/binarylog/sink.go @@ -69,7 +69,8 @@ type writerSink struct { func (ws *writerSink) Write(e *pb.GrpcLogEntry) error { b, err := proto.Marshal(e) if err != nil { - grpclogLogger.Infof("binary logging: failed to marshal proto message: %v", err) + grpclogLogger.Errorf("binary logging: failed to marshal proto message: %v", err) + return err } hdr := make([]byte, 4) binary.BigEndian.PutUint32(hdr, uint32(len(b))) @@ -85,24 +86,27 @@ func (ws *writerSink) Write(e *pb.GrpcLogEntry) error { func (ws *writerSink) Close() error { return nil } type bufferedSink struct { - mu sync.Mutex - closer io.Closer - out Sink // out is built on buf. - buf *bufio.Writer // buf is kept for flush. - - writeStartOnce sync.Once - writeTicker *time.Ticker + mu sync.Mutex + closer io.Closer + out Sink // out is built on buf. + buf *bufio.Writer // buf is kept for flush. + flusherStarted bool + + writeTicker *time.Ticker + done chan struct{} } func (fs *bufferedSink) Write(e *pb.GrpcLogEntry) error { - // Start the write loop when Write is called. - fs.writeStartOnce.Do(fs.startFlushGoroutine) fs.mu.Lock() + defer fs.mu.Unlock() + if !fs.flusherStarted { + // Start the write loop when Write is called. + fs.startFlushGoroutine() + fs.flusherStarted = true + } if err := fs.out.Write(e); err != nil { - fs.mu.Unlock() return err } - fs.mu.Unlock() return nil } @@ -113,7 +117,12 @@ const ( func (fs *bufferedSink) startFlushGoroutine() { fs.writeTicker = time.NewTicker(bufFlushDuration) go func() { - for range fs.writeTicker.C { + for { + select { + case <-fs.done: + return + case <-fs.writeTicker.C: + } fs.mu.Lock() if err := fs.buf.Flush(); err != nil { grpclogLogger.Warningf("failed to flush to Sink: %v", err) @@ -124,10 +133,12 @@ func (fs *bufferedSink) startFlushGoroutine() { } func (fs *bufferedSink) Close() error { + fs.mu.Lock() + defer fs.mu.Unlock() if fs.writeTicker != nil { fs.writeTicker.Stop() } - fs.mu.Lock() + close(fs.done) if err := fs.buf.Flush(); err != nil { grpclogLogger.Warningf("failed to flush to Sink: %v", err) } @@ -137,7 +148,6 @@ func (fs *bufferedSink) Close() error { if err := fs.out.Close(); err != nil { grpclogLogger.Warningf("failed to close the Sink: %v", err) } - fs.mu.Unlock() return nil } @@ -155,5 +165,6 @@ func NewBufferedSink(o io.WriteCloser) Sink { closer: o, out: newWriterSink(bufW), buf: bufW, + done: make(chan struct{}), } } diff --git a/internal/channelz/funcs.go b/internal/channelz/funcs.go index f7314139303..6d5760d9514 100644 --- a/internal/channelz/funcs.go +++ b/internal/channelz/funcs.go @@ -630,7 +630,7 @@ func (c *channelMap) GetServerSockets(id int64, startID int64, maxResults int64) if count == 0 { end = true } - var s []*SocketMetric + s := make([]*SocketMetric, 0, len(sks)) for _, ns := range sks { sm := &SocketMetric{} sm.SocketData = ns.s.ChannelzMetric() diff --git a/internal/channelz/types_linux.go b/internal/channelz/types_linux.go index 692dd618177..1b1c4cce34a 100644 --- a/internal/channelz/types_linux.go +++ b/internal/channelz/types_linux.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/channelz/types_nonlinux.go b/internal/channelz/types_nonlinux.go index 19c2fc521dc..8b06eed1ab8 100644 --- a/internal/channelz/types_nonlinux.go +++ b/internal/channelz/types_nonlinux.go @@ -1,4 +1,5 @@ -// +build !linux appengine +//go:build !linux +// +build !linux /* * @@ -37,6 +38,6 @@ type SocketOptionData struct { // Windows OS doesn't support Socket Option func (s *SocketOptionData) Getsockopt(fd uintptr) { once.Do(func() { - logger.Warning("Channelz: socket options are not supported on non-linux os and appengine.") + logger.Warning("Channelz: socket options are not supported on non-linux environments") }) } diff --git a/internal/channelz/util_linux.go b/internal/channelz/util_linux.go index fdf409d55de..8d194e44e1d 100644 --- a/internal/channelz/util_linux.go +++ b/internal/channelz/util_linux.go @@ -1,5 +1,3 @@ -// +build linux,!appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/channelz/util_nonlinux.go b/internal/channelz/util_nonlinux.go index 8864a081116..837ddc40240 100644 --- a/internal/channelz/util_nonlinux.go +++ b/internal/channelz/util_nonlinux.go @@ -1,4 +1,5 @@ -// +build !linux appengine +//go:build !linux +// +build !linux /* * diff --git a/internal/channelz/util_test.go b/internal/channelz/util_test.go index 3d1a1183fa4..9de6679043d 100644 --- a/internal/channelz/util_test.go +++ b/internal/channelz/util_test.go @@ -1,4 +1,5 @@ -// +build linux,!appengine +//go:build linux +// +build linux /* * diff --git a/internal/credentials/credentials.go b/internal/credentials/credentials.go new file mode 100644 index 00000000000..32c9b59033c --- /dev/null +++ b/internal/credentials/credentials.go @@ -0,0 +1,49 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package credentials + +import ( + "context" +) + +// requestInfoKey is a struct to be used as the key to store RequestInfo in a +// context. +type requestInfoKey struct{} + +// NewRequestInfoContext creates a context with ri. +func NewRequestInfoContext(ctx context.Context, ri interface{}) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) +} + +// RequestInfoFromContext extracts the RequestInfo from ctx. +func RequestInfoFromContext(ctx context.Context) interface{} { + return ctx.Value(requestInfoKey{}) +} + +// clientHandshakeInfoKey is a struct used as the key to store +// ClientHandshakeInfo in a context. +type clientHandshakeInfoKey struct{} + +// ClientHandshakeInfoFromContext extracts the ClientHandshakeInfo from ctx. +func ClientHandshakeInfoFromContext(ctx context.Context) interface{} { + return ctx.Value(clientHandshakeInfoKey{}) +} + +// NewClientHandshakeInfoContext creates a context with chi. +func NewClientHandshakeInfoContext(ctx context.Context, chi interface{}) context.Context { + return context.WithValue(ctx, clientHandshakeInfoKey{}, chi) +} diff --git a/internal/credentials/spiffe.go b/internal/credentials/spiffe.go index be70b6cdfc3..25ade623058 100644 --- a/internal/credentials/spiffe.go +++ b/internal/credentials/spiffe.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2020 gRPC authors. diff --git a/internal/credentials/spiffe_appengine.go b/internal/credentials/spiffe_appengine.go deleted file mode 100644 index af6f5771976..00000000000 --- a/internal/credentials/spiffe_appengine.go +++ /dev/null @@ -1,31 +0,0 @@ -// +build appengine - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package credentials - -import ( - "crypto/tls" - "net/url" -) - -// SPIFFEIDFromState is a no-op for appengine builds. -func SPIFFEIDFromState(state tls.ConnectionState) *url.URL { - return nil -} diff --git a/internal/credentials/syscallconn.go b/internal/credentials/syscallconn.go index f499a614c20..2919632d657 100644 --- a/internal/credentials/syscallconn.go +++ b/internal/credentials/syscallconn.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/credentials/syscallconn_test.go b/internal/credentials/syscallconn_test.go index ee17a0ca67b..b229a47d116 100644 --- a/internal/credentials/syscallconn_test.go +++ b/internal/credentials/syscallconn_test.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/credentials/util.go b/internal/credentials/util.go index 55664fa46b8..f792fd22caf 100644 --- a/internal/credentials/util.go +++ b/internal/credentials/util.go @@ -18,7 +18,9 @@ package credentials -import "crypto/tls" +import ( + "crypto/tls" +) const alpnProtoStrH2 = "h2" diff --git a/internal/credentials/xds/handshake_info.go b/internal/credentials/xds/handshake_info.go index 8b203566063..6ef43cc89fa 100644 --- a/internal/credentials/xds/handshake_info.go +++ b/internal/credentials/xds/handshake_info.go @@ -25,11 +25,13 @@ import ( "crypto/x509" "errors" "fmt" + "strings" "sync" "google.golang.org/grpc/attributes" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/resolver" ) @@ -64,8 +66,8 @@ type HandshakeInfo struct { mu sync.Mutex rootProvider certprovider.Provider identityProvider certprovider.Provider - acceptedSANs map[string]bool // Only on the client side. - requireClientCert bool // Only on server side. + sanMatchers []matcher.StringMatcher // Only on the client side. + requireClientCert bool // Only on server side. } // SetRootCertProvider updates the root certificate provider. @@ -82,13 +84,10 @@ func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) hi.mu.Unlock() } -// SetAcceptedSANs updates the list of accepted SANs. -func (hi *HandshakeInfo) SetAcceptedSANs(sans []string) { +// SetSANMatchers updates the list of SAN matchers. +func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []matcher.StringMatcher) { hi.mu.Lock() - hi.acceptedSANs = make(map[string]bool, len(sans)) - for _, san := range sans { - hi.acceptedSANs[san] = true - } + hi.sanMatchers = sanMatchers hi.mu.Unlock() } @@ -112,6 +111,14 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool { return hi.identityProvider == nil && hi.rootProvider == nil } +// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo. +// To be used only for testing purposes. +func (hi *HandshakeInfo) GetSANMatchersForTesting() []matcher.StringMatcher { + hi.mu.Lock() + defer hi.mu.Unlock() + return append([]matcher.StringMatcher{}, hi.sanMatchers...) +} + // ClientSideTLSConfig constructs a tls.Config to be used in a client-side // handshake based on the contents of the HandshakeInfo. func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) { @@ -131,7 +138,10 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, // Currently the Go stdlib does complete verification of the cert (which // includes hostname verification) or none. We are forced to go with the // latter and perform the normal cert validation ourselves. - cfg := &tls.Config{InsecureSkipVerify: true} + cfg := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + } km, err := rootProv.KeyMaterial(ctx) if err != nil { @@ -152,7 +162,10 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, // ServerSideTLSConfig constructs a tls.Config to be used in a server-side // handshake based on the contents of the HandshakeInfo. func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, error) { - cfg := &tls.Config{ClientAuth: tls.NoClientCert} + cfg := &tls.Config{ + ClientAuth: tls.NoClientCert, + NextProtos: []string{"h2"}, + } hi.mu.Lock() // On the server side, identityProvider is mandatory. RootProvider is // optional based on whether the server is doing TLS or mTLS. @@ -184,47 +197,115 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, return cfg, nil } -// MatchingSANExists returns true if the SAN contained in the passed in -// certificate is present in the list of accepted SANs in the HandshakeInfo. +// MatchingSANExists returns true if the SANs contained in cert match the +// criteria enforced by the list of SAN matchers in HandshakeInfo. // -// If the list of accepted SANs in the HandshakeInfo is empty, this function +// If the list of SAN matchers in the HandshakeInfo is empty, this function // returns true for all input certificates. func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool { - if len(hi.acceptedSANs) == 0 { + hi.mu.Lock() + defer hi.mu.Unlock() + if len(hi.sanMatchers) == 0 { return true } - var sans []string // SANs can be specified in any of these four fields on the parsed cert. - sans = append(sans, cert.DNSNames...) - sans = append(sans, cert.EmailAddresses...) - for _, ip := range cert.IPAddresses { - sans = append(sans, ip.String()) + for _, san := range cert.DNSNames { + if hi.matchSAN(san, true) { + return true + } } - for _, uri := range cert.URIs { - sans = append(sans, uri.String()) + for _, san := range cert.EmailAddresses { + if hi.matchSAN(san, false) { + return true + } + } + for _, san := range cert.IPAddresses { + if hi.matchSAN(san.String(), false) { + return true + } + } + for _, san := range cert.URIs { + if hi.matchSAN(san.String(), false) { + return true + } } + return false +} - hi.mu.Lock() - defer hi.mu.Unlock() - for _, san := range sans { - if hi.acceptedSANs[san] { +// Caller must hold mu. +func (hi *HandshakeInfo) matchSAN(san string, isDNS bool) bool { + for _, matcher := range hi.sanMatchers { + if em := matcher.ExactMatch(); em != "" && isDNS { + // This is a special case which is documented in the xDS protos. + // If the DNS SAN is a wildcard entry, and the match criteria is + // `exact`, then we need to perform DNS wildcard matching + // instead of regular string comparison. + if dnsMatch(em, san) { + return true + } + continue + } + if matcher.Match(san) { return true } } return false } -// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root -// and identity certificate providers. -func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *HandshakeInfo { - acceptedSANs := make(map[string]bool, len(sans)) - for _, san := range sans { - acceptedSANs[san] = true +// dnsMatch implements a DNS wildcard matching algorithm based on RFC2828 and +// grpc-java's implementation in `OkHostnameVerifier` class. +// +// NOTE: Here the `host` argument is the one from the set of string matchers in +// the xDS proto and the `san` argument is a DNS SAN from the certificate, and +// this is the one which can potentially contain a wildcard pattern. +func dnsMatch(host, san string) bool { + // Add trailing "." and turn them into absolute domain names. + if !strings.HasSuffix(host, ".") { + host += "." + } + if !strings.HasSuffix(san, ".") { + san += "." } - return &HandshakeInfo{ - rootProvider: root, - identityProvider: identity, - acceptedSANs: acceptedSANs, + // Domain names are case-insensitive. + host = strings.ToLower(host) + san = strings.ToLower(san) + + // If san does not contain a wildcard, do exact match. + if !strings.Contains(san, "*") { + return host == san + } + + // Wildcard dns matching rules + // - '*' is only permitted in the left-most label and must be the only + // character in that label. For example, *.example.com is permitted, while + // *a.example.com, a*.example.com, a*b.example.com, a.*.example.com are + // not permitted. + // - '*' matches a single domain name component. For example, *.example.com + // matches test.example.com but does not match sub.test.example.com. + // - Wildcard patterns for single-label domain names are not permitted. + if san == "*." || !strings.HasPrefix(san, "*.") || strings.Contains(san[1:], "*") { + return false + } + // Optimization: at this point, we know that the san contains a '*' and + // is the first domain component of san. So, the host name must be at + // least as long as the san to be able to match. + if len(host) < len(san) { + return false + } + // Hostname must end with the non-wildcard portion of san. + if !strings.HasSuffix(host, san[1:]) { + return false } + // At this point we know that the hostName and san share the same suffix + // (the non-wildcard portion of san). Now, we just need to make sure + // that the '*' does not match across domain components. + hostPrefix := strings.TrimSuffix(host, san[1:]) + return !strings.Contains(hostPrefix, ".") +} + +// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root +// and identity certificate providers. +func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo { + return &HandshakeInfo{rootProvider: root, identityProvider: identity} } diff --git a/internal/credentials/xds/handshake_info_test.go b/internal/credentials/xds/handshake_info_test.go new file mode 100644 index 00000000000..91257a1925d --- /dev/null +++ b/internal/credentials/xds/handshake_info_test.go @@ -0,0 +1,304 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "crypto/x509" + "net" + "net/url" + "regexp" + "testing" + + "google.golang.org/grpc/internal/xds/matcher" +) + +func TestDNSMatch(t *testing.T) { + tests := []struct { + desc string + host string + pattern string + wantMatch bool + }{ + { + desc: "invalid wildcard 1", + host: "aa.example.com", + pattern: "*a.example.com", + wantMatch: false, + }, + { + desc: "invalid wildcard 2", + host: "aa.example.com", + pattern: "a*.example.com", + wantMatch: false, + }, + { + desc: "invalid wildcard 3", + host: "abc.example.com", + pattern: "a*c.example.com", + wantMatch: false, + }, + { + desc: "wildcard in one of the middle components", + host: "abc.test.example.com", + pattern: "abc.*.example.com", + wantMatch: false, + }, + { + desc: "single component wildcard", + host: "a.example.com", + pattern: "*", + wantMatch: false, + }, + { + desc: "short host name", + host: "a.com", + pattern: "*.example.com", + wantMatch: false, + }, + { + desc: "suffix mismatch", + host: "a.notexample.com", + pattern: "*.example.com", + wantMatch: false, + }, + { + desc: "wildcard match across components", + host: "sub.test.example.com", + pattern: "*.example.com.", + wantMatch: false, + }, + { + desc: "host doesn't end in period", + host: "test.example.com", + pattern: "test.example.com.", + wantMatch: true, + }, + { + desc: "pattern doesn't end in period", + host: "test.example.com.", + pattern: "test.example.com", + wantMatch: true, + }, + { + desc: "case insensitive", + host: "TEST.EXAMPLE.COM.", + pattern: "test.example.com.", + wantMatch: true, + }, + { + desc: "simple match", + host: "test.example.com", + pattern: "test.example.com", + wantMatch: true, + }, + { + desc: "good wildcard", + host: "a.example.com", + pattern: "*.example.com", + wantMatch: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotMatch := dnsMatch(test.host, test.pattern) + if gotMatch != test.wantMatch { + t.Fatalf("dnsMatch(%s, %s) = %v, want %v", test.host, test.pattern, gotMatch, test.wantMatch) + } + }) + } +} + +func TestMatchingSANExists_FailureCases(t *testing.T) { + url1, err := url.Parse("http://golang.org") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + url2, err := url.Parse("https://github.com/grpc/grpc-go") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + inputCert := &x509.Certificate{ + DNSNames: []string{"foo.bar.example.com", "bar.baz.test.com", "*.example.com"}, + EmailAddresses: []string{"foobar@example.com", "barbaz@test.com"}, + IPAddresses: []net.IP{net.ParseIP("192.0.0.1"), net.ParseIP("2001:db8::68")}, + URIs: []*url.URL{url1, url2}, + } + + tests := []struct { + desc string + sanMatchers []matcher.StringMatcher + }{ + { + desc: "exact match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP("abcd.test.com"), nil, nil, nil, nil, false), + matcher.StringMatcherForTesting(newStringP("http://golang"), nil, nil, nil, nil, false), + matcher.StringMatcherForTesting(newStringP("HTTP://GOLANG.ORG"), nil, nil, nil, nil, false), + }, + }, + { + desc: "prefix match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, newStringP("i-aint-the-one"), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, newStringP("192.168.1.1"), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, newStringP("FOO.BAR"), nil, nil, nil, false), + }, + }, + { + desc: "suffix match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, newStringP("i-aint-the-one"), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP("1::68"), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP(".COM"), nil, nil, false), + }, + }, + { + desc: "regex match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`.*\.examples\.com`), false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`), false), + }, + }, + { + desc: "contains match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("i-aint-the-one"), nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("2001:db8:1:1::68"), nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("GRPC"), nil, false), + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + hi := NewHandshakeInfo(nil, nil) + hi.SetSANMatchers(test.sanMatchers) + + if hi.MatchingSANExists(inputCert) { + t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v succeeded when expected to fail", inputCert, test.sanMatchers) + } + }) + } +} + +func TestMatchingSANExists_Success(t *testing.T) { + url1, err := url.Parse("http://golang.org") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + url2, err := url.Parse("https://github.com/grpc/grpc-go") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + inputCert := &x509.Certificate{ + DNSNames: []string{"baz.test.com", "*.example.com"}, + EmailAddresses: []string{"foobar@example.com", "barbaz@test.com"}, + IPAddresses: []net.IP{net.ParseIP("192.0.0.1"), net.ParseIP("2001:db8::68")}, + URIs: []*url.URL{url1, url2}, + } + + tests := []struct { + desc string + sanMatchers []matcher.StringMatcher + }{ + { + desc: "no san matchers", + }, + { + desc: "exact match dns wildcard", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, newStringP("192.168.1.1"), nil, nil, nil, false), + matcher.StringMatcherForTesting(newStringP("https://github.com/grpc/grpc-java"), nil, nil, nil, nil, false), + matcher.StringMatcherForTesting(newStringP("abc.example.com"), nil, nil, nil, nil, false), + }, + }, + { + desc: "exact match ignore case", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP("FOOBAR@EXAMPLE.COM"), nil, nil, nil, nil, true), + }, + }, + { + desc: "prefix match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, newStringP(".co.in"), nil, nil, false), + matcher.StringMatcherForTesting(nil, newStringP("192.168.1.1"), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, newStringP("baz.test"), nil, nil, nil, false), + }, + }, + { + desc: "prefix match ignore case", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, newStringP("BAZ.test"), nil, nil, nil, true), + }, + }, + { + desc: "suffix match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`), false), + matcher.StringMatcherForTesting(nil, nil, newStringP("192.168.1.1"), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP("@test.com"), nil, nil, false), + }, + }, + { + desc: "suffix match ignore case", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, newStringP("@test.COM"), nil, nil, true), + }, + }, + { + desc: "regex match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("https://github.com/grpc/grpc-java"), nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`), false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`.*\.test\.com`), false), + }, + }, + { + desc: "contains match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP("https://github.com/grpc/grpc-java"), nil, nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("2001:68::db8"), nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("192.0.0"), nil, false), + }, + }, + { + desc: "contains match ignore case", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("GRPC"), nil, true), + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + hi := NewHandshakeInfo(nil, nil) + hi.SetSANMatchers(test.sanMatchers) + + if !hi.MatchingSANExists(inputCert) { + t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v failed when expected to succeed", inputCert, test.sanMatchers) + } + }) + } +} + +func newStringP(s string) *string { + return &s +} diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go index 73931a94bca..e766ac04af2 100644 --- a/internal/envconfig/envconfig.go +++ b/internal/envconfig/envconfig.go @@ -22,6 +22,8 @@ package envconfig import ( "os" "strings" + + xdsenv "google.golang.org/grpc/internal/xds/env" ) const ( @@ -31,8 +33,8 @@ const ( ) var ( - // Retry is set if retry is explicitly enabled via "GRPC_GO_RETRY=on". - Retry = strings.EqualFold(os.Getenv(retryStr), "on") + // Retry is set if retry is explicitly enabled via "GRPC_GO_RETRY=on" or if XDS retry support is enabled. + Retry = strings.EqualFold(os.Getenv(retryStr), "on") || xdsenv.RetrySupport // TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false"). TXTErrIgnore = !strings.EqualFold(os.Getenv(txtErrIgnoreStr), "false") ) diff --git a/internal/googlecloud/googlecloud.go b/internal/googlecloud/googlecloud.go new file mode 100644 index 00000000000..d6c9e03fc4c --- /dev/null +++ b/internal/googlecloud/googlecloud.go @@ -0,0 +1,128 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package googlecloud contains internal helpful functions for google cloud. +package googlecloud + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "regexp" + "runtime" + "strings" + "sync" + + "google.golang.org/grpc/grpclog" + internalgrpclog "google.golang.org/grpc/internal/grpclog" +) + +const ( + linuxProductNameFile = "/sys/class/dmi/id/product_name" + windowsCheckCommand = "powershell.exe" + windowsCheckCommandArgs = "Get-WmiObject -Class Win32_BIOS" + powershellOutputFilter = "Manufacturer" + windowsManufacturerRegex = ":(.*)" + + logPrefix = "[googlecloud]" +) + +var ( + // The following two variables will be reassigned in tests. + runningOS = runtime.GOOS + manufacturerReader = func() (io.Reader, error) { + switch runningOS { + case "linux": + return os.Open(linuxProductNameFile) + case "windows": + cmd := exec.Command(windowsCheckCommand, windowsCheckCommandArgs) + out, err := cmd.Output() + if err != nil { + return nil, err + } + for _, line := range strings.Split(strings.TrimSuffix(string(out), "\n"), "\n") { + if strings.HasPrefix(line, powershellOutputFilter) { + re := regexp.MustCompile(windowsManufacturerRegex) + name := re.FindString(line) + name = strings.TrimLeft(name, ":") + return strings.NewReader(name), nil + } + } + return nil, errors.New("cannot determine the machine's manufacturer") + default: + return nil, fmt.Errorf("%s is not supported", runningOS) + } + } + + vmOnGCEOnce sync.Once + vmOnGCE bool + + logger = internalgrpclog.NewPrefixLogger(grpclog.Component("googlecloud"), logPrefix) +) + +// OnGCE returns whether the client is running on GCE. +// +// It provides similar functionality as metadata.OnGCE from the cloud library +// package. We keep this to avoid depending on the cloud library module. +func OnGCE() bool { + vmOnGCEOnce.Do(func() { + vmOnGCE = isRunningOnGCE() + }) + return vmOnGCE +} + +// isRunningOnGCE checks whether the local system, without doing a network request is +// running on GCP. +func isRunningOnGCE() bool { + manufacturer, err := readManufacturer() + if err != nil { + logger.Infof("failed to read manufacturer %v, returning OnGCE=false", err) + return false + } + name := string(manufacturer) + switch runningOS { + case "linux": + name = strings.TrimSpace(name) + return name == "Google" || name == "Google Compute Engine" + case "windows": + name = strings.Replace(name, " ", "", -1) + name = strings.Replace(name, "\n", "", -1) + name = strings.Replace(name, "\r", "", -1) + return name == "Google" + default: + return false + } +} + +func readManufacturer() ([]byte, error) { + reader, err := manufacturerReader() + if err != nil { + return nil, err + } + if reader == nil { + return nil, errors.New("got nil reader") + } + manufacturer, err := ioutil.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("failed reading %v: %v", linuxProductNameFile, err) + } + return manufacturer, nil +} diff --git a/internal/googlecloud/googlecloud_test.go b/internal/googlecloud/googlecloud_test.go new file mode 100644 index 00000000000..bd5a42ffab9 --- /dev/null +++ b/internal/googlecloud/googlecloud_test.go @@ -0,0 +1,86 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package googlecloud + +import ( + "io" + "os" + "strings" + "testing" +) + +func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() { + tmpOS := runningOS + tmpReader := manufacturerReader + + // Set test OS and reader function. + runningOS = testOS + manufacturerReader = reader + return func() { + runningOS = tmpOS + manufacturerReader = tmpReader + } +} + +func setup(testOS string, testReader io.Reader) func() { + reader := func() (io.Reader, error) { + return testReader, nil + } + return setupManufacturerReader(testOS, reader) +} + +func setupError(testOS string, err error) func() { + reader := func() (io.Reader, error) { + return nil, err + } + return setupManufacturerReader(testOS, reader) +} + +func TestIsRunningOnGCE(t *testing.T) { + for _, tc := range []struct { + description string + testOS string + testReader io.Reader + out bool + }{ + // Linux tests. + {"linux: not a GCP platform", "linux", strings.NewReader("not GCP"), false}, + {"Linux: GCP platform (Google)", "linux", strings.NewReader("Google"), true}, + {"Linux: GCP platform (Google Compute Engine)", "linux", strings.NewReader("Google Compute Engine"), true}, + {"Linux: GCP platform (Google Compute Engine) with extra spaces", "linux", strings.NewReader(" Google Compute Engine "), true}, + // Windows tests. + {"windows: not a GCP platform", "windows", strings.NewReader("not GCP"), false}, + {"windows: GCP platform (Google)", "windows", strings.NewReader("Google"), true}, + {"windows: GCP platform (Google) with extra spaces", "windows", strings.NewReader(" Google "), true}, + } { + reverseFunc := setup(tc.testOS, tc.testReader) + if got, want := isRunningOnGCE(), tc.out; got != want { + t.Errorf("%v: isRunningOnGCE()=%v, want %v", tc.description, got, want) + } + reverseFunc() + } +} + +func TestIsRunningOnGCENoProductNameFile(t *testing.T) { + reverseFunc := setupError("linux", os.ErrNotExist) + if isRunningOnGCE() { + t.Errorf("ErrNotExist: isRunningOnGCE()=true, want false") + } + reverseFunc() +} diff --git a/internal/grpcrand/grpcrand.go b/internal/grpcrand/grpcrand.go index 200b115ca20..740f83c2b76 100644 --- a/internal/grpcrand/grpcrand.go +++ b/internal/grpcrand/grpcrand.go @@ -31,26 +31,37 @@ var ( mu sync.Mutex ) +// Int implements rand.Int on the grpcrand global source. +func Int() int { + mu.Lock() + defer mu.Unlock() + return r.Int() +} + // Int63n implements rand.Int63n on the grpcrand global source. func Int63n(n int64) int64 { mu.Lock() - res := r.Int63n(n) - mu.Unlock() - return res + defer mu.Unlock() + return r.Int63n(n) } // Intn implements rand.Intn on the grpcrand global source. func Intn(n int) int { mu.Lock() - res := r.Intn(n) - mu.Unlock() - return res + defer mu.Unlock() + return r.Intn(n) } // Float64 implements rand.Float64 on the grpcrand global source. func Float64() float64 { mu.Lock() - res := r.Float64() - mu.Unlock() - return res + defer mu.Unlock() + return r.Float64() +} + +// Uint64 implements rand.Uint64 on the grpcrand global source. +func Uint64() uint64 { + mu.Lock() + defer mu.Unlock() + return r.Uint64() } diff --git a/internal/grpctest/tlogger.go b/internal/grpctest/tlogger.go index 95c3598d1d5..bbb2a2ff4fb 100644 --- a/internal/grpctest/tlogger.go +++ b/internal/grpctest/tlogger.go @@ -41,19 +41,34 @@ const callingFrame = 4 type logType int +func (l logType) String() string { + switch l { + case infoLog: + return "INFO" + case warningLog: + return "WARNING" + case errorLog: + return "ERROR" + case fatalLog: + return "FATAL" + } + return "UNKNOWN" +} + const ( - logLog logType = iota + infoLog logType = iota + warningLog errorLog fatalLog ) type tLogger struct { v int - t *testing.T - start time.Time initialized bool - m sync.Mutex // protects errors + mu sync.Mutex // guards t, start, and errors + t *testing.T + start time.Time errors map[*regexp.Regexp]int } @@ -76,12 +91,14 @@ func getCallingPrefix(depth int) (string, error) { // log logs the message with the specified parameters to the tLogger. func (g *tLogger) log(ltype logType, depth int, format string, args ...interface{}) { + g.mu.Lock() + defer g.mu.Unlock() prefix, err := getCallingPrefix(callingFrame + depth) if err != nil { g.t.Error(err) return } - args = append([]interface{}{prefix}, args...) + args = append([]interface{}{ltype.String() + " " + prefix}, args...) args = append(args, fmt.Sprintf(" (t=+%s)", time.Since(g.start))) if format == "" { @@ -119,14 +136,14 @@ func (g *tLogger) log(ltype logType, depth int, format string, args ...interface // Update updates the testing.T that the testing logger logs to. Should be done // before every test. It also initializes the tLogger if it has not already. func (g *tLogger) Update(t *testing.T) { + g.mu.Lock() + defer g.mu.Unlock() if !g.initialized { grpclog.SetLoggerV2(TLogger) g.initialized = true } g.t = t g.start = time.Now() - g.m.Lock() - defer g.m.Unlock() g.errors = map[*regexp.Regexp]int{} } @@ -141,20 +158,20 @@ func (g *tLogger) ExpectError(expr string) { // ExpectErrorN declares an error to be expected n times. func (g *tLogger) ExpectErrorN(expr string, n int) { + g.mu.Lock() + defer g.mu.Unlock() re, err := regexp.Compile(expr) if err != nil { g.t.Error(err) return } - g.m.Lock() - defer g.m.Unlock() g.errors[re] += n } // EndTest checks if expected errors were not encountered. func (g *tLogger) EndTest(t *testing.T) { - g.m.Lock() - defer g.m.Unlock() + g.mu.Lock() + defer g.mu.Unlock() for re, count := range g.errors { if count > 0 { t.Errorf("Expected error '%v' not encountered", re.String()) @@ -165,8 +182,6 @@ func (g *tLogger) EndTest(t *testing.T) { // expected determines if the error string is protected or not. func (g *tLogger) expected(s string) bool { - g.m.Lock() - defer g.m.Unlock() for re, count := range g.errors { if re.FindStringIndex(s) != nil { g.errors[re]-- @@ -180,35 +195,35 @@ func (g *tLogger) expected(s string) bool { } func (g *tLogger) Info(args ...interface{}) { - g.log(logLog, 0, "", args...) + g.log(infoLog, 0, "", args...) } func (g *tLogger) Infoln(args ...interface{}) { - g.log(logLog, 0, "", args...) + g.log(infoLog, 0, "", args...) } func (g *tLogger) Infof(format string, args ...interface{}) { - g.log(logLog, 0, format, args...) + g.log(infoLog, 0, format, args...) } func (g *tLogger) InfoDepth(depth int, args ...interface{}) { - g.log(logLog, depth, "", args...) + g.log(infoLog, depth, "", args...) } func (g *tLogger) Warning(args ...interface{}) { - g.log(logLog, 0, "", args...) + g.log(warningLog, 0, "", args...) } func (g *tLogger) Warningln(args ...interface{}) { - g.log(logLog, 0, "", args...) + g.log(warningLog, 0, "", args...) } func (g *tLogger) Warningf(format string, args ...interface{}) { - g.log(logLog, 0, format, args...) + g.log(warningLog, 0, format, args...) } func (g *tLogger) WarningDepth(depth int, args ...interface{}) { - g.log(logLog, depth, "", args...) + g.log(warningLog, depth, "", args...) } func (g *tLogger) Error(args ...interface{}) { diff --git a/internal/grpcutil/target_test.go b/internal/grpcutil/target_test.go deleted file mode 100644 index f6c586dd080..00000000000 --- a/internal/grpcutil/target_test.go +++ /dev/null @@ -1,114 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package grpcutil - -import ( - "testing" - - "google.golang.org/grpc/resolver" -) - -func TestParseTarget(t *testing.T) { - for _, test := range []resolver.Target{ - {Scheme: "dns", Authority: "", Endpoint: "google.com"}, - {Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com"}, - {Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com/?a=b"}, - {Scheme: "passthrough", Authority: "", Endpoint: "/unix/socket/address"}, - } { - str := test.Scheme + "://" + test.Authority + "/" + test.Endpoint - got := ParseTarget(str, false) - if got != test { - t.Errorf("ParseTarget(%q, false) = %+v, want %+v", str, got, test) - } - got = ParseTarget(str, true) - if got != test { - t.Errorf("ParseTarget(%q, true) = %+v, want %+v", str, got, test) - } - } -} - -func TestParseTargetString(t *testing.T) { - for _, test := range []struct { - targetStr string - want resolver.Target - wantWithDialer resolver.Target - }{ - {targetStr: "", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}}, - {targetStr: ":///", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}}, - {targetStr: "a:///", want: resolver.Target{Scheme: "a", Authority: "", Endpoint: ""}}, - {targetStr: "://a/", want: resolver.Target{Scheme: "", Authority: "a", Endpoint: ""}}, - {targetStr: ":///a", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a"}}, - {targetStr: "a://b/", want: resolver.Target{Scheme: "a", Authority: "b", Endpoint: ""}}, - {targetStr: "a:///b", want: resolver.Target{Scheme: "a", Authority: "", Endpoint: "b"}}, - {targetStr: "://a/b", want: resolver.Target{Scheme: "", Authority: "a", Endpoint: "b"}}, - {targetStr: "a://b/c", want: resolver.Target{Scheme: "a", Authority: "b", Endpoint: "c"}}, - {targetStr: "dns:///google.com", want: resolver.Target{Scheme: "dns", Authority: "", Endpoint: "google.com"}}, - {targetStr: "dns://a.server.com/google.com", want: resolver.Target{Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com"}}, - {targetStr: "dns://a.server.com/google.com/?a=b", want: resolver.Target{Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com/?a=b"}}, - - {targetStr: "/", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "/"}}, - {targetStr: "google.com", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "google.com"}}, - {targetStr: "google.com/?a=b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "google.com/?a=b"}}, - {targetStr: "/unix/socket/address", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "/unix/socket/address"}}, - - // If we can only parse part of the target. - {targetStr: "://", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "://"}}, - {targetStr: "unix://domain", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix://domain"}}, - {targetStr: "unix://a/b/c", want: resolver.Target{Scheme: "unix", Authority: "a", Endpoint: "/b/c"}}, - {targetStr: "a:b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:b"}}, - {targetStr: "a/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a/b"}}, - {targetStr: "a:/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:/b"}}, - {targetStr: "a//b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a//b"}}, - {targetStr: "a://b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a://b"}}, - - // Unix cases without custom dialer. - // unix:[local_path], unix:[/absolute], and unix://[/absolute] have different - // behaviors with a custom dialer, to prevent behavior changes with custom dialers. - {targetStr: "unix:a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:a/b/c"}}, - {targetStr: "unix:/a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/a/b/c"}}, - {targetStr: "unix:///a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}}, - - {targetStr: "unix-abstract:a/b/c", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a/b/c"}}, - {targetStr: "unix-abstract:a b", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a b"}}, - {targetStr: "unix-abstract:a:b", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a:b"}}, - {targetStr: "unix-abstract:a-b", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a-b"}}, - {targetStr: "unix-abstract:/ a///://::!@#$%^&*()b", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/ a///://::!@#$%^&*()b"}}, - {targetStr: "unix-abstract:passthrough:abc", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "passthrough:abc"}}, - {targetStr: "unix-abstract:unix:///abc", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "unix:///abc"}}, - {targetStr: "unix-abstract:///a/b/c", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/a/b/c"}}, - {targetStr: "unix-abstract://authority/a/b/c", want: resolver.Target{Scheme: "unix-abstract", Authority: "authority", Endpoint: "/a/b/c"}}, - {targetStr: "unix-abstract:///", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/"}}, - {targetStr: "unix-abstract://authority", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "//authority"}}, - - {targetStr: "passthrough:///unix:///a/b/c", want: resolver.Target{Scheme: "passthrough", Authority: "", Endpoint: "unix:///a/b/c"}}, - } { - got := ParseTarget(test.targetStr, false) - if got != test.want { - t.Errorf("ParseTarget(%q, false) = %+v, want %+v", test.targetStr, got, test.want) - } - wantWithDialer := test.wantWithDialer - if wantWithDialer == (resolver.Target{}) { - wantWithDialer = test.want - } - got = ParseTarget(test.targetStr, true) - if got != wantWithDialer { - t.Errorf("ParseTarget(%q, true) = %+v, want %+v", test.targetStr, got, wantWithDialer) - } - } -} diff --git a/internal/internal.go b/internal/internal.go index 1e2834c70f6..1b596bf3579 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -38,12 +38,6 @@ var ( // KeepaliveMinPingTime is the minimum ping interval. This must be 10s by // default, but tests may wish to set it lower for convenience. KeepaliveMinPingTime = 10 * time.Second - // NewRequestInfoContext creates a new context based on the argument context attaching - // the passed in RequestInfo to the new context. - NewRequestInfoContext interface{} // func(context.Context, credentials.RequestInfo) context.Context - // NewClientHandshakeInfoContext returns a copy of the input context with - // the passed in ClientHandshakeInfo struct added to it. - NewClientHandshakeInfoContext interface{} // func(context.Context, credentials.ClientHandshakeInfo) context.Context // ParseServiceConfigForTesting is for creating a fake // ClientConn for resolver testing only ParseServiceConfigForTesting interface{} // func(string) *serviceconfig.ParseResult @@ -65,6 +59,11 @@ var ( // gRPC server. An xDS-enabled server needs to know what type of credentials // is configured on the underlying gRPC server. This is set by server.go. GetServerCredentials interface{} // func (*grpc.Server) credentials.TransportCredentials + // DrainServerTransports initiates a graceful close of existing connections + // on a gRPC server accepted on the provided listener address. An + // xDS-enabled server invokes this method on a grpc.Server when a particular + // listener moves to "not-serving" mode. + DrainServerTransports interface{} // func(*grpc.Server, string) ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/internal/leakcheck/leakcheck.go b/internal/leakcheck/leakcheck.go index 1d4fcef994b..946c575f140 100644 --- a/internal/leakcheck/leakcheck.go +++ b/internal/leakcheck/leakcheck.go @@ -42,7 +42,6 @@ var goroutinesToIgnore = []string{ "runtime_mcall", "(*loggingT).flushDaemon", "goroutine in C code", - "httputil.DumpRequestOut", // TODO: Remove this once Go1.13 support is removed. https://github.com/golang/go/issues/37669. } // RegisterIgnoreGoroutine appends s into the ignore goroutine list. The diff --git a/internal/pretty/pretty.go b/internal/pretty/pretty.go new file mode 100644 index 00000000000..0177af4b511 --- /dev/null +++ b/internal/pretty/pretty.go @@ -0,0 +1,82 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package pretty defines helper functions to pretty-print structs for logging. +package pretty + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/golang/protobuf/jsonpb" + protov1 "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/encoding/protojson" + protov2 "google.golang.org/protobuf/proto" +) + +const jsonIndent = " " + +// ToJSON marshals the input into a json string. +// +// If marshal fails, it falls back to fmt.Sprintf("%+v"). +func ToJSON(e interface{}) string { + switch ee := e.(type) { + case protov1.Message: + mm := jsonpb.Marshaler{Indent: jsonIndent} + ret, err := mm.MarshalToString(ee) + if err != nil { + // This may fail for proto.Anys, e.g. for xDS v2, LDS, the v2 + // messages are not imported, and this will fail because the message + // is not found. + return fmt.Sprintf("%+v", ee) + } + return ret + case protov2.Message: + mm := protojson.MarshalOptions{ + Multiline: true, + Indent: jsonIndent, + } + ret, err := mm.Marshal(ee) + if err != nil { + // This may fail for proto.Anys, e.g. for xDS v2, LDS, the v2 + // messages are not imported, and this will fail because the message + // is not found. + return fmt.Sprintf("%+v", ee) + } + return string(ret) + default: + ret, err := json.MarshalIndent(ee, "", jsonIndent) + if err != nil { + return fmt.Sprintf("%+v", ee) + } + return string(ret) + } +} + +// FormatJSON formats the input json bytes with indentation. +// +// If Indent fails, it returns the unchanged input as string. +func FormatJSON(b []byte) string { + var out bytes.Buffer + err := json.Indent(&out, b, "", jsonIndent) + if err != nil { + return string(b) + } + return out.String() +} diff --git a/internal/profiling/buffer/buffer.go b/internal/profiling/buffer/buffer.go index 45745cd0919..f4cd4201de1 100644 --- a/internal/profiling/buffer/buffer.go +++ b/internal/profiling/buffer/buffer.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2019 gRPC authors. diff --git a/internal/profiling/buffer/buffer_appengine.go b/internal/profiling/buffer/buffer_appengine.go deleted file mode 100644 index c92599e5b9c..00000000000 --- a/internal/profiling/buffer/buffer_appengine.go +++ /dev/null @@ -1,43 +0,0 @@ -// +build appengine - -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package buffer - -// CircularBuffer is a no-op implementation for appengine builds. -// -// Appengine does not support stats because of lack of the support for unsafe -// pointers, which are necessary to efficiently store and retrieve things into -// and from a circular buffer. As a result, Push does not do anything and Drain -// returns an empty slice. -type CircularBuffer struct{} - -// NewCircularBuffer returns a no-op for appengine builds. -func NewCircularBuffer(size uint32) (*CircularBuffer, error) { - return nil, nil -} - -// Push returns a no-op for appengine builds. -func (cb *CircularBuffer) Push(x interface{}) { -} - -// Drain returns a no-op for appengine builds. -func (cb *CircularBuffer) Drain() []interface{} { - return nil -} diff --git a/internal/profiling/buffer/buffer_test.go b/internal/profiling/buffer/buffer_test.go index 86bd77d4a2e..a7f3b61e4af 100644 --- a/internal/profiling/buffer/buffer_test.go +++ b/internal/profiling/buffer/buffer_test.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2019 gRPC authors. diff --git a/internal/profiling/goid_modified.go b/internal/profiling/goid_modified.go index b186499cd0d..ff1a5f5933b 100644 --- a/internal/profiling/goid_modified.go +++ b/internal/profiling/goid_modified.go @@ -1,3 +1,4 @@ +//go:build grpcgoid // +build grpcgoid /* diff --git a/internal/profiling/goid_regular.go b/internal/profiling/goid_regular.go index 891c2e98f9d..042933227d8 100644 --- a/internal/profiling/goid_regular.go +++ b/internal/profiling/goid_regular.go @@ -1,3 +1,4 @@ +//go:build !grpcgoid // +build !grpcgoid /* diff --git a/internal/resolver/config_selector.go b/internal/resolver/config_selector.go index 5e7f36703d4..be7e13d5859 100644 --- a/internal/resolver/config_selector.go +++ b/internal/resolver/config_selector.go @@ -117,9 +117,12 @@ type ClientInterceptor interface { NewStream(ctx context.Context, ri RPCInfo, done func(), newStream func(ctx context.Context, done func()) (ClientStream, error)) (ClientStream, error) } -// ServerInterceptor is unimplementable; do not use. +// ServerInterceptor is an interceptor for incoming RPC's on gRPC server side. type ServerInterceptor interface { - notDefined() + // AllowRPC checks if an incoming RPC is allowed to proceed based on + // information about connection RPC was received on, and HTTP Headers. This + // information will be piped into context. + AllowRPC(ctx context.Context) error // TODO: Make this a real interceptor for filters such as rate limiting. } type csKeyType string diff --git a/internal/resolver/config_selector_test.go b/internal/resolver/config_selector_test.go index e5a50995df1..e1dae8bde27 100644 --- a/internal/resolver/config_selector_test.go +++ b/internal/resolver/config_selector_test.go @@ -48,6 +48,8 @@ func (s) TestSafeConfigSelector(t *testing.T) { retChan1 := make(chan *RPCConfig) retChan2 := make(chan *RPCConfig) + defer close(retChan1) + defer close(retChan2) one := 1 two := 2 @@ -55,8 +57,8 @@ func (s) TestSafeConfigSelector(t *testing.T) { resp1 := &RPCConfig{MethodConfig: serviceconfig.MethodConfig{MaxReqSize: &one}} resp2 := &RPCConfig{MethodConfig: serviceconfig.MethodConfig{MaxReqSize: &two}} - cs1Called := make(chan struct{}) - cs2Called := make(chan struct{}) + cs1Called := make(chan struct{}, 1) + cs2Called := make(chan struct{}, 1) cs1 := &fakeConfigSelector{ selectConfig: func(r RPCInfo) (*RPCConfig, error) { diff --git a/internal/resolver/dns/dns_resolver.go b/internal/resolver/dns/dns_resolver.go index 30423556658..75301c51491 100644 --- a/internal/resolver/dns/dns_resolver.go +++ b/internal/resolver/dns/dns_resolver.go @@ -34,6 +34,7 @@ import ( grpclbstate "google.golang.org/grpc/balancer/grpclb/state" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/resolver" @@ -46,6 +47,13 @@ var EnableSRVLookups = false var logger = grpclog.Component("dns") +// Globals to stub out in tests. TODO: Perhaps these two can be combined into a +// single variable for testing the resolver? +var ( + newTimer = time.NewTimer + newTimerDNSResRate = time.NewTimer +) + func init() { resolver.Register(NewBuilder()) } @@ -143,7 +151,6 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts d.wg.Add(1) go d.watcher() - d.ResolveNow(resolver.ResolveNowOptions{}) return d, nil } @@ -201,28 +208,38 @@ func (d *dnsResolver) Close() { func (d *dnsResolver) watcher() { defer d.wg.Done() + backoffIndex := 1 for { - select { - case <-d.ctx.Done(): - return - case <-d.rn: - } - state, err := d.lookup() if err != nil { + // Report error to the underlying grpc.ClientConn. d.cc.ReportError(err) } else { - d.cc.UpdateState(*state) + err = d.cc.UpdateState(*state) } - // Sleep to prevent excessive re-resolutions. Incoming resolution requests - // will be queued in d.rn. - t := time.NewTimer(minDNSResRate) + var timer *time.Timer + if err == nil { + // Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least + // to prevent constantly re-resolving. + backoffIndex = 1 + timer = newTimerDNSResRate(minDNSResRate) + select { + case <-d.ctx.Done(): + timer.Stop() + return + case <-d.rn: + } + } else { + // Poll on an error found in DNS Resolver or an error received from ClientConn. + timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex)) + backoffIndex++ + } select { - case <-t.C: case <-d.ctx.Done(): - t.Stop() + timer.Stop() return + case <-timer.C: } } } @@ -260,18 +277,13 @@ func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) { return newAddrs, nil } -var filterError = func(err error) error { +func handleDNSError(err error, lookupType string) error { if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary { // Timeouts and temporary errors should be communicated to gRPC to // attempt another DNS query (with backoff). Other errors should be // suppressed (they may represent the absence of a TXT record). return nil } - return err -} - -func handleDNSError(err error, lookupType string) error { - err = filterError(err) if err != nil { err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err) logger.Info(err) @@ -306,12 +318,12 @@ func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult { } func (d *dnsResolver) lookupHost() ([]resolver.Address, error) { - var newAddrs []resolver.Address addrs, err := d.resolver.LookupHost(d.ctx, d.host) if err != nil { err = handleDNSError(err, "A") return nil, err } + newAddrs := make([]resolver.Address, 0, len(addrs)) for _, a := range addrs { ip, ok := formatIP(a) if !ok { diff --git a/internal/resolver/dns/dns_resolver_test.go b/internal/resolver/dns/dns_resolver_test.go index 1c8469a275a..69307a981cf 100644 --- a/internal/resolver/dns/dns_resolver_test.go +++ b/internal/resolver/dns/dns_resolver_test.go @@ -30,9 +30,13 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/balancer" grpclbstate "google.golang.org/grpc/balancer/grpclb/state" "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/leakcheck" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) @@ -41,13 +45,15 @@ func TestMain(m *testing.M) { // Set a non-zero duration only for tests which are actually testing that // feature. replaceDNSResRate(time.Duration(0)) // No nead to clean up since we os.Exit - replaceNetFunc(nil) // No nead to clean up since we os.Exit + overrideDefaultResolver(false) // No nead to clean up since we os.Exit code := m.Run() os.Exit(code) } const ( - txtBytesLimit = 255 + txtBytesLimit = 255 + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond ) type testClientConn struct { @@ -57,13 +63,17 @@ type testClientConn struct { state resolver.State updateStateCalls int errChan chan error + updateStateErr error } -func (t *testClientConn) UpdateState(s resolver.State) { +func (t *testClientConn) UpdateState(s resolver.State) error { t.m1.Lock() defer t.m1.Unlock() t.state = s t.updateStateCalls++ + // This error determines whether DNS Resolver actually decides to exponentially backoff or not. + // This can be any error. + return t.updateStateErr } func (t *testClientConn) getState() (resolver.State, int) { @@ -99,12 +109,12 @@ type testResolver struct { // A write to this channel is made when this resolver receives a resolution // request. Tests can rely on reading from this channel to be notified about // resolution requests instead of sleeping for a predefined period of time. - ch chan struct{} + lookupHostCh *testutils.Channel } func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) { - if tr.ch != nil { - tr.ch <- struct{}{} + if tr.lookupHostCh != nil { + tr.lookupHostCh.Send(nil) } return hostLookup(host) } @@ -117,9 +127,17 @@ func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, erro return txtLookup(host) } -func replaceNetFunc(ch chan struct{}) func() { +// overrideDefaultResolver overrides the defaultResolver used by the code with +// an instance of the testResolver. pushOnLookup controls whether the +// testResolver created here pushes lookupHost events on its channel. +func overrideDefaultResolver(pushOnLookup bool) func() { oldResolver := defaultResolver - defaultResolver = &testResolver{ch: ch} + + var lookupHostCh *testutils.Channel + if pushOnLookup { + lookupHostCh = testutils.NewChannel() + } + defaultResolver = &testResolver{lookupHostCh: lookupHostCh} return func() { defaultResolver = oldResolver @@ -669,6 +687,13 @@ func TestResolve(t *testing.T) { func testDNSResolver(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(_ time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string addrWant []resolver.Address @@ -725,7 +750,7 @@ func testDNSResolver(t *testing.T) { if cnt == 0 { t.Fatalf("UpdateState not called after 2s; aborting") } - if !reflect.DeepEqual(a.addrWant, state.Addresses) { + if !cmp.Equal(a.addrWant, state.Addresses, cmpopts.EquateEmpty()) { t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) } sc := scFromState(state) @@ -736,12 +761,151 @@ func testDNSResolver(t *testing.T) { } } +// DNS Resolver immediately starts polling on an error from grpc. This should continue until the ClientConn doesn't +// send back an error from updating the DNS Resolver's state. +func TestDNSResolverExponentialBackoff(t *testing.T) { + defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + timerChan := testutils.NewChannel() + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, allows this test to call timer immediately. + t := time.NewTimer(time.Hour) + timerChan.Send(t) + return t + } + tests := []struct { + name string + target string + addrWant []resolver.Address + scWant string + }{ + { + "happy case default port", + "foo.bar.com", + []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, + generateSC("foo.bar.com"), + }, + { + "happy case specified port", + "foo.bar.com:1234", + []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, + generateSC("foo.bar.com"), + }, + { + "happy case another default port", + "srv.ipv4.single.fake", + []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, + generateSC("srv.ipv4.single.fake"), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + b := NewBuilder() + cc := &testClientConn{target: test.target} + // Cause ClientConn to return an error. + cc.updateStateErr = balancer.ErrBadResolverState + r, err := b.Build(resolver.Target{Endpoint: test.target}, cc, resolver.BuildOptions{}) + if err != nil { + t.Fatalf("Error building resolver for target %v: %v", test.target, err) + } + var state resolver.State + var cnt int + for i := 0; i < 2000; i++ { + state, cnt = cc.getState() + if cnt > 0 { + break + } + time.Sleep(time.Millisecond) + } + if cnt == 0 { + t.Fatalf("UpdateState not called after 2s; aborting") + } + if !reflect.DeepEqual(test.addrWant, state.Addresses) { + t.Errorf("Resolved addresses of target: %q = %+v, want %+v", test.target, state.Addresses, test.addrWant) + } + sc := scFromState(state) + if test.scWant != sc { + t.Errorf("Resolved service config of target: %q = %+v, want %+v", test.target, sc, test.scWant) + } + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + // Cause timer to go off 10 times, and see if it calls updateState() correctly. + for i := 0; i < 10; i++ { + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) + } + // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call + // ClientConn update state. + deadline := time.Now().Add(defaultTestTimeout) + for { + cc.m1.Lock() + got := cc.updateStateCalls + cc.m1.Unlock() + if got == 11 { + break + } + + if time.Now().After(deadline) { + t.Fatalf("Exponential backoff is not working as expected - should update state 11 times instead of %d", got) + } + + time.Sleep(time.Millisecond) + } + + // Update resolver.ClientConn to not return an error anymore - this should stop it from backing off. + cc.updateStateErr = nil + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) + // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call + // ClientConn update state the final time. The DNS Resolver should then stop polling. + deadline = time.Now().Add(defaultTestTimeout) + for { + cc.m1.Lock() + got := cc.updateStateCalls + cc.m1.Unlock() + if got == 12 { + break + } + + if time.Now().After(deadline) { + t.Fatalf("Exponential backoff is not working as expected - should stop backing off at 12 total UpdateState calls instead of %d", got) + } + + _, err := timerChan.ReceiveOrFail() + if err { + t.Fatalf("Should not poll again after Client Conn stops returning error.") + } + + time.Sleep(time.Millisecond) + } + r.Close() + }) + } +} + func testDNSResolverWithSRV(t *testing.T) { EnableSRVLookups = true defer func() { EnableSRVLookups = false }() defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(_ time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string addrWant []resolver.Address @@ -814,7 +978,7 @@ func testDNSResolverWithSRV(t *testing.T) { if cnt == 0 { t.Fatalf("UpdateState not called after 2s; aborting") } - if !reflect.DeepEqual(a.addrWant, state.Addresses) { + if !cmp.Equal(a.addrWant, state.Addresses, cmpopts.EquateEmpty()) { t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) } gs := grpclbstate.Get(state) @@ -855,6 +1019,13 @@ func mutateTbl(target string) func() { func testDNSResolveNow(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(_ time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string addrWant []resolver.Address @@ -926,6 +1097,13 @@ const colonDefaultPort = ":" + defaultPort func testIPResolver(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(_ time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string want []resolver.Address @@ -975,6 +1153,13 @@ func testIPResolver(t *testing.T) { func TestResolveFunc(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { addr string want error @@ -1013,6 +1198,13 @@ func TestResolveFunc(t *testing.T) { func TestDisableServiceConfig(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string scWant string @@ -1059,6 +1251,13 @@ func TestDisableServiceConfig(t *testing.T) { func TestTXTError(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } defer func(v bool) { envconfig.TXTErrIgnore = v }(envconfig.TXTErrIgnore) for _, ignore := range []bool{false, true} { envconfig.TXTErrIgnore = ignore @@ -1090,6 +1289,13 @@ func TestTXTError(t *testing.T) { } func TestDNSResolverRetry(t *testing.T) { + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } b := NewBuilder() target := "ipv4.single.fake" cc := &testClientConn{target: target} @@ -1144,6 +1350,13 @@ func TestDNSResolverRetry(t *testing.T) { func TestCustomAuthority(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { authority string @@ -1249,16 +1462,33 @@ func TestCustomAuthority(t *testing.T) { // requests. It sets the re-resolution rate to a small value and repeatedly // calls ResolveNow() and ensures only the expected number of resolution // requests are made. + func TestRateLimitedResolve(t *testing.T) { defer leakcheck.Check(t) - - const dnsResRate = 10 * time.Millisecond - dc := replaceDNSResRate(dnsResRate) - defer dc() + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential + // backoff. + return time.NewTimer(time.Hour) + } + defer func(nt func(d time.Duration) *time.Timer) { + newTimerDNSResRate = nt + }(newTimerDNSResRate) + + timerChan := testutils.NewChannel() + newTimerDNSResRate = func(d time.Duration) *time.Timer { + // Will never fire on its own, allows this test to call timer + // immediately. + t := time.NewTimer(time.Hour) + timerChan.Send(t) + return t + } // Create a new testResolver{} for this test because we want the exact count // of the number of times the resolver was invoked. - nc := replaceNetFunc(make(chan struct{})) + nc := overrideDefaultResolver(true) defer nc() target := "foo.bar.com" @@ -1281,55 +1511,65 @@ func TestRateLimitedResolve(t *testing.T) { t.Fatalf("delegate resolver returned unexpected type: %T\n", tr) } - // Observe the time before unblocking the lookupHost call. The 100ms rate - // limiting timer will begin immediately after that. This means the next - // resolution could happen less than 100ms if we read the time *after* - // receiving from tr.ch - start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Wait for the first resolution request to be done. This happens as part - // of the first iteration of the for loop in watcher() because we call - // ResolveNow in Build. - <-tr.ch - - // Here we start a couple of goroutines. One repeatedly calls ResolveNow() - // until asked to stop, and the other waits for two resolution requests to be - // made to our testResolver and stops the former. We measure the start and - // end times, and expect the duration elapsed to be in the interval - // {wantCalls*dnsResRate, wantCalls*dnsResRate} - done := make(chan struct{}) - go func() { - for { - select { - case <-done: - return - default: - r.ResolveNow(resolver.ResolveNowOptions{}) - time.Sleep(1 * time.Millisecond) - } - } - }() + // of the first iteration of the for loop in watcher(). + if _, err := tr.lookupHostCh.Receive(ctx); err != nil { + t.Fatalf("Timed out waiting for lookup() call.") + } - gotCalls := 0 - const wantCalls = 3 - min, max := wantCalls*dnsResRate, (wantCalls+1)*dnsResRate - tMax := time.NewTimer(max) - for gotCalls != wantCalls { - select { - case <-tr.ch: - gotCalls++ - case <-tMax.C: - t.Fatalf("Timed out waiting for %v calls after %v; got %v", wantCalls, max, gotCalls) - } + // Call Resolve Now 100 times, shouldn't continue onto next iteration of + // watcher, thus shouldn't lookup again. + for i := 0; i <= 100; i++ { + r.ResolveNow(resolver.ResolveNowOptions{}) } - close(done) - elapsed := time.Since(start) - if gotCalls != wantCalls { - t.Fatalf("resolve count mismatch for target: %q = %+v, want %+v\n", target, gotCalls, wantCalls) + continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer continueCancel() + + if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil { + t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") } - if elapsed < min { - t.Fatalf("elapsed time: %v, wanted it to be between {%v and %v}", elapsed, min, max) + + // Make the DNSMinResRate timer fire immediately (by receiving it, then + // resetting to 0), this will unblock the resolver which is currently + // blocked on the DNS Min Res Rate timer going off, which will allow it to + // continue to the next iteration of the watcher loop. + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) + + // Now that DNS Min Res Rate timer has gone off, it should lookup again. + if _, err := tr.lookupHostCh.Receive(ctx); err != nil { + t.Fatalf("Timed out waiting for lookup() call.") + } + + // Resolve Now 1000 more times, shouldn't lookup again as DNS Min Res Rate + // timer has not gone off. + for i := 0; i < 1000; i++ { + r.ResolveNow(resolver.ResolveNowOptions{}) + } + + if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil { + t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") + } + + // Make the DNSMinResRate timer fire immediately again. + timer, err = timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer = timer.(*time.Timer) + timerPointer.Reset(0) + + // Now that DNS Min Res Rate timer has gone off, it should lookup again. + if _, err = tr.lookupHostCh.Receive(ctx); err != nil { + t.Fatalf("Timed out waiting for lookup() call.") } wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}} @@ -1347,21 +1587,66 @@ func TestRateLimitedResolve(t *testing.T) { } } +// DNS Resolver immediately starts polling on an error. This will cause the re-resolution to return another error. +// Thus, test that it constantly sends errors to the grpc.ClientConn. func TestReportError(t *testing.T) { const target = "notfoundaddress" + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + timerChan := testutils.NewChannel() + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, allows this test to call timer immediately. + t := time.NewTimer(time.Hour) + timerChan.Send(t) + return t + } cc := &testClientConn{target: target, errChan: make(chan error)} + totalTimesCalledError := 0 b := NewBuilder() r, err := b.Build(resolver.Target{Endpoint: target}, cc, resolver.BuildOptions{}) if err != nil { - t.Fatalf("%v\n", err) + t.Fatalf("Error building resolver for target %v: %v", target, err) + } + // Should receive first error. + err = <-cc.errChan + if !strings.Contains(err.Error(), "hostLookup error") { + t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) } + totalTimesCalledError++ + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) defer r.Close() - select { - case err := <-cc.errChan: + + // Cause timer to go off 10 times, and see if it matches DNS Resolver updating Error. + for i := 0; i < 10; i++ { + // Should call ReportError(). + err = <-cc.errChan if !strings.Contains(err.Error(), "hostLookup error") { t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) } - case <-time.After(time.Second): - t.Fatalf("did not receive error after 1s") + totalTimesCalledError++ + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) + } + + if totalTimesCalledError != 11 { + t.Errorf("ReportError() not called 11 times, instead called %d times.", totalTimesCalledError) + } + // Clean up final watcher iteration. + <-cc.errChan + _, err = timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) } } diff --git a/internal/resolver/dns/go113.go b/internal/resolver/dns/go113.go deleted file mode 100644 index 8783a8cf821..00000000000 --- a/internal/resolver/dns/go113.go +++ /dev/null @@ -1,33 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package dns - -import "net" - -func init() { - filterError = func(err error) error { - if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound { - // The name does not exist; not an error. - return nil - } - return err - } -} diff --git a/internal/serviceconfig/serviceconfig.go b/internal/serviceconfig/serviceconfig.go index bd4b8875f1a..badbdbf597f 100644 --- a/internal/serviceconfig/serviceconfig.go +++ b/internal/serviceconfig/serviceconfig.go @@ -46,6 +46,22 @@ type BalancerConfig struct { type intermediateBalancerConfig []map[string]json.RawMessage +// MarshalJSON implements the json.Marshaler interface. +// +// It marshals the balancer and config into a length-1 slice +// ([]map[string]config). +func (bc *BalancerConfig) MarshalJSON() ([]byte, error) { + if bc.Config == nil { + // If config is nil, return empty config `{}`. + return []byte(fmt.Sprintf(`[{%q: %v}]`, bc.Name, "{}")), nil + } + c, err := json.Marshal(bc.Config) + if err != nil { + return nil, err + } + return []byte(fmt.Sprintf(`[{%q: %s}]`, bc.Name, c)), nil +} + // UnmarshalJSON implements the json.Unmarshaler interface. // // ServiceConfig contains a list of loadBalancingConfigs, each with a name and @@ -62,6 +78,7 @@ func (bc *BalancerConfig) UnmarshalJSON(b []byte) error { return err } + var names []string for i, lbcfg := range ir { if len(lbcfg) != 1 { return fmt.Errorf("invalid loadBalancingConfig: entry %v does not contain exactly 1 policy/config pair: %q", i, lbcfg) @@ -76,6 +93,7 @@ func (bc *BalancerConfig) UnmarshalJSON(b []byte) error { for name, jsonCfg = range lbcfg { } + names = append(names, name) builder := balancer.Get(name) if builder == nil { // If the balancer is not registered, move on to the next config. @@ -104,7 +122,7 @@ func (bc *BalancerConfig) UnmarshalJSON(b []byte) error { // return. This means we had a loadBalancingConfig slice but did not // encounter a registered policy. The config is considered invalid in this // case. - return fmt.Errorf("invalid loadBalancingConfig: no supported policies found") + return fmt.Errorf("invalid loadBalancingConfig: no supported policies found in %v", names) } // MethodConfig defines the configuration recommended by the service providers for a diff --git a/internal/serviceconfig/serviceconfig_test.go b/internal/serviceconfig/serviceconfig_test.go index b8abaae027e..770ee2efeb8 100644 --- a/internal/serviceconfig/serviceconfig_test.go +++ b/internal/serviceconfig/serviceconfig_test.go @@ -29,16 +29,18 @@ import ( ) type testBalancerConfigType struct { - externalserviceconfig.LoadBalancingConfig + externalserviceconfig.LoadBalancingConfig `json:"-"` + + Check bool `json:"check"` } -var testBalancerConfig = testBalancerConfigType{} +var testBalancerConfig = testBalancerConfigType{Check: true} const ( testBalancerBuilderName = "test-bb" testBalancerBuilderNotParserName = "test-bb-not-parser" - testBalancerConfigJSON = `{"test-balancer-config":"true"}` + testBalancerConfigJSON = `{"check":true}` ) type testBalancerBuilder struct { @@ -133,3 +135,48 @@ func TestBalancerConfigUnmarshalJSON(t *testing.T) { }) } } + +func TestBalancerConfigMarshalJSON(t *testing.T) { + tests := []struct { + name string + bc BalancerConfig + wantJSON string + }{ + { + name: "OK", + bc: BalancerConfig{ + Name: testBalancerBuilderName, + Config: testBalancerConfig, + }, + wantJSON: fmt.Sprintf(`[{"test-bb": {"check":true}}]`), + }, + { + name: "OK config is nil", + bc: BalancerConfig{ + Name: testBalancerBuilderNotParserName, + Config: nil, // nil should be marshalled to an empty config "{}". + }, + wantJSON: fmt.Sprintf(`[{"test-bb-not-parser": {}}]`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := tt.bc.MarshalJSON() + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + if str := string(b); str != tt.wantJSON { + t.Fatalf("got str %q, want %q", str, tt.wantJSON) + } + + var bc BalancerConfig + if err := bc.UnmarshalJSON(b); err != nil { + t.Errorf("failed to mnmarshal: %v", err) + } + if !cmp.Equal(bc, tt.bc) { + t.Errorf("diff: %v", cmp.Diff(bc, tt.bc)) + } + }) + } +} diff --git a/internal/status/status.go b/internal/status/status.go index 710223b8ded..e5c6513edd1 100644 --- a/internal/status/status.go +++ b/internal/status/status.go @@ -97,7 +97,7 @@ func (s *Status) Err() error { if s.Code() == codes.OK { return nil } - return &Error{e: s.Proto()} + return &Error{s: s} } // WithDetails returns a new status with the provided details messages appended to the status. @@ -136,19 +136,23 @@ func (s *Status) Details() []interface{} { return details } +func (s *Status) String() string { + return fmt.Sprintf("rpc error: code = %s desc = %s", s.Code(), s.Message()) +} + // Error wraps a pointer of a status proto. It implements error and Status, // and a nil *Error should never be returned by this package. type Error struct { - e *spb.Status + s *Status } func (e *Error) Error() string { - return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage()) + return e.s.String() } // GRPCStatus returns the Status represented by se. func (e *Error) GRPCStatus() *Status { - return FromProto(e.e) + return e.s } // Is implements future error.Is functionality. @@ -158,5 +162,5 @@ func (e *Error) Is(target error) bool { if !ok { return false } - return proto.Equal(e.e, tse.e) + return proto.Equal(e.s.s, tse.s.s) } diff --git a/internal/syscall/syscall_linux.go b/internal/syscall/syscall_linux.go index 4b2964f2a1e..b3a72276dee 100644 --- a/internal/syscall/syscall_linux.go +++ b/internal/syscall/syscall_linux.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/syscall/syscall_nonlinux.go b/internal/syscall/syscall_nonlinux.go index 7913ef1dbfb..999f52cd75b 100644 --- a/internal/syscall/syscall_nonlinux.go +++ b/internal/syscall/syscall_nonlinux.go @@ -1,4 +1,5 @@ -// +build !linux appengine +//go:build !linux +// +build !linux /* * @@ -35,41 +36,41 @@ var logger = grpclog.Component("core") func log() { once.Do(func() { - logger.Info("CPU time info is unavailable on non-linux or appengine environment.") + logger.Info("CPU time info is unavailable on non-linux environments.") }) } -// GetCPUTime returns the how much CPU time has passed since the start of this process. -// It always returns 0 under non-linux or appengine environment. +// GetCPUTime returns the how much CPU time has passed since the start of this +// process. It always returns 0 under non-linux environments. func GetCPUTime() int64 { log() return 0 } -// Rusage is an empty struct under non-linux or appengine environment. +// Rusage is an empty struct under non-linux environments. type Rusage struct{} -// GetRusage is a no-op function under non-linux or appengine environment. +// GetRusage is a no-op function under non-linux environments. func GetRusage() *Rusage { log() return nil } // CPUTimeDiff returns the differences of user CPU time and system CPU time used -// between two Rusage structs. It a no-op function for non-linux or appengine environment. +// between two Rusage structs. It a no-op function for non-linux environments. func CPUTimeDiff(first *Rusage, latest *Rusage) (float64, float64) { log() return 0, 0 } -// SetTCPUserTimeout is a no-op function under non-linux or appengine environments +// SetTCPUserTimeout is a no-op function under non-linux environments. func SetTCPUserTimeout(conn net.Conn, timeout time.Duration) error { log() return nil } -// GetTCPUserTimeout is a no-op function under non-linux or appengine environments -// a negative return value indicates the operation is not supported +// GetTCPUserTimeout is a no-op function under non-linux environments. +// A negative return value indicates the operation is not supported func GetTCPUserTimeout(conn net.Conn) (int, error) { log() return -1, nil diff --git a/credentials/tls/certprovider/meshca/logging.go b/internal/testutils/marshal_any.go similarity index 55% rename from credentials/tls/certprovider/meshca/logging.go rename to internal/testutils/marshal_any.go index ae20059c4f7..9ddef6de15d 100644 --- a/credentials/tls/certprovider/meshca/logging.go +++ b/internal/testutils/marshal_any.go @@ -1,8 +1,6 @@ -// +build go1.13 - /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,22 +13,24 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * */ -package meshca +package testutils import ( "fmt" - "google.golang.org/grpc/grpclog" - internalgrpclog "google.golang.org/grpc/internal/grpclog" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/protobuf/types/known/anypb" ) -const prefix = "[%p] " - -var logger = grpclog.Component("meshca") - -func prefixLogger(p *providerPlugin) *internalgrpclog.PrefixLogger { - return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) +// MarshalAny is a convenience function to marshal protobuf messages into any +// protos. It will panic if the marshaling fails. +func MarshalAny(m proto.Message) *anypb.Any { + a, err := ptypes.MarshalAny(m) + if err != nil { + panic(fmt.Sprintf("ptypes.MarshalAny(%+v) failed: %v", m, err)) + } + return a } diff --git a/internal/transport/controlbuf.go b/internal/transport/controlbuf.go index 40ef23923fd..8394d252df0 100644 --- a/internal/transport/controlbuf.go +++ b/internal/transport/controlbuf.go @@ -20,13 +20,17 @@ package transport import ( "bytes" + "errors" "fmt" "runtime" + "strconv" "sync" "sync/atomic" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/status" ) var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { @@ -128,6 +132,15 @@ type cleanupStream struct { func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM +type earlyAbortStream struct { + httpStatus uint32 + streamID uint32 + contentSubtype string + status *status.Status +} + +func (*earlyAbortStream) isTransportResponseFrame() bool { return false } + type dataFrame struct { streamID uint32 endStream bool @@ -284,7 +297,7 @@ type controlBuffer struct { // closed and nilled when transportResponseFrames drops below the // threshold. Both fields are protected by mu. transportResponseFrames int - trfChan atomic.Value // *chan struct{} + trfChan atomic.Value // chan struct{} } func newControlBuffer(done <-chan struct{}) *controlBuffer { @@ -298,10 +311,10 @@ func newControlBuffer(done <-chan struct{}) *controlBuffer { // throttle blocks if there are too many incomingSettings/cleanupStreams in the // controlbuf. func (c *controlBuffer) throttle() { - ch, _ := c.trfChan.Load().(*chan struct{}) + ch, _ := c.trfChan.Load().(chan struct{}) if ch != nil { select { - case <-*ch: + case <-ch: case <-c.done: } } @@ -335,8 +348,7 @@ func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it cbItem) (b if c.transportResponseFrames == maxQueuedTransportResponseFrames { // We are adding the frame that puts us over the threshold; create // a throttling channel. - ch := make(chan struct{}) - c.trfChan.Store(&ch) + c.trfChan.Store(make(chan struct{})) } } c.mu.Unlock() @@ -377,9 +389,9 @@ func (c *controlBuffer) get(block bool) (interface{}, error) { if c.transportResponseFrames == maxQueuedTransportResponseFrames { // We are removing the frame that put us over the // threshold; close and clear the throttling channel. - ch := c.trfChan.Load().(*chan struct{}) - close(*ch) - c.trfChan.Store((*chan struct{})(nil)) + ch := c.trfChan.Load().(chan struct{}) + close(ch) + c.trfChan.Store((chan struct{})(nil)) } c.transportResponseFrames-- } @@ -395,7 +407,6 @@ func (c *controlBuffer) get(block bool) (interface{}, error) { select { case <-c.ch: case <-c.done: - c.finish() return nil, ErrConnClosing } } @@ -420,6 +431,14 @@ func (c *controlBuffer) finish() { hdr.onOrphaned(ErrConnClosing) } } + // In case throttle() is currently in flight, it needs to be unblocked. + // Otherwise, the transport may not close, since the transport is closed by + // the reader encountering the connection error. + ch, _ := c.trfChan.Load().(chan struct{}) + if ch != nil { + close(ch) + } + c.trfChan.Store((chan struct{})(nil)) c.mu.Unlock() } @@ -749,6 +768,27 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { return nil } +func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error { + if l.side == clientSide { + return errors.New("earlyAbortStream not handled on client") + } + // In case the caller forgets to set the http status, default to 200. + if eas.httpStatus == 0 { + eas.httpStatus = 200 + } + headerFields := []hpack.HeaderField{ + {Name: ":status", Value: strconv.Itoa(int(eas.httpStatus))}, + {Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)}, + {Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))}, + {Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())}, + } + + if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil { + return err + } + return nil +} + func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error { if l.side == clientSide { l.draining = true @@ -787,6 +827,8 @@ func (l *loopyWriter) handle(i interface{}) error { return l.registerStreamHandler(i) case *cleanupStream: return l.cleanupStreamHandler(i) + case *earlyAbortStream: + return l.earlyAbortStreamHandler(i) case *incomingGoAway: return l.incomingGoAwayHandler(i) case *dataFrame: diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 05d3871e628..1c3459c2b4c 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -141,9 +141,8 @@ type serverHandlerTransport struct { stats stats.Handler } -func (ht *serverHandlerTransport) Close() error { +func (ht *serverHandlerTransport) Close() { ht.closeOnce.Do(ht.closeCloseChanOnce) - return nil } func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) } diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index f9efdfb0716..b08dcaaf3c4 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -62,7 +62,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { ProtoMajor: 2, Method: "GET", Header: http.Header{}, - RequestURI: "/", }, wantErr: "invalid gRPC request method", }, @@ -74,7 +73,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { Header: http.Header{ "Content-Type": {"application/foo"}, }, - RequestURI: "/service/foo.bar", }, wantErr: "invalid gRPC request content-type", }, @@ -86,7 +84,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { Header: http.Header{ "Content-Type": {"application/grpc"}, }, - RequestURI: "/service/foo.bar", }, modrw: func(w http.ResponseWriter) http.ResponseWriter { // Return w without its Flush method @@ -109,7 +106,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", }, check: func(t *serverHandlerTransport, tt *testCase) error { if t.req != tt.req { @@ -133,7 +129,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", }, check: func(t *serverHandlerTransport, tt *testCase) error { if !t.timeoutSet { @@ -157,7 +152,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", }, wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`, }, @@ -175,7 +169,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", }, check: func(ht *serverHandlerTransport, tt *testCase) error { want := metadata.MD{ @@ -247,8 +240,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", - Body: bodyr, + Body: bodyr, } rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) ht, err := NewServerHandlerTransport(rw, req, nil) @@ -359,8 +351,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", - Body: bodyr, + Body: bodyr, } rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) ht, err := NewServerHandlerTransport(rw, req, nil) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index a76310c6e13..e3203adbd0b 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -24,6 +24,7 @@ import ( "io" "math" "net" + "net/http" "strconv" "strings" "sync" @@ -32,15 +33,14 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" - "google.golang.org/grpc/internal/grpcutil" - imetadata "google.golang.org/grpc/internal/metadata" - "google.golang.org/grpc/internal/transport/networktype" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" + icredentials "google.golang.org/grpc/internal/credentials" + "google.golang.org/grpc/internal/grpcutil" + imetadata "google.golang.org/grpc/internal/metadata" "google.golang.org/grpc/internal/syscall" + "google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -116,6 +116,9 @@ type http2Client struct { // goAwayReason records the http2.ErrCode and debug data received with the // GoAway frame. goAwayReason GoAwayReason + // goAwayDebugMessage contains a detailed human readable string about a + // GoAway frame, useful for error messages. + goAwayDebugMessage string // A condition variable used to signal when the keepalive goroutine should // go dormant. The condition for dormancy is based on the number of active // streams and the `PermitWithoutStream` keepalive client parameter. And @@ -238,9 +241,16 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // Attributes field of resolver.Address, which is shoved into connectCtx // and passed to the credential handshaker. This makes it possible for // address specific arbitrary data to reach the credential handshaker. - contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) - connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) - conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn) + connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) + rawConn := conn + // Pull the deadline from the connectCtx, which will be used for + // timeouts in the authentication protocol handshake. Can ignore the + // boolean as the deadline will return the zero value, which will make + // the conn not timeout on I/O operations. + deadline, _ := connectCtx.Deadline() + rawConn.SetDeadline(deadline) + conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, rawConn) + rawConn.SetDeadline(time.Time{}) if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } @@ -347,12 +357,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // Send connection preface to server. n, err := t.conn.Write(clientPreface) if err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + err = connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + t.Close(err) + return nil, err } if n != len(clientPreface) { - t.Close() - return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + err = connectionErrorf(true, nil, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + t.Close(err) + return nil, err } var ss []http2.Setting @@ -370,14 +382,16 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts } err = t.framer.fr.WriteSettings(ss...) if err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + err = connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + t.Close(err) + return nil, err } // Adjust the connection flow control window if needed. if delta := uint32(icwz - defaultWindowSize); delta > 0 { if err := t.framer.fr.WriteWindowUpdate(0, delta); err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err) + err = connectionErrorf(true, err, "transport: failed to write window update: %v", err) + t.Close(err) + return nil, err } } @@ -403,11 +417,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts logger.Errorf("transport: loopyWriter.run returning. Err: %v", err) } } - // If it's a connection error, let reader goroutine handle it - // since there might be data in the buffers. - if _, ok := err.(net.Error); !ok { - t.conn.Close() - } + // Do not close the transport. Let reader goroutine handle it since + // there might be data in the buffers. + t.conn.Close() + t.controlBuf.finish() close(t.writerDone) }() return t, nil @@ -463,7 +476,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) Method: callHdr.Method, AuthInfo: t.authInfo, } - ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) + ctxWithRequestInfo := icredentials.NewRequestInfoContext(ctx, ri) authData, err := t.getTrAuthData(ctxWithRequestInfo, aud) if err != nil { return nil, err @@ -612,26 +625,35 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call return callAuthData, nil } -// PerformedIOError wraps an error to indicate IO may have been performed -// before the error occurred. -type PerformedIOError struct { +// NewStreamError wraps an error and reports additional information. Typically +// NewStream errors result in transparent retry, as they mean nothing went onto +// the wire. However, there are two notable exceptions: +// +// 1. If the stream headers violate the max header list size allowed by the +// server. In this case there is no reason to retry at all, as it is +// assumed the RPC would continue to fail on subsequent attempts. +// 2. If the credentials errored when requesting their headers. In this case, +// it's possible a retry can fix the problem, but indefinitely transparently +// retrying is not appropriate as it is likely the credentials, if they can +// eventually succeed, would need I/O to do so. +type NewStreamError struct { Err error + + DoNotRetry bool + DoNotTransparentRetry bool } -// Error implements error. -func (p PerformedIOError) Error() string { - return p.Err.Error() +func (e NewStreamError) Error() string { + return e.Err.Error() } // NewStream creates a stream and registers it into the transport as "active" -// streams. +// streams. All non-nil errors returned will be *NewStreamError. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { ctx = peer.NewContext(ctx, t.getPeer()) headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { - // We may have performed I/O in the per-RPC creds callback, so do not - // allow transparent retry. - return nil, PerformedIOError{err} + return nil, &NewStreamError{Err: err, DoNotTransparentRetry: true} } s := t.newStream(ctx, callHdr) cleanup := func(err error) { @@ -731,23 +753,23 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return true }, hdr) if err != nil { - return nil, err + return nil, &NewStreamError{Err: err} } if success { break } if hdrListSizeErr != nil { - return nil, hdrListSizeErr + return nil, &NewStreamError{Err: hdrListSizeErr, DoNotRetry: true} } firstTry = false select { case <-ch: - case <-s.ctx.Done(): - return nil, ContextErr(s.ctx.Err()) + case <-ctx.Done(): + return nil, &NewStreamError{Err: ContextErr(ctx.Err())} case <-t.goAway: - return nil, errStreamDrain + return nil, &NewStreamError{Err: errStreamDrain} case <-t.ctx.Done(): - return nil, ErrConnClosing + return nil, &NewStreamError{Err: ErrConnClosing} } } if t.statsHandler != nil { @@ -854,12 +876,12 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // This method blocks until the addrConn that initiated this transport is // re-connected. This happens because t.onClose() begins reconnect logic at the // addrConn level and blocks until the addrConn is successfully connected. -func (t *http2Client) Close() error { +func (t *http2Client) Close(err error) { t.mu.Lock() // Make sure we only Close once. if t.state == closing { t.mu.Unlock() - return nil + return } // Call t.onClose before setting the state to closing to prevent the client // from attempting to create new streams ASAP. @@ -875,13 +897,25 @@ func (t *http2Client) Close() error { t.mu.Unlock() t.controlBuf.finish() t.cancel() - err := t.conn.Close() + t.conn.Close() if channelz.IsOn() { channelz.RemoveEntry(t.channelzID) } + // Append info about previous goaways if there were any, since this may be important + // for understanding the root cause for this connection to be closed. + _, goAwayDebugMessage := t.GetGoAwayReason() + + var st *status.Status + if len(goAwayDebugMessage) > 0 { + st = status.Newf(codes.Unavailable, "closing transport due to: %v, received prior goaway: %v", err, goAwayDebugMessage) + err = st.Err() + } else { + st = status.New(codes.Unavailable, err.Error()) + } + // Notify all active streams. for _, s := range streams { - t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, status.New(codes.Unavailable, ErrConnClosing.Desc), nil, false) + t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false) } if t.statsHandler != nil { connEnd := &stats.ConnEnd{ @@ -889,7 +923,6 @@ func (t *http2Client) Close() error { } t.statsHandler.HandleConn(t.ctx, connEnd) } - return err } // GracefulClose sets the state to draining, which prevents new streams from @@ -908,7 +941,7 @@ func (t *http2Client) GracefulClose() { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close() + t.Close(ErrConnClosing) return } t.controlBuf.put(&incomingGoAway{}) @@ -1049,7 +1082,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { } // The server has closed the stream without sending trailers. Record that // the read direction is closed, and set the status appropriately. - if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { + if f.StreamEnded() { t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) } } @@ -1154,9 +1187,9 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { } } id := f.LastStreamID - if id > 0 && id%2 != 1 { + if id > 0 && id%2 == 0 { t.mu.Unlock() - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered numbered stream id: %v", id)) return } // A client can receive multiple GoAways from the server (see @@ -1174,7 +1207,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)) return } default: @@ -1204,7 +1237,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) } } @@ -1220,12 +1253,17 @@ func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) { t.goAwayReason = GoAwayTooManyPings } } + if len(f.DebugData()) == 0 { + t.goAwayDebugMessage = fmt.Sprintf("code: %s", f.ErrCode) + } else { + t.goAwayDebugMessage = fmt.Sprintf("code: %s, debug data: %q", f.ErrCode, string(f.DebugData())) + } } -func (t *http2Client) GetGoAwayReason() GoAwayReason { +func (t *http2Client) GetGoAwayReason() (GoAwayReason, string) { t.mu.Lock() defer t.mu.Unlock() - return t.goAwayReason + return t.goAwayReason, t.goAwayDebugMessage } func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { @@ -1252,35 +1290,128 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { return } - state := &decodeState{} - // Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode. - state.data.isGRPC = !initialHeader - if h2code, err := state.decodeHeader(frame); err != nil { - t.closeStream(s, err, true, h2code, status.Convert(err), nil, endStream) + // frame.Truncated is set to true when framer detects that the current header + // list size hits MaxHeaderListSize limit. + if frame.Truncated { + se := status.New(codes.Internal, "peer header list size exceeded limit") + t.closeStream(s, se.Err(), true, http2.ErrCodeFrameSize, se, nil, endStream) return } - isHeader := false - defer func() { - if t.statsHandler != nil { - if isHeader { - inHeader := &stats.InHeader{ - Client: true, - WireLength: int(frame.Header().Length), - Header: s.header.Copy(), - Compression: s.recvCompress, - } - t.statsHandler.HandleRPC(s.ctx, inHeader) - } else { - inTrailer := &stats.InTrailer{ - Client: true, - WireLength: int(frame.Header().Length), - Trailer: s.trailer.Copy(), - } - t.statsHandler.HandleRPC(s.ctx, inTrailer) + var ( + // If a gRPC Response-Headers has already been received, then it means + // that the peer is speaking gRPC and we are in gRPC mode. + isGRPC = !initialHeader + mdata = make(map[string][]string) + contentTypeErr = "malformed header: missing HTTP content-type" + grpcMessage string + statusGen *status.Status + recvCompress string + httpStatusCode *int + httpStatusErr string + rawStatusCode = codes.Unknown + // headerError is set if an error is encountered while parsing the headers + headerError string + ) + + if initialHeader { + httpStatusErr = "malformed header: missing HTTP status" + } + + for _, hf := range frame.Fields { + switch hf.Name { + case "content-type": + if _, validContentType := grpcutil.ContentSubtype(hf.Value); !validContentType { + contentTypeErr = fmt.Sprintf("transport: received unexpected content-type %q", hf.Value) + break + } + contentTypeErr = "" + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) + isGRPC = true + case "grpc-encoding": + recvCompress = hf.Value + case "grpc-status": + code, err := strconv.ParseInt(hf.Value, 10, 32) + if err != nil { + se := status.New(codes.Internal, fmt.Sprintf("transport: malformed grpc-status: %v", err)) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return } + rawStatusCode = codes.Code(uint32(code)) + case "grpc-message": + grpcMessage = decodeGrpcMessage(hf.Value) + case "grpc-status-details-bin": + var err error + statusGen, err = decodeGRPCStatusDetails(hf.Value) + if err != nil { + headerError = fmt.Sprintf("transport: malformed grpc-status-details-bin: %v", err) + } + case ":status": + if hf.Value == "200" { + httpStatusErr = "" + statusCode := 200 + httpStatusCode = &statusCode + break + } + + c, err := strconv.ParseInt(hf.Value, 10, 32) + if err != nil { + se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err)) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return + } + statusCode := int(c) + httpStatusCode = &statusCode + + httpStatusErr = fmt.Sprintf( + "unexpected HTTP status code received from server: %d (%s)", + statusCode, + http.StatusText(statusCode), + ) + default: + if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { + break + } + v, err := decodeMetadataHeader(hf.Name, hf.Value) + if err != nil { + headerError = fmt.Sprintf("transport: malformed %s: %v", hf.Name, err) + logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err) + break + } + mdata[hf.Name] = append(mdata[hf.Name], v) } - }() + } + + if !isGRPC || httpStatusErr != "" { + var code = codes.Internal // when header does not include HTTP status, return INTERNAL + + if httpStatusCode != nil { + var ok bool + code, ok = HTTPStatusConvTab[*httpStatusCode] + if !ok { + code = codes.Unknown + } + } + var errs []string + if httpStatusErr != "" { + errs = append(errs, httpStatusErr) + } + if contentTypeErr != "" { + errs = append(errs, contentTypeErr) + } + // Verify the HTTP response is a 200. + se := status.New(code, strings.Join(errs, "; ")) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return + } + + if headerError != "" { + se := status.New(codes.Internal, headerError) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return + } + + isHeader := false // If headerChan hasn't been closed yet if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { @@ -1291,9 +1422,9 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { // These values can be set without any synchronization because // stream goroutine will read it only after seeing a closed // headerChan which we'll close after setting this. - s.recvCompress = state.data.encoding - if len(state.data.mdata) > 0 { - s.header = state.data.mdata + s.recvCompress = recvCompress + if len(mdata) > 0 { + s.header = mdata } } else { // HEADERS frame block carries a Trailers-Only. @@ -1302,13 +1433,36 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { close(s.headerChan) } + if t.statsHandler != nil { + if isHeader { + inHeader := &stats.InHeader{ + Client: true, + WireLength: int(frame.Header().Length), + Header: metadata.MD(mdata).Copy(), + Compression: s.recvCompress, + } + t.statsHandler.HandleRPC(s.ctx, inHeader) + } else { + inTrailer := &stats.InTrailer{ + Client: true, + WireLength: int(frame.Header().Length), + Trailer: metadata.MD(mdata).Copy(), + } + t.statsHandler.HandleRPC(s.ctx, inTrailer) + } + } + if !endStream { return } + if statusGen == nil { + statusGen = status.New(rawStatusCode, grpcMessage) + } + // if client received END_STREAM from server while stream was still active, send RST_STREAM rst := s.getState() == streamActive - t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true) + t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, statusGen, mdata, true) } // reader runs as a separate goroutine in charge of reading data from network @@ -1322,7 +1476,8 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.fr.ReadFrame() if err != nil { - t.Close() // this kicks off resetTransport, so must be last before return + err = connectionErrorf(true, err, "error reading server preface: %v", err) + t.Close(err) // this kicks off resetTransport, so must be last before return return } t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!) @@ -1331,7 +1486,8 @@ func (t *http2Client) reader() { } sf, ok := frame.(*http2.SettingsFrame) if !ok { - t.Close() // this kicks off resetTransport, so must be last before return + // this kicks off resetTransport, so must be last before return + t.Close(connectionErrorf(true, nil, "initial http2 frame from server is not a settings frame: %T", frame)) return } t.onPrefaceReceipt() @@ -1367,7 +1523,7 @@ func (t *http2Client) reader() { continue } else { // Transport error. - t.Close() + t.Close(connectionErrorf(true, err, "error reading from server: %v", err)) return } } @@ -1426,7 +1582,7 @@ func (t *http2Client) keepalive() { continue } if outstandingPing && timeoutLeft <= 0 { - t.Close() + t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")) return } t.mu.Lock() diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 7c6c89d4f9b..f2cad9ebc31 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -102,11 +102,11 @@ type http2Server struct { mu sync.Mutex // guard the following - // drainChan is initialized when drain(...) is called the first time. + // drainChan is initialized when Drain() is called the first time. // After which the server writes out the first GoAway(with ID 2^31-1) frame. // Then an independent goroutine will be launched to later send the second GoAway. // During this time we don't want to write another first GoAway(with ID 2^31 -1) frame. - // Thus call to drain(...) will be a no-op if drainChan is already initialized since draining is + // Thus call to Drain() will be a no-op if drainChan is already initialized since draining is // already underway. drainChan chan struct{} state transportState @@ -125,9 +125,30 @@ type http2Server struct { connectionID uint64 } -// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is -// returned if something goes wrong. -func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { +// NewServerTransport creates a http2 transport with conn and configuration +// options from config. +// +// It returns a non-nil transport and a nil error on success. On failure, it +// returns a nil transport and a non-nil error. For a special case where the +// underlying conn gets closed before the client preface could be read, it +// returns a nil transport and a nil error. +func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { + var authInfo credentials.AuthInfo + rawConn := conn + if config.Credentials != nil { + var err error + conn, authInfo, err = config.Credentials.ServerHandshake(rawConn) + if err != nil { + // ErrConnDispatched means that the connection was dispatched away + // from gRPC; those connections should be left open. io.EOF means + // the connection was closed before handshaking completed, which can + // happen naturally from probers. Return these errors directly. + if err == credentials.ErrConnDispatched || err == io.EOF { + return nil, err + } + return nil, connectionErrorf(false, err, "ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) + } + } writeBufSize := config.WriteBufferSize readBufSize := config.ReadBufferSize maxHeaderListSize := defaultServerMaxHeaderListSize @@ -210,14 +231,15 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err if kep.MinTime == 0 { kep.MinTime = defaultKeepalivePolicyMinTime } + done := make(chan struct{}) t := &http2Server{ - ctx: context.Background(), + ctx: setConnection(context.Background(), rawConn), done: done, conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), - authInfo: config.AuthInfo, + authInfo: authInfo, framer: framer, readerDone: make(chan struct{}), writerDone: make(chan struct{}), @@ -266,6 +288,14 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // Check the validity of client preface. preface := make([]byte, len(clientPreface)) if _, err := io.ReadFull(t.conn, preface); err != nil { + // In deployments where a gRPC server runs behind a cloud load balancer + // which performs regular TCP level health checks, the connection is + // closed immediately by the latter. Returning io.EOF here allows the + // grpc server implementation to recognize this scenario and suppress + // logging to reduce spam. + if err == io.EOF { + return nil, io.EOF + } return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to receive the preface from client: %v", err) } if !bytes.Equal(preface, clientPreface) { @@ -295,6 +325,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err } } t.conn.Close() + t.controlBuf.finish() close(t.writerDone) }() go t.keepalive() @@ -304,37 +335,131 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) { streamID := frame.Header().StreamID - state := &decodeState{ - serverSide: true, - } - if h2code, err := state.decodeHeader(frame); err != nil { - if _, ok := status.FromError(err); ok { - t.controlBuf.put(&cleanupStream{ - streamID: streamID, - rst: true, - rstCode: h2code, - onWrite: func() {}, - }) - } + + // frame.Truncated is set to true when framer detects that the current header + // list size hits MaxHeaderListSize limit. + if frame.Truncated { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeFrameSize, + onWrite: func() {}, + }) return false } buf := newRecvBuffer() s := &Stream{ - id: streamID, - st: t, - buf: buf, - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - recvCompress: state.data.encoding, - method: state.data.method, - contentSubtype: state.data.contentSubtype, + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + } + + var ( + // If a gRPC Response-Headers has already been received, then it means + // that the peer is speaking gRPC and we are in gRPC mode. + isGRPC = false + mdata = make(map[string][]string) + httpMethod string + // headerError is set if an error is encountered while parsing the headers + headerError bool + + timeoutSet bool + timeout time.Duration + ) + + for _, hf := range frame.Fields { + switch hf.Name { + case "content-type": + contentSubtype, validContentType := grpcutil.ContentSubtype(hf.Value) + if !validContentType { + break + } + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) + s.contentSubtype = contentSubtype + isGRPC = true + case "grpc-encoding": + s.recvCompress = hf.Value + case ":method": + httpMethod = hf.Value + case ":path": + s.method = hf.Value + case "grpc-timeout": + timeoutSet = true + var err error + if timeout, err = decodeTimeout(hf.Value); err != nil { + headerError = true + } + // "Transports must consider requests containing the Connection header + // as malformed." - A41 + case "connection": + if logger.V(logLevel) { + logger.Errorf("transport: http2Server.operateHeaders parsed a :connection header which makes a request malformed as per the HTTP/2 spec") + } + headerError = true + default: + if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { + break + } + v, err := decodeMetadataHeader(hf.Name, hf.Value) + if err != nil { + headerError = true + logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err) + break + } + mdata[hf.Name] = append(mdata[hf.Name], v) + } + } + + // "If multiple Host headers or multiple :authority headers are present, the + // request must be rejected with an HTTP status code 400 as required by Host + // validation in RFC 7230 §5.4, gRPC status code INTERNAL, or RST_STREAM + // with HTTP/2 error code PROTOCOL_ERROR." - A41. Since this is a HTTP/2 + // error, this takes precedence over a client not speaking gRPC. + if len(mdata[":authority"]) > 1 || len(mdata["host"]) > 1 { + errMsg := fmt.Sprintf("num values of :authority: %v, num values of host: %v, both must only have 1 value as per HTTP/2 spec", len(mdata[":authority"]), len(mdata["host"])) + if logger.V(logLevel) { + logger.Errorf("transport: %v", errMsg) + } + t.controlBuf.put(&earlyAbortStream{ + httpStatus: 400, + streamID: streamID, + contentSubtype: s.contentSubtype, + status: status.New(codes.Internal, errMsg), + }) + return false + } + + if !isGRPC || headerError { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeProtocol, + onWrite: func() {}, + }) + return false + } + + // "If :authority is missing, Host must be renamed to :authority." - A41 + if len(mdata[":authority"]) == 0 { + // No-op if host isn't present, no eventual :authority header is a valid + // RPC. + if host, ok := mdata["host"]; ok { + mdata[":authority"] = host + delete(mdata, "host") + } + } else { + // "If :authority is present, Host must be discarded" - A41 + delete(mdata, "host") } + if frame.StreamEnded() { // s is just created by the caller. No lock needed. s.state = streamReadDone } - if state.data.timeoutSet { - s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout) + if timeoutSet { + s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout) } else { s.ctx, s.cancel = context.WithCancel(t.ctx) } @@ -347,33 +472,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } s.ctx = peer.NewContext(s.ctx, pr) // Attach the received metadata to the context. - if len(state.data.mdata) > 0 { - s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata) - } - if state.data.statsTags != nil { - s.ctx = stats.SetIncomingTags(s.ctx, state.data.statsTags) - } - if state.data.statsTrace != nil { - s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace) - } - if t.inTapHandle != nil { - var err error - info := &tap.Info{ - FullMethodName: state.data.method, + if len(mdata) > 0 { + s.ctx = metadata.NewIncomingContext(s.ctx, mdata) + if statsTags := mdata["grpc-tags-bin"]; len(statsTags) > 0 { + s.ctx = stats.SetIncomingTags(s.ctx, []byte(statsTags[len(statsTags)-1])) } - s.ctx, err = t.inTapHandle(s.ctx, info) - if err != nil { - if logger.V(logLevel) { - logger.Warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) - } - t.controlBuf.put(&cleanupStream{ - streamID: s.id, - rst: true, - rstCode: http2.ErrCodeRefusedStream, - onWrite: func() {}, - }) - s.cancel() - return false + if statsTrace := mdata["grpc-trace-bin"]; len(statsTrace) > 0 { + s.ctx = stats.SetIncomingTrace(s.ctx, []byte(statsTrace[len(statsTrace)-1])) } } t.mu.Lock() @@ -403,10 +508,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return true } t.maxStreamID = streamID - if state.data.httpMethod != http.MethodPost { + if httpMethod != http.MethodPost { t.mu.Unlock() if logger.V(logLevel) { - logger.Warningf("transport: http2Server.operateHeaders parsed a :method field: %v which should be POST", state.data.httpMethod) + logger.Infof("transport: http2Server.operateHeaders parsed a :method field: %v which should be POST", httpMethod) } t.controlBuf.put(&cleanupStream{ streamID: streamID, @@ -417,6 +522,26 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.cancel() return false } + if t.inTapHandle != nil { + var err error + if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: s.method}); err != nil { + t.mu.Unlock() + if logger.V(logLevel) { + logger.Infof("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) + } + stat, ok := status.FromError(err) + if !ok { + stat = status.New(codes.PermissionDenied, err.Error()) + } + t.controlBuf.put(&earlyAbortStream{ + httpStatus: 200, + streamID: s.id, + contentSubtype: s.contentSubtype, + status: stat, + }) + return false + } + } t.activeStreams[streamID] = s if len(t.activeStreams) == 1 { t.idle = time.Time{} @@ -438,7 +563,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( LocalAddr: t.localAddr, Compression: s.recvCompress, WireLength: int(frame.Header().Length), - Header: metadata.MD(state.data.mdata).Copy(), + Header: metadata.MD(mdata).Copy(), } t.stats.HandleRPC(s.ctx, inHeader) } @@ -650,7 +775,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { s.write(recvMsg{buffer: buffer}) } } - if f.Header().Flags.Has(http2.FlagDataEndStream) { + if f.StreamEnded() { // Received the end of stream from the client. s.compareAndSwapState(streamActive, streamReadDone) s.write(recvMsg{err: io.EOF}) @@ -1005,12 +1130,12 @@ func (t *http2Server) keepalive() { if val <= 0 { // The connection has been idle for a duration of keepalive.MaxConnectionIdle or more. // Gracefully close the connection. - t.drain(http2.ErrCodeNo, []byte{}) + t.Drain() return } idleTimer.Reset(val) case <-ageTimer.C: - t.drain(http2.ErrCodeNo, []byte{}) + t.Drain() ageTimer.Reset(t.kp.MaxConnectionAgeGrace) select { case <-ageTimer.C: @@ -1064,11 +1189,11 @@ func (t *http2Server) keepalive() { // Close starts shutting down the http2Server transport. // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. -func (t *http2Server) Close() error { +func (t *http2Server) Close() { t.mu.Lock() if t.state == closing { t.mu.Unlock() - return errors.New("transport: Close() was already called") + return } t.state = closing streams := t.activeStreams @@ -1076,7 +1201,9 @@ func (t *http2Server) Close() error { t.mu.Unlock() t.controlBuf.finish() close(t.done) - err := t.conn.Close() + if err := t.conn.Close(); err != nil && logger.V(logLevel) { + logger.Infof("transport: error closing conn during Close: %v", err) + } if channelz.IsOn() { channelz.RemoveEntry(t.channelzID) } @@ -1088,7 +1215,6 @@ func (t *http2Server) Close() error { connEnd := &stats.ConnEnd{} t.stats.HandleConn(t.ctx, connEnd) } - return err } // deleteStream deletes the stream s from transport's active streams. @@ -1153,17 +1279,13 @@ func (t *http2Server) RemoteAddr() net.Addr { } func (t *http2Server) Drain() { - t.drain(http2.ErrCodeNo, []byte{}) -} - -func (t *http2Server) drain(code http2.ErrCode, debugData []byte) { t.mu.Lock() defer t.mu.Unlock() if t.drainChan != nil { return } t.drainChan = make(chan struct{}) - t.controlBuf.put(&goAway{code: code, debugData: debugData, headsUp: true}) + t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte{}, headsUp: true}) } var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} @@ -1281,3 +1403,18 @@ func getJitter(v time.Duration) time.Duration { j := grpcrand.Int63n(2*r) - r return time.Duration(j) } + +type connectionKey struct{} + +// GetConnection gets the connection from the context. +func GetConnection(ctx context.Context) net.Conn { + conn, _ := ctx.Value(connectionKey{}).(net.Conn) + return conn +} + +// SetConnection adds the connection to the context to be able to get +// information about the destination ip and port for an incoming RPC. This also +// allows any unary or streaming interceptors to see the connection. +func setConnection(ctx context.Context, conn net.Conn) context.Context { + return context.WithValue(ctx, connectionKey{}, conn) +} diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index c7dee140cf1..d8247bcdf69 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -39,7 +39,6 @@ import ( spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/status" ) @@ -96,53 +95,6 @@ var ( logger = grpclog.Component("transport") ) -type parsedHeaderData struct { - encoding string - // statusGen caches the stream status received from the trailer the server - // sent. Client side only. Do not access directly. After all trailers are - // parsed, use the status method to retrieve the status. - statusGen *status.Status - // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not - // intended for direct access outside of parsing. - rawStatusCode *int - rawStatusMsg string - httpStatus *int - // Server side only fields. - timeoutSet bool - timeout time.Duration - method string - httpMethod string - // key-value metadata map from the peer. - mdata map[string][]string - statsTags []byte - statsTrace []byte - contentSubtype string - - // isGRPC field indicates whether the peer is speaking gRPC (otherwise HTTP). - // - // We are in gRPC mode (peer speaking gRPC) if: - // * We are client side and have already received a HEADER frame that indicates gRPC peer. - // * The header contains valid a content-type, i.e. a string starts with "application/grpc" - // And we should handle error specific to gRPC. - // - // Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we - // are in HTTP fallback mode, and should handle error specific to HTTP. - isGRPC bool - grpcErr error - httpErr error - contentTypeErr string -} - -// decodeState configures decoding criteria and records the decoded data. -type decodeState struct { - // whether decoding on server side or not - serverSide bool - - // Records the states during HPACK decoding. It will be filled with info parsed from HTTP HEADERS - // frame once decodeHeader function has been invoked and returned. - data parsedHeaderData -} - // isReservedHeader checks whether hdr belongs to HTTP2 headers // reserved by gRPC protocol. Any other headers are classified as the // user-specified metadata. @@ -180,14 +132,6 @@ func isWhitelistedHeader(hdr string) bool { } } -func (d *decodeState) status() *status.Status { - if d.data.statusGen == nil { - // No status-details were provided; generate status using code/msg. - d.data.statusGen = status.New(codes.Code(int32(*(d.data.rawStatusCode))), d.data.rawStatusMsg) - } - return d.data.statusGen -} - const binHdrSuffix = "-bin" func encodeBinHeader(v []byte) string { @@ -217,168 +161,16 @@ func decodeMetadataHeader(k, v string) (string, error) { return v, nil } -func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) (http2.ErrCode, error) { - // frame.Truncated is set to true when framer detects that the current header - // list size hits MaxHeaderListSize limit. - if frame.Truncated { - return http2.ErrCodeFrameSize, status.Error(codes.Internal, "peer header list size exceeded limit") - } - - for _, hf := range frame.Fields { - d.processHeaderField(hf) - } - - if d.data.isGRPC { - if d.data.grpcErr != nil { - return http2.ErrCodeProtocol, d.data.grpcErr - } - if d.serverSide { - return http2.ErrCodeNo, nil - } - if d.data.rawStatusCode == nil && d.data.statusGen == nil { - // gRPC status doesn't exist. - // Set rawStatusCode to be unknown and return nil error. - // So that, if the stream has ended this Unknown status - // will be propagated to the user. - // Otherwise, it will be ignored. In which case, status from - // a later trailer, that has StreamEnded flag set, is propagated. - code := int(codes.Unknown) - d.data.rawStatusCode = &code - } - return http2.ErrCodeNo, nil - } - - // HTTP fallback mode - if d.data.httpErr != nil { - return http2.ErrCodeProtocol, d.data.httpErr - } - - var ( - code = codes.Internal // when header does not include HTTP status, return INTERNAL - ok bool - ) - - if d.data.httpStatus != nil { - code, ok = HTTPStatusConvTab[*(d.data.httpStatus)] - if !ok { - code = codes.Unknown - } - } - - return http2.ErrCodeProtocol, status.Error(code, d.constructHTTPErrMsg()) -} - -// constructErrMsg constructs error message to be returned in HTTP fallback mode. -// Format: HTTP status code and its corresponding message + content-type error message. -func (d *decodeState) constructHTTPErrMsg() string { - var errMsgs []string - - if d.data.httpStatus == nil { - errMsgs = append(errMsgs, "malformed header: missing HTTP status") - } else { - errMsgs = append(errMsgs, fmt.Sprintf("%s: HTTP status code %d", http.StatusText(*(d.data.httpStatus)), *d.data.httpStatus)) - } - - if d.data.contentTypeErr == "" { - errMsgs = append(errMsgs, "transport: missing content-type field") - } else { - errMsgs = append(errMsgs, d.data.contentTypeErr) - } - - return strings.Join(errMsgs, "; ") -} - -func (d *decodeState) addMetadata(k, v string) { - if d.data.mdata == nil { - d.data.mdata = make(map[string][]string) +func decodeGRPCStatusDetails(rawDetails string) (*status.Status, error) { + v, err := decodeBinHeader(rawDetails) + if err != nil { + return nil, err } - d.data.mdata[k] = append(d.data.mdata[k], v) -} - -func (d *decodeState) processHeaderField(f hpack.HeaderField) { - switch f.Name { - case "content-type": - contentSubtype, validContentType := grpcutil.ContentSubtype(f.Value) - if !validContentType { - d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value) - return - } - d.data.contentSubtype = contentSubtype - // TODO: do we want to propagate the whole content-type in the metadata, - // or come up with a way to just propagate the content-subtype if it was set? - // ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"} - // in the metadata? - d.addMetadata(f.Name, f.Value) - d.data.isGRPC = true - case "grpc-encoding": - d.data.encoding = f.Value - case "grpc-status": - code, err := strconv.Atoi(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err) - return - } - d.data.rawStatusCode = &code - case "grpc-message": - d.data.rawStatusMsg = decodeGrpcMessage(f.Value) - case "grpc-status-details-bin": - v, err := decodeBinHeader(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) - return - } - s := &spb.Status{} - if err := proto.Unmarshal(v, s); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) - return - } - d.data.statusGen = status.FromProto(s) - case "grpc-timeout": - d.data.timeoutSet = true - var err error - if d.data.timeout, err = decodeTimeout(f.Value); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err) - } - case ":path": - d.data.method = f.Value - case ":status": - code, err := strconv.Atoi(f.Value) - if err != nil { - d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err) - return - } - d.data.httpStatus = &code - case "grpc-tags-bin": - v, err := decodeBinHeader(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) - return - } - d.data.statsTags = v - d.addMetadata(f.Name, string(v)) - case "grpc-trace-bin": - v, err := decodeBinHeader(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) - return - } - d.data.statsTrace = v - d.addMetadata(f.Name, string(v)) - case ":method": - d.data.httpMethod = f.Value - default: - if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) { - break - } - v, err := decodeMetadataHeader(f.Name, f.Value) - if err != nil { - if logger.V(logLevel) { - logger.Errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err) - } - return - } - d.addMetadata(f.Name, v) + st := &spb.Status{} + if err = proto.Unmarshal(v, st); err != nil { + return nil, err } + return status.FromProto(st), nil } type timeoutUnit uint8 diff --git a/internal/transport/http_util_test.go b/internal/transport/http_util_test.go index 2205050acea..bbd53180471 100644 --- a/internal/transport/http_util_test.go +++ b/internal/transport/http_util_test.go @@ -23,9 +23,6 @@ import ( "reflect" "testing" "time" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/hpack" ) func (s) TestTimeoutDecode(t *testing.T) { @@ -189,68 +186,6 @@ func (s) TestDecodeMetadataHeader(t *testing.T) { } } -func (s) TestDecodeHeaderH2ErrCode(t *testing.T) { - for _, test := range []struct { - name string - // input - metaHeaderFrame *http2.MetaHeadersFrame - serverSide bool - // output - wantCode http2.ErrCode - }{ - { - name: "valid header", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/grpc"}, - }}, - wantCode: http2.ErrCodeNo, - }, - { - name: "valid header serverSide", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/grpc"}, - }}, - serverSide: true, - wantCode: http2.ErrCodeNo, - }, - { - name: "invalid grpc status header field", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/grpc"}, - {Name: "grpc-status", Value: "xxxx"}, - }}, - wantCode: http2.ErrCodeProtocol, - }, - { - name: "invalid http content type", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/json"}, - }}, - wantCode: http2.ErrCodeProtocol, - }, - { - name: "http fallback and invalid http status", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - // No content type provided then fallback into handling http error. - {Name: ":status", Value: "xxxx"}, - }}, - wantCode: http2.ErrCodeProtocol, - }, - { - name: "http2 frame size exceeds", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: nil, Truncated: true}, - wantCode: http2.ErrCodeFrameSize, - }, - } { - t.Run(test.name, func(t *testing.T) { - state := &decodeState{serverSide: test.serverSide} - if h2code, _ := state.decodeHeader(test.metaHeaderFrame); h2code != test.wantCode { - t.Fatalf("decodeState.decodeHeader(%v) = %v, want %v", test.metaHeaderFrame, h2code, test.wantCode) - } - }) - } -} - func (s) TestParseDialTarget(t *testing.T) { for _, test := range []struct { target, wantNet, wantAddr string diff --git a/internal/transport/keepalive_test.go b/internal/transport/keepalive_test.go index c8f177fecf1..c4021925f32 100644 --- a/internal/transport/keepalive_test.go +++ b/internal/transport/keepalive_test.go @@ -24,6 +24,7 @@ package transport import ( "context" + "fmt" "io" "net" "testing" @@ -47,7 +48,7 @@ func (s) TestMaxConnectionIdle(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -68,7 +69,7 @@ func (s) TestMaxConnectionIdle(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayNoReason { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayNoReason { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayNoReason) } case <-timeout.C: @@ -86,7 +87,7 @@ func (s) TestMaxConnectionIdleBusyClient(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -122,7 +123,7 @@ func (s) TestMaxConnectionAge(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -142,7 +143,7 @@ func (s) TestMaxConnectionAge(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayNoReason { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayNoReason { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayNoReason) } case <-timeout.C: @@ -169,7 +170,7 @@ func (s) TestKeepaliveServerClosesUnresponsiveClient(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -192,7 +193,7 @@ func (s) TestKeepaliveServerClosesUnresponsiveClient(t *testing.T) { // We read from the net.Conn till we get an error, which is expected when // the server closes the connection as part of the keepalive logic. - errCh := make(chan error) + errCh := make(chan error, 1) go func() { b := make([]byte, 24) for { @@ -228,7 +229,7 @@ func (s) TestKeepaliveServerWithResponsiveClient(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -257,7 +258,7 @@ func (s) TestKeepaliveClientClosesUnresponsiveServer(t *testing.T) { PermitWithoutStream: true, }}, connCh) defer cancel() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) conn, ok := <-connCh if !ok { @@ -288,7 +289,7 @@ func (s) TestKeepaliveClientOpenWithUnresponsiveServer(t *testing.T) { Timeout: 1 * time.Second, }}, connCh) defer cancel() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) conn, ok := <-connCh if !ok { @@ -317,7 +318,7 @@ func (s) TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { Timeout: 1 * time.Second, }}, connCh) defer cancel() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) conn, ok := <-connCh if !ok { @@ -345,14 +346,21 @@ func (s) TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { // responds to keepalive pings, and makes sure than a client transport stays // healthy without any active streams. func (s) TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { - server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, normal, ConnectOptions{ - KeepaliveParams: keepalive.ClientParameters{ - Time: 1 * time.Second, - Timeout: 1 * time.Second, - PermitWithoutStream: true, - }}) + server, client, cancel := setUpWithOptions(t, 0, + &ServerConfig{ + KeepalivePolicy: keepalive.EnforcementPolicy{ + PermitWithoutStream: true, + }, + }, + normal, + ConnectOptions{ + KeepaliveParams: keepalive.ClientParameters{ + Time: 1 * time.Second, + Timeout: 1 * time.Second, + PermitWithoutStream: true, + }}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -391,7 +399,7 @@ func (s) TestKeepaliveClientFrequency(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -402,7 +410,7 @@ func (s) TestKeepaliveClientFrequency(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayTooManyPings { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayTooManyPings { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayTooManyPings) } case <-timeout.C: @@ -436,7 +444,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -447,7 +455,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayTooManyPings { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayTooManyPings { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayTooManyPings) } case <-timeout.C: @@ -480,7 +488,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -497,7 +505,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayTooManyPings { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayTooManyPings { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayTooManyPings) } case <-timeout.C: @@ -530,7 +538,7 @@ func (s) TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -564,7 +572,7 @@ func (s) TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -604,7 +612,7 @@ func (s) TestKeepaliveServerEnforcementWithDormantKeepaliveOnClient(t *testing.T } server, client, cancel := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -658,7 +666,7 @@ func (s) TestTCPUserTimeout(t *testing.T) { }, ) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() diff --git a/internal/transport/networktype/networktype.go b/internal/transport/networktype/networktype.go index 96967428b51..7bb53cff101 100644 --- a/internal/transport/networktype/networktype.go +++ b/internal/transport/networktype/networktype.go @@ -17,7 +17,7 @@ */ // Package networktype declares the network type to be used in the default -// dailer. Attribute of a resolver.Address. +// dialer. Attribute of a resolver.Address. package networktype import ( diff --git a/internal/transport/proxy_test.go b/internal/transport/proxy_test.go index a2f1aa43854..404354a19db 100644 --- a/internal/transport/proxy_test.go +++ b/internal/transport/proxy_test.go @@ -1,3 +1,4 @@ +//go:build !race // +build !race /* @@ -119,7 +120,7 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy msg := []byte{4, 3, 5, 2} recvBuf := make([]byte, len(msg)) - done := make(chan error) + done := make(chan error, 1) go func() { in, err := blis.Accept() if err != nil { diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 5cf7c5f80fe..d3bf65b2bdf 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -30,6 +30,7 @@ import ( "net" "sync" "sync/atomic" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -518,7 +519,8 @@ const ( // ServerConfig consists of all the configurations to establish a server transport. type ServerConfig struct { MaxStreams uint32 - AuthInfo credentials.AuthInfo + ConnectionTimeout time.Duration + Credentials credentials.TransportCredentials InTapHandle tap.ServerInHandle StatsHandler stats.Handler KeepaliveParams keepalive.ServerParameters @@ -532,12 +534,6 @@ type ServerConfig struct { HeaderTableSize *uint32 } -// NewServerTransport creates a ServerTransport with conn or non-nil error -// if it fails. -func NewServerTransport(protocol string, conn net.Conn, config *ServerConfig) (ServerTransport, error) { - return newHTTP2Server(conn, config) -} - // ConnectOptions covers all relevant options for communicating with the server. type ConnectOptions struct { // UserAgent is the application user agent. @@ -622,7 +618,7 @@ type ClientTransport interface { // Close tears down this transport. Once it returns, the transport // should not be accessed any more. The caller must make sure this // is called only once. - Close() error + Close(err error) // GracefulClose starts to tear down the transport: the transport will stop // accepting new RPCs and NewStream will return error. Once all streams are @@ -656,8 +652,9 @@ type ClientTransport interface { // HTTP/2). GoAway() <-chan struct{} - // GetGoAwayReason returns the reason why GoAway frame was received. - GetGoAwayReason() GoAwayReason + // GetGoAwayReason returns the reason why GoAway frame was received, along + // with a human readable string with debug info. + GetGoAwayReason() (GoAwayReason, string) // RemoteAddr returns the remote network address. RemoteAddr() net.Addr @@ -693,7 +690,7 @@ type ServerTransport interface { // Close tears down the transport. Once it is called, the transport // should not be accessed any more. All the pending streams and their // handlers will be terminated asynchronously. - Close() error + Close() // RemoteAddr returns the remote network address. RemoteAddr() net.Addr diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 1d8d3ed355d..4e561a73c4c 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -323,7 +323,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT if err != nil { return } - transport, err := NewServerTransport("http2", conn, serverConfig) + transport, err := NewServerTransport(conn, serverConfig) if err != nil { return } @@ -481,7 +481,7 @@ func (s) TestInflightStreamClosing(t *testing.T) { server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -550,7 +550,7 @@ func (s) TestClientSendAndReceive(t *testing.T) { if recvErr != io.EOF { t.Fatalf("Error: %v; want ", recvErr) } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -560,7 +560,7 @@ func (s) TestClientErrorNotify(t *testing.T) { go server.stop() // ct.reader should detect the error and activate ct.Error(). <-ct.Error() - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) } func performOneRPC(ct ClientTransport) { @@ -597,7 +597,7 @@ func (s) TestClientMix(t *testing.T) { }(s) go func(ct ClientTransport) { <-ct.Error() - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) }(ct) for i := 0; i < 1000; i++ { time.Sleep(10 * time.Millisecond) @@ -636,7 +636,7 @@ func (s) TestLargeMessage(t *testing.T) { }() } wg.Wait() - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -653,7 +653,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) { server, ct, cancel := setUpWithOptions(t, 0, sc, delayRead, co) defer cancel() defer server.stop() - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) server.mu.Lock() ready := server.ready server.mu.Unlock() @@ -780,7 +780,7 @@ func (s) TestGracefulClose(t *testing.T) { go func() { defer wg.Done() str, err := ct.NewStream(ctx, &CallHdr{}) - if err == ErrConnClosing { + if err != nil && err.(*NewStreamError).Err == ErrConnClosing { return } else if err != nil { t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) @@ -831,7 +831,7 @@ func (s) TestLargeMessageSuspension(t *testing.T) { if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -841,7 +841,7 @@ func (s) TestMaxStreams(t *testing.T) { } server, ct, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer cancel() - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) defer server.stop() callHdr := &CallHdr{ Host: "localhost", @@ -901,7 +901,7 @@ func (s) TestMaxStreams(t *testing.T) { // Close the first stream created so that the new stream can finally be created. ct.CloseStream(s, nil) <-done - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) <-ct.writerDone if ct.maxConcurrentStreams != 1 { t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams) @@ -960,7 +960,7 @@ func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) { sc.mu.Unlock() break } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) select { case <-ss.Context().Done(): if ss.Context().Err() != context.Canceled { @@ -980,7 +980,7 @@ func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) { server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() @@ -1069,7 +1069,7 @@ func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) { server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() defer server.mu.Unlock() @@ -1302,7 +1302,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) { if err != nil { t.Fatalf("Error while creating client transport: %v", err) } - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) str, err := ct.NewStream(connectCtx, &CallHdr{}) if err != nil { t.Fatalf("Error while creating stream: %v", err) @@ -1345,7 +1345,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) { if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) { t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -1367,7 +1367,7 @@ func (s) TestInvalidHeaderField(t *testing.T) { if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) { t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField) } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -1375,7 +1375,7 @@ func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) defer cancel() defer server.stop() - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}) @@ -1481,7 +1481,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) server, client, cancel := setUpWithOptions(t, 0, sc, pingpong, co) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() defer server.mu.Unlock() @@ -1563,7 +1563,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) } // Close down both server and client so that their internals can be read without data // races. - client.Close() + client.Close(fmt.Errorf("closed manually by test")) st.Close() <-st.readerDone <-st.writerDone @@ -1663,81 +1663,291 @@ func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { } } -// If the client sends an HTTP/2 request with a :method header with a value other than POST, as specified in -// the gRPC over HTTP/2 specification, the server should close the stream. -func (s) TestServerWithClientSendingWrongMethod(t *testing.T) { - server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) - defer server.stop() - // Create a client directly to not couple what you can send to API of http2_client.go. - mconn, err := net.Dial("tcp", server.lis.Addr().String()) - if err != nil { - t.Fatalf("Client failed to dial: %v", err) +// TestHeadersCausingStreamError tests headers that should cause a stream protocol +// error, which would end up with a RST_STREAM being sent to the client and also +// the server closing the stream. +func (s) TestHeadersCausingStreamError(t *testing.T) { + tests := []struct { + name string + headers []struct { + name string + values []string + } + }{ + // If the client sends an HTTP/2 request with a :method header with a + // value other than POST, as specified in the gRPC over HTTP/2 + // specification, the server should close the stream. + { + name: "Client Sending Wrong Method", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"PUT"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "content-type", values: []string{"application/grpc"}}, + }, + }, + // "Transports must consider requests containing the Connection header + // as malformed" - A41 Malformed requests map to a stream error of type + // PROTOCOL_ERROR. + { + name: "Connection header present", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "content-type", values: []string{"application/grpc"}}, + {name: "connection", values: []string{"not-supported"}}, + }, + }, + // multiple :authority or multiple Host headers would make the eventual + // :authority ambiguous as per A41. Since these headers won't have a + // content-type that corresponds to a grpc-client, the server should + // simply write a RST_STREAM to the wire. + { + // Note: multiple authority headers are handled by the framer + // itself, which will cause a stream error. Thus, it will never get + // to operateHeaders with the check in operateHeaders for stream + // error, but the server transport will still send a stream error. + name: "Multiple authority headers", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost", "localhost2"}}, + {name: "host", values: []string{"localhost"}}, + }, + }, } - defer mconn.Close() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) + defer server.stop() + // Create a client directly to not tie what you can send to API of + // http2_client.go (i.e. control headers being sent). + mconn, err := net.Dial("tcp", server.lis.Addr().String()) + if err != nil { + t.Fatalf("Client failed to dial: %v", err) + } + defer mconn.Close() - if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { - t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) + if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { + t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) + } + + framer := http2.NewFramer(mconn, mconn) + if err := framer.WriteSettings(); err != nil { + t.Fatalf("Error while writing settings: %v", err) + } + + // result chan indicates that reader received a RSTStream from server. + // An error will be passed on it if any other frame is received. + result := testutils.NewChannel() + + // Launch a reader goroutine. + go func() { + for { + frame, err := framer.ReadFrame() + if err != nil { + return + } + switch frame := frame.(type) { + case *http2.SettingsFrame: + // Do nothing. A settings frame is expected from server preface. + case *http2.RSTStreamFrame: + if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeProtocol { + // Client only created a single stream, so RST Stream should be for that single stream. + result.Send(fmt.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))) + } + // Records that client successfully received RST Stream frame. + result.Send(nil) + return + default: + // The server should send nothing but a single RST Stream frame. + result.Send(errors.New("the client received a frame other than RST Stream")) + } + } + }() + + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + + // Needs to build headers deterministically to conform to gRPC over + // HTTP/2 spec. + for _, header := range test.headers { + for _, value := range header.values { + if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { + t.Fatalf("Error while encoding header: %v", err) + } + } + } + + if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { + t.Fatalf("Error while writing headers: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + r, err := result.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from channel: %v", err) + } + if r != nil { + t.Fatalf("want nil, got %v", r) + } + }) } +} - framer := http2.NewFramer(mconn, mconn) - if err := framer.WriteSettings(); err != nil { - t.Fatalf("Error while writing settings: %v", err) +// TestHeadersMultipleHosts tests that a request with multiple hosts gets +// rejected with HTTP Status 400 and gRPC status Internal, regardless of whether +// the client is speaking gRPC or not. +func (s) TestHeadersMultipleHosts(t *testing.T) { + tests := []struct { + name string + headers []struct { + name string + values []string + } + }{ + // Note: multiple authority headers are handled by the framer itself, + // which will cause a stream error. Thus, it will never get to + // operateHeaders with the check in operateHeaders for possible grpc-status sent back. + + // multiple :authority or multiple Host headers would make the eventual + // :authority ambiguous as per A41. This takes precedence even over the + // fact a request is non grpc. All of these requests should be rejected + // with grpc-status Internal. + { + name: "Multiple host headers non grpc", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "host", values: []string{"localhost", "localhost2"}}, + }, + }, + { + name: "Multiple host headers grpc", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "content-type", values: []string{"application/grpc"}}, + {name: "host", values: []string{"localhost", "localhost2"}}, + }, + }, } + for _, test := range tests { + server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) + defer server.stop() + // Create a client directly to not tie what you can send to API of + // http2_client.go (i.e. control headers being sent). + mconn, err := net.Dial("tcp", server.lis.Addr().String()) + if err != nil { + t.Fatalf("Client failed to dial: %v", err) + } + defer mconn.Close() - // success chan indicates that reader received a RSTStream from server. - // An error will be passed on it if any other frame is received. - success := testutils.NewChannel() + if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { + t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) + } - // Launch a reader goroutine. - go func() { - for { - frame, err := framer.ReadFrame() - if err != nil { - return + framer := http2.NewFramer(mconn, mconn) + framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil) + if err := framer.WriteSettings(); err != nil { + t.Fatalf("Error while writing settings: %v", err) + } + + // result chan indicates that reader received a Headers Frame with + // desired grpc status and message from server. An error will be passed + // on it if any other frame is received. + result := testutils.NewChannel() + + // Launch a reader goroutine. + go func() { + for { + frame, err := framer.ReadFrame() + if err != nil { + return + } + switch frame := frame.(type) { + case *http2.SettingsFrame: + // Do nothing. A settings frame is expected from server preface. + case *http2.MetaHeadersFrame: + var status, grpcStatus, grpcMessage string + for _, header := range frame.Fields { + if header.Name == ":status" { + status = header.Value + } + if header.Name == "grpc-status" { + grpcStatus = header.Value + } + if header.Name == "grpc-message" { + grpcMessage = header.Value + } + } + if status != "400" { + result.Send(fmt.Errorf("incorrect HTTP Status got %v, want 200", status)) + return + } + if grpcStatus != "13" { // grpc status code internal + result.Send(fmt.Errorf("incorrect gRPC Status got %v, want 13", grpcStatus)) + return + } + if !strings.Contains(grpcMessage, "both must only have 1 value as per HTTP/2 spec") { + result.Send(fmt.Errorf("incorrect gRPC message")) + return + } + + // Records that client successfully received a HeadersFrame + // with expected Trailers-Only response. + result.Send(nil) + return + default: + // The server should send nothing but a single Settings and Headers frame. + result.Send(errors.New("the client received a frame other than Settings or Headers")) + } } - switch frame := frame.(type) { - case *http2.SettingsFrame: - // Do nothing. A settings frame is expected from server preface. - case *http2.RSTStreamFrame: - if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeProtocol { - // Client only created a single stream, so RST Stream should be for that single stream. - t.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) + }() + + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + + // Needs to build headers deterministically to conform to gRPC over + // HTTP/2 spec. + for _, header := range test.headers { + for _, value := range header.values { + if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { + t.Fatalf("Error while encoding header: %v", err) } - // Records that client successfully received RST Stream frame. - success.Send(nil) - return - default: - // The server should send nothing but a single RST Stream frame. - success.Send(errors.New("The client received a frame other than RST Stream")) } } - }() - - // Done with HTTP/2 setup - now create a stream with a bad method header. - var buf bytes.Buffer - henc := hpack.NewEncoder(&buf) - // Method is required to be POST in a gRPC call. - if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "PUT"}); err != nil { - t.Fatalf("Error while encoding header: %v", err) - } - // Have the rest of the headers be ok and within the gRPC over HTTP/2 spec. - if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { - t.Fatalf("Error while encoding header: %v", err) - } - if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { - t.Fatalf("Error while encoding header: %v", err) - } - if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { - t.Fatalf("Error while encoding header: %v", err) - } - if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { - t.Fatalf("Error while writing headers: %v", err) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if e, err := success.Receive(ctx); e != nil || err != nil { - t.Fatalf("Error in frame server should send: %v. Error receiving from channel: %v", e, err) + if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { + t.Fatalf("Error while writing headers: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + r, err := result.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from channel: %v", err) + } + if r != nil { + t.Fatalf("want nil, got %v", r) + } } } @@ -1762,7 +1972,7 @@ func runPingPongTest(t *testing.T, msgSize int) { server, client, cancel := setUp(t, 0, 0, pingpong) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() defer server.mu.Unlock() @@ -1850,7 +2060,7 @@ func (s) TestHeaderTblSize(t *testing.T) { server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) defer cancel() - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) defer server.stop() ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() @@ -1969,10 +2179,186 @@ func (s) TestClientHandshakeInfo(t *testing.T) { if err != nil { t.Fatalf("NewClientTransport(): %v", err) } - defer tr.Close() + defer tr.Close(fmt.Errorf("closed manually by test")) wantAttr := attributes.New(testAttrKey, testAttrVal) if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr) } } + +func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { + testStream := func() *Stream { + return &Stream{ + done: make(chan struct{}), + headerChan: make(chan struct{}), + buf: &recvBuffer{ + c: make(chan recvMsg), + mu: sync.Mutex{}, + }, + } + } + + testClient := func(ts *Stream) *http2Client { + return &http2Client{ + mu: sync.Mutex{}, + activeStreams: map[uint32]*Stream{ + 0: ts, + }, + controlBuf: &controlBuffer{ + ch: make(chan struct{}), + done: make(chan struct{}), + list: &itemList{}, + }, + } + } + + for _, test := range []struct { + name string + // input + metaHeaderFrame *http2.MetaHeadersFrame + // output + wantStatus *status.Status + }{ + { + name: "valid header", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + {Name: "grpc-status", Value: "0"}, + {Name: ":status", Value: "200"}, + }, + }, + // no error + wantStatus: status.New(codes.OK, ""), + }, + { + name: "missing content-type header", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "grpc-status", Value: "0"}, + {Name: ":status", Value: "200"}, + }, + }, + wantStatus: status.New( + codes.Unknown, + "malformed header: missing HTTP content-type", + ), + }, + { + name: "invalid grpc status header field", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + {Name: "grpc-status", Value: "xxxx"}, + {Name: ":status", Value: "200"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "transport: malformed grpc-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", + ), + }, + { + name: "invalid http content type", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/json"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"", + ), + }, + { + name: "http fallback and invalid http status", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + // No content type provided then fallback into handling http error. + {Name: ":status", Value: "xxxx"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "transport: malformed http-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", + ), + }, + { + name: "http2 frame size exceeds", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: nil, + Truncated: true, + }, + wantStatus: status.New( + codes.Internal, + "peer header list size exceeded limit", + ), + }, + { + name: "bad status in grpc mode", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + {Name: "grpc-status", Value: "0"}, + {Name: ":status", Value: "504"}, + }, + }, + wantStatus: status.New( + codes.Unavailable, + "unexpected HTTP status code received from server: 504 (Gateway Timeout)", + ), + }, + { + name: "missing http status", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "malformed header: missing HTTP status", + ), + }, + } { + + t.Run(test.name, func(t *testing.T) { + ts := testStream() + s := testClient(ts) + + test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ + FrameHeader: http2.FrameHeader{ + StreamID: 0, + }, + } + + s.operateHeaders(test.metaHeaderFrame) + + got := ts.status + want := test.wantStatus + if got.Code() != want.Code() || got.Message() != want.Message() { + t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) + } + }) + t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) { + ts := testStream() + s := testClient(ts) + + test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ + FrameHeader: http2.FrameHeader{ + StreamID: 0, + Flags: http2.FlagHeadersEndStream, + }, + } + + s.operateHeaders(test.metaHeaderFrame) + + got := ts.status + want := test.wantStatus + if got.Code() != want.Code() || got.Message() != want.Message() { + t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) + } + }) + } +} diff --git a/internal/xds/bootstrap.go b/internal/xds/bootstrap.go new file mode 100644 index 00000000000..1d74ab46a11 --- /dev/null +++ b/internal/xds/bootstrap.go @@ -0,0 +1,147 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package xds contains types that need to be shared between code under +// google.golang.org/grpc/xds/... and the rest of gRPC. +package xds + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/xds/env" +) + +var logger = grpclog.Component("internal/xds") + +// TransportAPI refers to the API version for xDS transport protocol. +type TransportAPI int + +const ( + // TransportV2 refers to the v2 xDS transport protocol. + TransportV2 TransportAPI = iota + // TransportV3 refers to the v3 xDS transport protocol. + TransportV3 +) + +// BootstrapOptions wraps the parameters passed to SetupBootstrapFile. +type BootstrapOptions struct { + // Version is the xDS transport protocol version. + Version TransportAPI + // NodeID is the node identifier of the gRPC client/server node in the + // proxyless service mesh. + NodeID string + // ServerURI is the address of the management server. + ServerURI string + // ServerListenerResourceNameTemplate is the Listener resource name to fetch. + ServerListenerResourceNameTemplate string + // CertificateProviders is the certificate providers configuration. + CertificateProviders map[string]json.RawMessage +} + +// SetupBootstrapFile creates a temporary file with bootstrap contents, based on +// the passed in options, and updates the bootstrap environment variable to +// point to this file. +// +// Returns a cleanup function which will be non-nil if the setup process was +// completed successfully. It is the responsibility of the caller to invoke the +// cleanup function at the end of the test. +func SetupBootstrapFile(opts BootstrapOptions) (func(), error) { + bootstrapContents, err := BootstrapContents(opts) + if err != nil { + return nil, err + } + f, err := ioutil.TempFile("", "test_xds_bootstrap_*") + if err != nil { + return nil, fmt.Errorf("failed to created bootstrap file: %v", err) + } + + if err := ioutil.WriteFile(f.Name(), bootstrapContents, 0644); err != nil { + return nil, fmt.Errorf("failed to created bootstrap file: %v", err) + } + logger.Infof("Created bootstrap file at %q with contents: %s\n", f.Name(), bootstrapContents) + + origBootstrapFileName := env.BootstrapFileName + env.BootstrapFileName = f.Name() + return func() { + os.Remove(f.Name()) + env.BootstrapFileName = origBootstrapFileName + }, nil +} + +// BootstrapContents returns the contents to go into a bootstrap file, +// environment, or configuration passed to +// xds.NewXDSResolverWithConfigForTesting. +func BootstrapContents(opts BootstrapOptions) ([]byte, error) { + cfg := &bootstrapConfig{ + XdsServers: []server{ + { + ServerURI: opts.ServerURI, + ChannelCreds: []creds{ + { + Type: "insecure", + }, + }, + }, + }, + Node: node{ + ID: opts.NodeID, + }, + CertificateProviders: opts.CertificateProviders, + ServerListenerResourceNameTemplate: opts.ServerListenerResourceNameTemplate, + } + switch opts.Version { + case TransportV2: + // TODO: Add any v2 specific fields. + case TransportV3: + cfg.XdsServers[0].ServerFeatures = append(cfg.XdsServers[0].ServerFeatures, "xds_v3") + default: + return nil, fmt.Errorf("unsupported xDS transport protocol version: %v", opts.Version) + } + + bootstrapContents, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to created bootstrap file: %v", err) + } + return bootstrapContents, nil +} + +type bootstrapConfig struct { + XdsServers []server `json:"xds_servers,omitempty"` + Node node `json:"node,omitempty"` + CertificateProviders map[string]json.RawMessage `json:"certificate_providers,omitempty"` + ServerListenerResourceNameTemplate string `json:"server_listener_resource_name_template,omitempty"` +} + +type server struct { + ServerURI string `json:"server_uri,omitempty"` + ChannelCreds []creds `json:"channel_creds,omitempty"` + ServerFeatures []string `json:"server_features,omitempty"` +} + +type creds struct { + Type string `json:"type,omitempty"` + Config interface{} `json:"config,omitempty"` +} + +type node struct { + ID string `json:"id,omitempty"` +} diff --git a/xds/internal/env/env.go b/internal/xds/env/env.go similarity index 52% rename from xds/internal/env/env.go rename to internal/xds/env/env.go index a28d741f356..2977bfa6285 100644 --- a/xds/internal/env/env.go +++ b/internal/xds/env/env.go @@ -37,42 +37,59 @@ const ( // and kept in variable BootstrapFileName. // // When both bootstrap FileName and FileContent are set, FileName is used. - BootstrapFileContentEnv = "GRPC_XDS_BOOTSTRAP_CONFIG" - circuitBreakingSupportEnv = "GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING" - timeoutSupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT" - faultInjectionSupportEnv = "GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION" + BootstrapFileContentEnv = "GRPC_XDS_BOOTSTRAP_CONFIG" + + ringHashSupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH" clientSideSecuritySupportEnv = "GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT" + aggregateAndDNSSupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER" + retrySupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY" + rbacSupportEnv = "GRPC_XDS_EXPERIMENTAL_RBAC" + + c2pResolverSupportEnv = "GRPC_EXPERIMENTAL_GOOGLE_C2P_RESOLVER" + c2pResolverTestOnlyTrafficDirectorURIEnv = "GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI" ) var ( // BootstrapFileName holds the name of the file which contains xDS bootstrap // configuration. Users can specify the location of the bootstrap file by - // setting the environment variable "GRPC_XDS_BOOSTRAP". + // setting the environment variable "GRPC_XDS_BOOTSTRAP". // // When both bootstrap FileName and FileContent are set, FileName is used. BootstrapFileName = os.Getenv(BootstrapFileNameEnv) // BootstrapFileContent holds the content of the xDS bootstrap // configuration. Users can specify the bootstrap config by - // setting the environment variable "GRPC_XDS_BOOSTRAP_CONFIG". + // setting the environment variable "GRPC_XDS_BOOTSTRAP_CONFIG". // // When both bootstrap FileName and FileContent are set, FileName is used. BootstrapFileContent = os.Getenv(BootstrapFileContentEnv) - // CircuitBreakingSupport indicates whether circuit breaking support is - // enabled, which can be done by setting the environment variable - // "GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING" to "true". - CircuitBreakingSupport = strings.EqualFold(os.Getenv(circuitBreakingSupportEnv), "true") - // TimeoutSupport indicates whether support for max_stream_duration in - // route actions is enabled. This can be enabled by setting the - // environment variable "GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT" to "true". - TimeoutSupport = strings.EqualFold(os.Getenv(timeoutSupportEnv), "true") - // FaultInjectionSupport is used to control both fault injection and HTTP - // filter support. - FaultInjectionSupport = strings.EqualFold(os.Getenv(faultInjectionSupportEnv), "true") + // RingHashSupport indicates whether ring hash support is enabled, which can + // be disabled by setting the environment variable + // "GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH" to "false". + RingHashSupport = !strings.EqualFold(os.Getenv(ringHashSupportEnv), "false") // ClientSideSecuritySupport is used to control processing of security // configuration on the client-side. // // Note that there is no env var protection for the server-side because we // have a brand new API on the server-side and users explicitly need to use // the new API to get security integration on the server. - ClientSideSecuritySupport = strings.EqualFold(os.Getenv(clientSideSecuritySupportEnv), "true") + ClientSideSecuritySupport = !strings.EqualFold(os.Getenv(clientSideSecuritySupportEnv), "false") + // AggregateAndDNSSupportEnv indicates whether processing of aggregated + // cluster and DNS cluster is enabled, which can be enabled by setting the + // environment variable + // "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER" to + // "true". + AggregateAndDNSSupportEnv = strings.EqualFold(os.Getenv(aggregateAndDNSSupportEnv), "true") + + // RetrySupport indicates whether xDS retry is enabled. + RetrySupport = !strings.EqualFold(os.Getenv(retrySupportEnv), "false") + + // RBACSupport indicates whether xDS configured RBAC HTTP Filter is enabled. + RBACSupport = strings.EqualFold(os.Getenv(rbacSupportEnv), "true") + + // C2PResolverSupport indicates whether support for C2P resolver is enabled. + // This can be enabled by setting the environment variable + // "GRPC_EXPERIMENTAL_GOOGLE_C2P_RESOLVER" to "true". + C2PResolverSupport = strings.EqualFold(os.Getenv(c2pResolverSupportEnv), "true") + // C2PResolverTestOnlyTrafficDirectorURI is the TD URI for testing. + C2PResolverTestOnlyTrafficDirectorURI = os.Getenv(c2pResolverTestOnlyTrafficDirectorURIEnv) ) diff --git a/internal/xds/matcher/matcher_header.go b/internal/xds/matcher/matcher_header.go new file mode 100644 index 00000000000..35a22adadcf --- /dev/null +++ b/internal/xds/matcher/matcher_header.go @@ -0,0 +1,253 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package matcher + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "google.golang.org/grpc/metadata" +) + +// HeaderMatcher is an interface for header matchers. These are +// documented in (EnvoyProxy link here?). These matchers will match on different +// aspects of HTTP header name/value pairs. +type HeaderMatcher interface { + Match(metadata.MD) bool + String() string +} + +// mdValuesFromOutgoingCtx retrieves metadata from context. If there are +// multiple values, the values are concatenated with "," (comma and no space). +// +// All header matchers only match against the comma-concatenated string. +func mdValuesFromOutgoingCtx(md metadata.MD, key string) (string, bool) { + vs, ok := md[key] + if !ok { + return "", false + } + return strings.Join(vs, ","), true +} + +// HeaderExactMatcher matches on an exact match of the value of the header. +type HeaderExactMatcher struct { + key string + exact string +} + +// NewHeaderExactMatcher returns a new HeaderExactMatcher. +func NewHeaderExactMatcher(key, exact string) *HeaderExactMatcher { + return &HeaderExactMatcher{key: key, exact: exact} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderExactMatcher. +func (hem *HeaderExactMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hem.key) + if !ok { + return false + } + return v == hem.exact +} + +func (hem *HeaderExactMatcher) String() string { + return fmt.Sprintf("headerExact:%v:%v", hem.key, hem.exact) +} + +// HeaderRegexMatcher matches on whether the entire request header value matches +// the regex. +type HeaderRegexMatcher struct { + key string + re *regexp.Regexp +} + +// NewHeaderRegexMatcher returns a new HeaderRegexMatcher. +func NewHeaderRegexMatcher(key string, re *regexp.Regexp) *HeaderRegexMatcher { + return &HeaderRegexMatcher{key: key, re: re} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderRegexMatcher. +func (hrm *HeaderRegexMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hrm.key) + if !ok { + return false + } + return hrm.re.MatchString(v) +} + +func (hrm *HeaderRegexMatcher) String() string { + return fmt.Sprintf("headerRegex:%v:%v", hrm.key, hrm.re.String()) +} + +// HeaderRangeMatcher matches on whether the request header value is within the +// range. The header value must be an integer in base 10 notation. +type HeaderRangeMatcher struct { + key string + start, end int64 // represents [start, end). +} + +// NewHeaderRangeMatcher returns a new HeaderRangeMatcher. +func NewHeaderRangeMatcher(key string, start, end int64) *HeaderRangeMatcher { + return &HeaderRangeMatcher{key: key, start: start, end: end} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderRangeMatcher. +func (hrm *HeaderRangeMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hrm.key) + if !ok { + return false + } + if i, err := strconv.ParseInt(v, 10, 64); err == nil && i >= hrm.start && i < hrm.end { + return true + } + return false +} + +func (hrm *HeaderRangeMatcher) String() string { + return fmt.Sprintf("headerRange:%v:[%d,%d)", hrm.key, hrm.start, hrm.end) +} + +// HeaderPresentMatcher will match based on whether the header is present in the +// whole request. +type HeaderPresentMatcher struct { + key string + present bool +} + +// NewHeaderPresentMatcher returns a new HeaderPresentMatcher. +func NewHeaderPresentMatcher(key string, present bool) *HeaderPresentMatcher { + return &HeaderPresentMatcher{key: key, present: present} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderPresentMatcher. +func (hpm *HeaderPresentMatcher) Match(md metadata.MD) bool { + vs, ok := mdValuesFromOutgoingCtx(md, hpm.key) + present := ok && len(vs) > 0 + return present == hpm.present +} + +func (hpm *HeaderPresentMatcher) String() string { + return fmt.Sprintf("headerPresent:%v:%v", hpm.key, hpm.present) +} + +// HeaderPrefixMatcher matches on whether the prefix of the header value matches +// the prefix passed into this struct. +type HeaderPrefixMatcher struct { + key string + prefix string +} + +// NewHeaderPrefixMatcher returns a new HeaderPrefixMatcher. +func NewHeaderPrefixMatcher(key string, prefix string) *HeaderPrefixMatcher { + return &HeaderPrefixMatcher{key: key, prefix: prefix} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderPrefixMatcher. +func (hpm *HeaderPrefixMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hpm.key) + if !ok { + return false + } + return strings.HasPrefix(v, hpm.prefix) +} + +func (hpm *HeaderPrefixMatcher) String() string { + return fmt.Sprintf("headerPrefix:%v:%v", hpm.key, hpm.prefix) +} + +// HeaderSuffixMatcher matches on whether the suffix of the header value matches +// the suffix passed into this struct. +type HeaderSuffixMatcher struct { + key string + suffix string +} + +// NewHeaderSuffixMatcher returns a new HeaderSuffixMatcher. +func NewHeaderSuffixMatcher(key string, suffix string) *HeaderSuffixMatcher { + return &HeaderSuffixMatcher{key: key, suffix: suffix} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderSuffixMatcher. +func (hsm *HeaderSuffixMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hsm.key) + if !ok { + return false + } + return strings.HasSuffix(v, hsm.suffix) +} + +func (hsm *HeaderSuffixMatcher) String() string { + return fmt.Sprintf("headerSuffix:%v:%v", hsm.key, hsm.suffix) +} + +// HeaderContainsMatcher matches on whether the header value contains the +// value passed into this struct. +type HeaderContainsMatcher struct { + key string + contains string +} + +// NewHeaderContainsMatcher returns a new HeaderContainsMatcher. key is the HTTP +// Header key to match on, and contains is the value that the header should +// should contain for a successful match. An empty contains string does not +// work, use HeaderPresentMatcher in that case. +func NewHeaderContainsMatcher(key string, contains string) *HeaderContainsMatcher { + return &HeaderContainsMatcher{key: key, contains: contains} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderContainsMatcher. +func (hcm *HeaderContainsMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hcm.key) + if !ok { + return false + } + return strings.Contains(v, hcm.contains) +} + +func (hcm *HeaderContainsMatcher) String() string { + return fmt.Sprintf("headerContains:%v%v", hcm.key, hcm.contains) +} + +// InvertMatcher inverts the match result of the underlying header matcher. +type InvertMatcher struct { + m HeaderMatcher +} + +// NewInvertMatcher returns a new InvertMatcher. +func NewInvertMatcher(m HeaderMatcher) *InvertMatcher { + return &InvertMatcher{m: m} +} + +// Match returns whether the passed in HTTP Headers match according to the +// InvertMatcher. +func (i *InvertMatcher) Match(md metadata.MD) bool { + return !i.m.Match(md) +} + +func (i *InvertMatcher) String() string { + return fmt.Sprintf("invert{%s}", i.m) +} diff --git a/xds/internal/resolver/matcher_header_test.go b/internal/xds/matcher/matcher_header_test.go similarity index 88% rename from xds/internal/resolver/matcher_header_test.go rename to internal/xds/matcher/matcher_header_test.go index fb87cc5dd32..9a0d51300d0 100644 --- a/xds/internal/resolver/matcher_header_test.go +++ b/internal/xds/matcher/matcher_header_test.go @@ -16,7 +16,7 @@ * */ -package resolver +package matcher import ( "regexp" @@ -64,8 +64,8 @@ func TestHeaderExactMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hem := newHeaderExactMatcher(tt.key, tt.exact) - if got := hem.match(tt.md); got != tt.want { + hem := NewHeaderExactMatcher(tt.key, tt.exact) + if got := hem.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -110,8 +110,8 @@ func TestHeaderRegexMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hrm := newHeaderRegexMatcher(tt.key, regexp.MustCompile(tt.regexStr)) - if got := hrm.match(tt.md); got != tt.want { + hrm := NewHeaderRegexMatcher(tt.key, regexp.MustCompile(tt.regexStr)) + if got := hrm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -157,8 +157,8 @@ func TestHeaderRangeMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hrm := newHeaderRangeMatcher(tt.key, tt.start, tt.end) - if got := hrm.match(tt.md); got != tt.want { + hrm := NewHeaderRangeMatcher(tt.key, tt.start, tt.end) + if got := hrm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -204,8 +204,8 @@ func TestHeaderPresentMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hpm := newHeaderPresentMatcher(tt.key, tt.present) - if got := hpm.match(tt.md); got != tt.want { + hpm := NewHeaderPresentMatcher(tt.key, tt.present) + if got := hpm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -250,8 +250,8 @@ func TestHeaderPrefixMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hpm := newHeaderPrefixMatcher(tt.key, tt.prefix) - if got := hpm.match(tt.md); got != tt.want { + hpm := NewHeaderPrefixMatcher(tt.key, tt.prefix) + if got := hpm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -296,8 +296,8 @@ func TestHeaderSuffixMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hsm := newHeaderSuffixMatcher(tt.key, tt.suffix) - if got := hsm.match(tt.md); got != tt.want { + hsm := NewHeaderSuffixMatcher(tt.key, tt.suffix) + if got := hsm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -307,24 +307,24 @@ func TestHeaderSuffixMatcherMatch(t *testing.T) { func TestInvertMatcherMatch(t *testing.T) { tests := []struct { name string - m headerMatcherInterface + m HeaderMatcher md metadata.MD }{ { name: "true->false", - m: newHeaderExactMatcher("th", "tv"), + m: NewHeaderExactMatcher("th", "tv"), md: metadata.Pairs("th", "tv"), }, { name: "false->true", - m: newHeaderExactMatcher("th", "abc"), + m: NewHeaderExactMatcher("th", "abc"), md: metadata.Pairs("th", "tv"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := newInvertMatcher(tt.m).match(tt.md) - want := !tt.m.match(tt.md) + got := NewInvertMatcher(tt.m).Match(tt.md) + want := !tt.m.Match(tt.md) if got != want { t.Errorf("match() = %v, want %v", got, want) } diff --git a/internal/xds/matcher/string_matcher.go b/internal/xds/matcher/string_matcher.go new file mode 100644 index 00000000000..d7df6a1e2b4 --- /dev/null +++ b/internal/xds/matcher/string_matcher.go @@ -0,0 +1,183 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package matcher contains types that need to be shared between code under +// google.golang.org/grpc/xds/... and the rest of gRPC. +package matcher + +import ( + "errors" + "fmt" + "regexp" + "strings" + + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" +) + +// StringMatcher contains match criteria for matching a string, and is an +// internal representation of the `StringMatcher` proto defined at +// https://github.com/envoyproxy/envoy/blob/main/api/envoy/type/matcher/v3/string.proto. +type StringMatcher struct { + // Since these match fields are part of a `oneof` in the corresponding xDS + // proto, only one of them is expected to be set. + exactMatch *string + prefixMatch *string + suffixMatch *string + regexMatch *regexp.Regexp + containsMatch *string + // If true, indicates the exact/prefix/suffix/contains matching should be + // case insensitive. This has no effect on the regex match. + ignoreCase bool +} + +// Match returns true if input matches the criteria in the given StringMatcher. +func (sm StringMatcher) Match(input string) bool { + if sm.ignoreCase { + input = strings.ToLower(input) + } + switch { + case sm.exactMatch != nil: + return input == *sm.exactMatch + case sm.prefixMatch != nil: + return strings.HasPrefix(input, *sm.prefixMatch) + case sm.suffixMatch != nil: + return strings.HasSuffix(input, *sm.suffixMatch) + case sm.regexMatch != nil: + return sm.regexMatch.MatchString(input) + case sm.containsMatch != nil: + return strings.Contains(input, *sm.containsMatch) + } + return false +} + +// StringMatcherFromProto is a helper function to create a StringMatcher from +// the corresponding StringMatcher proto. +// +// Returns a non-nil error if matcherProto is invalid. +func StringMatcherFromProto(matcherProto *v3matcherpb.StringMatcher) (StringMatcher, error) { + if matcherProto == nil { + return StringMatcher{}, errors.New("input StringMatcher proto is nil") + } + + matcher := StringMatcher{ignoreCase: matcherProto.GetIgnoreCase()} + switch mt := matcherProto.GetMatchPattern().(type) { + case *v3matcherpb.StringMatcher_Exact: + matcher.exactMatch = &mt.Exact + if matcher.ignoreCase { + *matcher.exactMatch = strings.ToLower(*matcher.exactMatch) + } + case *v3matcherpb.StringMatcher_Prefix: + if matcherProto.GetPrefix() == "" { + return StringMatcher{}, errors.New("empty prefix is not allowed in StringMatcher") + } + matcher.prefixMatch = &mt.Prefix + if matcher.ignoreCase { + *matcher.prefixMatch = strings.ToLower(*matcher.prefixMatch) + } + case *v3matcherpb.StringMatcher_Suffix: + if matcherProto.GetSuffix() == "" { + return StringMatcher{}, errors.New("empty suffix is not allowed in StringMatcher") + } + matcher.suffixMatch = &mt.Suffix + if matcher.ignoreCase { + *matcher.suffixMatch = strings.ToLower(*matcher.suffixMatch) + } + case *v3matcherpb.StringMatcher_SafeRegex: + regex := matcherProto.GetSafeRegex().GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return StringMatcher{}, fmt.Errorf("safe_regex matcher %q is invalid", regex) + } + matcher.regexMatch = re + case *v3matcherpb.StringMatcher_Contains: + if matcherProto.GetContains() == "" { + return StringMatcher{}, errors.New("empty contains is not allowed in StringMatcher") + } + matcher.containsMatch = &mt.Contains + if matcher.ignoreCase { + *matcher.containsMatch = strings.ToLower(*matcher.containsMatch) + } + default: + return StringMatcher{}, fmt.Errorf("unrecognized string matcher: %+v", matcherProto) + } + return matcher, nil +} + +// StringMatcherForTesting is a helper function to create a StringMatcher based +// on the given arguments. Intended only for testing purposes. +func StringMatcherForTesting(exact, prefix, suffix, contains *string, regex *regexp.Regexp, ignoreCase bool) StringMatcher { + sm := StringMatcher{ + exactMatch: exact, + prefixMatch: prefix, + suffixMatch: suffix, + regexMatch: regex, + containsMatch: contains, + ignoreCase: ignoreCase, + } + if ignoreCase { + switch { + case sm.exactMatch != nil: + *sm.exactMatch = strings.ToLower(*exact) + case sm.prefixMatch != nil: + *sm.prefixMatch = strings.ToLower(*prefix) + case sm.suffixMatch != nil: + *sm.suffixMatch = strings.ToLower(*suffix) + case sm.containsMatch != nil: + *sm.containsMatch = strings.ToLower(*contains) + } + } + return sm +} + +// ExactMatch returns the value of the configured exact match or an empty string +// if exact match criteria was not specified. +func (sm StringMatcher) ExactMatch() string { + if sm.exactMatch != nil { + return *sm.exactMatch + } + return "" +} + +// Equal returns true if other and sm are equivalent to each other. +func (sm StringMatcher) Equal(other StringMatcher) bool { + if sm.ignoreCase != other.ignoreCase { + return false + } + + if (sm.exactMatch != nil) != (other.exactMatch != nil) || + (sm.prefixMatch != nil) != (other.prefixMatch != nil) || + (sm.suffixMatch != nil) != (other.suffixMatch != nil) || + (sm.regexMatch != nil) != (other.regexMatch != nil) || + (sm.containsMatch != nil) != (other.containsMatch != nil) { + return false + } + + switch { + case sm.exactMatch != nil: + return *sm.exactMatch == *other.exactMatch + case sm.prefixMatch != nil: + return *sm.prefixMatch == *other.prefixMatch + case sm.suffixMatch != nil: + return *sm.suffixMatch == *other.suffixMatch + case sm.regexMatch != nil: + return sm.regexMatch.String() == other.regexMatch.String() + case sm.containsMatch != nil: + return *sm.containsMatch == *other.containsMatch + } + return true +} diff --git a/internal/xds/matcher/string_matcher_test.go b/internal/xds/matcher/string_matcher_test.go new file mode 100644 index 00000000000..389963b94e9 --- /dev/null +++ b/internal/xds/matcher/string_matcher_test.go @@ -0,0 +1,309 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package matcher + +import ( + "regexp" + "testing" + + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + "github.com/google/go-cmp/cmp" +) + +func TestStringMatcherFromProto(t *testing.T) { + tests := []struct { + desc string + inputProto *v3matcherpb.StringMatcher + wantMatcher StringMatcher + wantErr bool + }{ + { + desc: "nil proto", + wantErr: true, + }, + { + desc: "empty prefix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}, + }, + wantErr: true, + }, + { + desc: "empty suffix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: ""}, + }, + wantErr: true, + }, + { + desc: "empty contains", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: ""}, + }, + wantErr: true, + }, + { + desc: "invalid regex", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{ + SafeRegex: &v3matcherpb.RegexMatcher{Regex: "??"}, + }, + }, + wantErr: true, + }, + { + desc: "happy case exact", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "exact"}, + }, + wantMatcher: StringMatcher{exactMatch: newStringP("exact")}, + }, + { + desc: "happy case exact ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "EXACT"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + exactMatch: newStringP("exact"), + ignoreCase: true, + }, + }, + { + desc: "happy case prefix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "prefix"}, + }, + wantMatcher: StringMatcher{prefixMatch: newStringP("prefix")}, + }, + { + desc: "happy case prefix ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "PREFIX"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + prefixMatch: newStringP("prefix"), + ignoreCase: true, + }, + }, + { + desc: "happy case suffix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "suffix"}, + }, + wantMatcher: StringMatcher{suffixMatch: newStringP("suffix")}, + }, + { + desc: "happy case suffix ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "SUFFIX"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + suffixMatch: newStringP("suffix"), + ignoreCase: true, + }, + }, + { + desc: "happy case regex", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{ + SafeRegex: &v3matcherpb.RegexMatcher{Regex: "good?regex?"}, + }, + }, + wantMatcher: StringMatcher{regexMatch: regexp.MustCompile("good?regex?")}, + }, + { + desc: "happy case contains", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "contains"}, + }, + wantMatcher: StringMatcher{containsMatch: newStringP("contains")}, + }, + { + desc: "happy case contains ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "CONTAINS"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + containsMatch: newStringP("contains"), + ignoreCase: true, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotMatcher, err := StringMatcherFromProto(test.inputProto) + if (err != nil) != test.wantErr { + t.Fatalf("StringMatcherFromProto(%+v) returned err: %v, wantErr: %v", test.inputProto, err, test.wantErr) + } + if diff := cmp.Diff(gotMatcher, test.wantMatcher, cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Fatalf("StringMatcherFromProto(%+v) returned unexpected diff (-got, +want):\n%s", test.inputProto, diff) + } + }) + } +} + +func TestMatch(t *testing.T) { + var ( + exactMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "exact"}}) + prefixMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "prefix"}}) + suffixMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "suffix"}}) + regexMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: "good?regex?"}}}) + containsMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "contains"}}) + exactMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "exact"}, + IgnoreCase: true, + }) + prefixMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "prefix"}, + IgnoreCase: true, + }) + suffixMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "suffix"}, + IgnoreCase: true, + }) + containsMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "contains"}, + IgnoreCase: true, + }) + ) + + tests := []struct { + desc string + matcher StringMatcher + input string + wantMatch bool + }{ + { + desc: "exact match success", + matcher: exactMatcher, + input: "exact", + wantMatch: true, + }, + { + desc: "exact match failure", + matcher: exactMatcher, + input: "not-exact", + }, + { + desc: "exact match success with ignore case", + matcher: exactMatcherIgnoreCase, + input: "EXACT", + wantMatch: true, + }, + { + desc: "exact match failure with ignore case", + matcher: exactMatcherIgnoreCase, + input: "not-exact", + }, + { + desc: "prefix match success", + matcher: prefixMatcher, + input: "prefixIsHere", + wantMatch: true, + }, + { + desc: "prefix match failure", + matcher: prefixMatcher, + input: "not-prefix", + }, + { + desc: "prefix match success with ignore case", + matcher: prefixMatcherIgnoreCase, + input: "PREFIXisHere", + wantMatch: true, + }, + { + desc: "prefix match failure with ignore case", + matcher: prefixMatcherIgnoreCase, + input: "not-PREFIX", + }, + { + desc: "suffix match success", + matcher: suffixMatcher, + input: "hereIsThesuffix", + wantMatch: true, + }, + { + desc: "suffix match failure", + matcher: suffixMatcher, + input: "suffix-is-not-here", + }, + { + desc: "suffix match success with ignore case", + matcher: suffixMatcherIgnoreCase, + input: "hereIsTheSuFFix", + wantMatch: true, + }, + { + desc: "suffix match failure with ignore case", + matcher: suffixMatcherIgnoreCase, + input: "SUFFIX-is-not-here", + }, + { + desc: "regex match success", + matcher: regexMatcher, + input: "goodregex", + wantMatch: true, + }, + { + desc: "regex match failure", + matcher: regexMatcher, + input: "regex-is-not-here", + }, + { + desc: "contains match success", + matcher: containsMatcher, + input: "IScontainsHERE", + wantMatch: true, + }, + { + desc: "contains match failure", + matcher: containsMatcher, + input: "con-tains-is-not-here", + }, + { + desc: "contains match success with ignore case", + matcher: containsMatcherIgnoreCase, + input: "isCONTAINShere", + wantMatch: true, + }, + { + desc: "contains match failure with ignore case", + matcher: containsMatcherIgnoreCase, + input: "CON-TAINS-is-not-here", + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if gotMatch := test.matcher.Match(test.input); gotMatch != test.wantMatch { + t.Errorf("StringMatcher.Match(%s) returned %v, want %v", test.input, gotMatch, test.wantMatch) + } + }) + } +} + +func newStringP(s string) *string { + return &s +} diff --git a/internal/xds/rbac/matchers.go b/internal/xds/rbac/matchers.go new file mode 100644 index 00000000000..28dabf46591 --- /dev/null +++ b/internal/xds/rbac/matchers.go @@ -0,0 +1,426 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rbac + +import ( + "errors" + "fmt" + "net" + "regexp" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3route_componentspb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + internalmatcher "google.golang.org/grpc/internal/xds/matcher" +) + +// matcher is an interface that takes data about incoming RPC's and returns +// whether it matches with whatever matcher implements this interface. +type matcher interface { + match(data *rpcData) bool +} + +// policyMatcher helps determine whether an incoming RPC call matches a policy. +// A policy is a logical role (e.g. Service Admin), which is comprised of +// permissions and principals. A principal is an identity (or identities) for a +// downstream subject which are assigned the policy (role), and a permission is +// an action(s) that a principal(s) can take. A policy matches if both a +// permission and a principal match, which will be determined by the child or +// permissions and principal matchers. policyMatcher implements the matcher +// interface. +type policyMatcher struct { + permissions *orMatcher + principals *orMatcher +} + +func newPolicyMatcher(policy *v3rbacpb.Policy) (*policyMatcher, error) { + permissions, err := matchersFromPermissions(policy.Permissions) + if err != nil { + return nil, err + } + principals, err := matchersFromPrincipals(policy.Principals) + if err != nil { + return nil, err + } + return &policyMatcher{ + permissions: &orMatcher{matchers: permissions}, + principals: &orMatcher{matchers: principals}, + }, nil +} + +func (pm *policyMatcher) match(data *rpcData) bool { + // A policy matches if and only if at least one of its permissions match the + // action taking place AND at least one if its principals match the + // downstream peer. + return pm.permissions.match(data) && pm.principals.match(data) +} + +// matchersFromPermissions takes a list of permissions (can also be +// a single permission, e.g. from a not matcher which is logically !permission) +// and returns a list of matchers which correspond to that permission. This will +// be called in many instances throughout the initial construction of the RBAC +// engine from the AND and OR matchers and also from the NOT matcher. +func matchersFromPermissions(permissions []*v3rbacpb.Permission) ([]matcher, error) { + var matchers []matcher + for _, permission := range permissions { + switch permission.GetRule().(type) { + case *v3rbacpb.Permission_AndRules: + mList, err := matchersFromPermissions(permission.GetAndRules().Rules) + if err != nil { + return nil, err + } + matchers = append(matchers, &andMatcher{matchers: mList}) + case *v3rbacpb.Permission_OrRules: + mList, err := matchersFromPermissions(permission.GetOrRules().Rules) + if err != nil { + return nil, err + } + matchers = append(matchers, &orMatcher{matchers: mList}) + case *v3rbacpb.Permission_Any: + matchers = append(matchers, &alwaysMatcher{}) + case *v3rbacpb.Permission_Header: + m, err := newHeaderMatcher(permission.GetHeader()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Permission_UrlPath: + m, err := newURLPathMatcher(permission.GetUrlPath()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Permission_DestinationIp: + // Due to this being on server side, the destination IP is the local + // IP. + m, err := newLocalIPMatcher(permission.GetDestinationIp()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Permission_DestinationPort: + matchers = append(matchers, newPortMatcher(permission.GetDestinationPort())) + case *v3rbacpb.Permission_NotRule: + mList, err := matchersFromPermissions([]*v3rbacpb.Permission{{Rule: permission.GetNotRule().Rule}}) + if err != nil { + return nil, err + } + matchers = append(matchers, ¬Matcher{matcherToNot: mList[0]}) + case *v3rbacpb.Permission_Metadata: + // Not supported in gRPC RBAC currently - a permission typed as + // Metadata in the initial config will be a no-op. + case *v3rbacpb.Permission_RequestedServerName: + // Not supported in gRPC RBAC currently - a permission typed as + // requested server name in the initial config will be a no-op. + } + } + return matchers, nil +} + +func matchersFromPrincipals(principals []*v3rbacpb.Principal) ([]matcher, error) { + var matchers []matcher + for _, principal := range principals { + switch principal.GetIdentifier().(type) { + case *v3rbacpb.Principal_AndIds: + mList, err := matchersFromPrincipals(principal.GetAndIds().Ids) + if err != nil { + return nil, err + } + matchers = append(matchers, &andMatcher{matchers: mList}) + case *v3rbacpb.Principal_OrIds: + mList, err := matchersFromPrincipals(principal.GetOrIds().Ids) + if err != nil { + return nil, err + } + matchers = append(matchers, &orMatcher{matchers: mList}) + case *v3rbacpb.Principal_Any: + matchers = append(matchers, &alwaysMatcher{}) + case *v3rbacpb.Principal_Authenticated_: + authenticatedMatcher, err := newAuthenticatedMatcher(principal.GetAuthenticated()) + if err != nil { + return nil, err + } + matchers = append(matchers, authenticatedMatcher) + case *v3rbacpb.Principal_DirectRemoteIp: + m, err := newRemoteIPMatcher(principal.GetDirectRemoteIp()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_Header: + // Do we need an error here? + m, err := newHeaderMatcher(principal.GetHeader()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_UrlPath: + m, err := newURLPathMatcher(principal.GetUrlPath()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_NotId: + mList, err := matchersFromPrincipals([]*v3rbacpb.Principal{{Identifier: principal.GetNotId().Identifier}}) + if err != nil { + return nil, err + } + matchers = append(matchers, ¬Matcher{matcherToNot: mList[0]}) + case *v3rbacpb.Principal_SourceIp: + // The source ip principal identifier is deprecated. Thus, a + // principal typed as a source ip in the identifier will be a no-op. + // The config should use DirectRemoteIp instead. + case *v3rbacpb.Principal_RemoteIp: + // RBAC in gRPC treats direct_remote_ip and remote_ip as logically + // equivalent, as per A41. + m, err := newRemoteIPMatcher(principal.GetRemoteIp()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_Metadata: + // Not supported in gRPC RBAC currently - a principal typed as + // Metadata in the initial config will be a no-op. + } + } + return matchers, nil +} + +// orMatcher is a matcher where it successfully matches if one of it's +// children successfully match. It also logically represents a principal or +// permission, but can also be it's own entity further down the tree of +// matchers. orMatcher implements the matcher interface. +type orMatcher struct { + matchers []matcher +} + +func (om *orMatcher) match(data *rpcData) bool { + // Range through child matchers and pass in data about incoming RPC, and + // only one child matcher has to match to be logically successful. + for _, m := range om.matchers { + if m.match(data) { + return true + } + } + return false +} + +// andMatcher is a matcher that is successful if every child matcher +// matches. andMatcher implements the matcher interface. +type andMatcher struct { + matchers []matcher +} + +func (am *andMatcher) match(data *rpcData) bool { + for _, m := range am.matchers { + if !m.match(data) { + return false + } + } + return true +} + +// alwaysMatcher is a matcher that will always match. This logically +// represents an any rule for a permission or a principal. alwaysMatcher +// implements the matcher interface. +type alwaysMatcher struct { +} + +func (am *alwaysMatcher) match(data *rpcData) bool { + return true +} + +// notMatcher is a matcher that nots an underlying matcher. notMatcher +// implements the matcher interface. +type notMatcher struct { + matcherToNot matcher +} + +func (nm *notMatcher) match(data *rpcData) bool { + return !nm.matcherToNot.match(data) +} + +// headerMatcher is a matcher that matches on incoming HTTP Headers present +// in the incoming RPC. headerMatcher implements the matcher interface. +type headerMatcher struct { + matcher internalmatcher.HeaderMatcher +} + +func newHeaderMatcher(headerMatcherConfig *v3route_componentspb.HeaderMatcher) (*headerMatcher, error) { + var m internalmatcher.HeaderMatcher + switch headerMatcherConfig.HeaderMatchSpecifier.(type) { + case *v3route_componentspb.HeaderMatcher_ExactMatch: + m = internalmatcher.NewHeaderExactMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetExactMatch()) + case *v3route_componentspb.HeaderMatcher_SafeRegexMatch: + regex, err := regexp.Compile(headerMatcherConfig.GetSafeRegexMatch().Regex) + if err != nil { + return nil, err + } + m = internalmatcher.NewHeaderRegexMatcher(headerMatcherConfig.Name, regex) + case *v3route_componentspb.HeaderMatcher_RangeMatch: + m = internalmatcher.NewHeaderRangeMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetRangeMatch().Start, headerMatcherConfig.GetRangeMatch().End) + case *v3route_componentspb.HeaderMatcher_PresentMatch: + m = internalmatcher.NewHeaderPresentMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetPresentMatch()) + case *v3route_componentspb.HeaderMatcher_PrefixMatch: + m = internalmatcher.NewHeaderPrefixMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetPrefixMatch()) + case *v3route_componentspb.HeaderMatcher_SuffixMatch: + m = internalmatcher.NewHeaderSuffixMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetSuffixMatch()) + case *v3route_componentspb.HeaderMatcher_ContainsMatch: + m = internalmatcher.NewHeaderContainsMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetContainsMatch()) + default: + return nil, errors.New("unknown header matcher type") + } + if headerMatcherConfig.InvertMatch { + m = internalmatcher.NewInvertMatcher(m) + } + return &headerMatcher{matcher: m}, nil +} + +func (hm *headerMatcher) match(data *rpcData) bool { + return hm.matcher.Match(data.md) +} + +// urlPathMatcher matches on the URL Path of the incoming RPC. In gRPC, this +// logically maps to the full method name the RPC is calling on the server side. +// urlPathMatcher implements the matcher interface. +type urlPathMatcher struct { + stringMatcher internalmatcher.StringMatcher +} + +func newURLPathMatcher(pathMatcher *v3matcherpb.PathMatcher) (*urlPathMatcher, error) { + stringMatcher, err := internalmatcher.StringMatcherFromProto(pathMatcher.GetPath()) + if err != nil { + return nil, err + } + return &urlPathMatcher{stringMatcher: stringMatcher}, nil +} + +func (upm *urlPathMatcher) match(data *rpcData) bool { + return upm.stringMatcher.Match(data.fullMethod) +} + +// remoteIPMatcher and localIPMatcher both are matchers that match against +// a CIDR Range. Two different matchers are needed as the remote and destination +// ip addresses come from different parts of the data about incoming RPC's +// passed in. Matching a CIDR Range means to determine whether the IP Address +// falls within the CIDR Range or not. They both implement the matcher +// interface. +type remoteIPMatcher struct { + // ipNet represents the CidrRange that this matcher was configured with. + // This is what will remote and destination IP's will be matched against. + ipNet *net.IPNet +} + +func newRemoteIPMatcher(cidrRange *v3corepb.CidrRange) (*remoteIPMatcher, error) { + // Convert configuration to a cidrRangeString, as Go standard library has + // methods that parse cidr string. + cidrRangeString := fmt.Sprintf("%s/%d", cidrRange.AddressPrefix, cidrRange.PrefixLen.Value) + _, ipNet, err := net.ParseCIDR(cidrRangeString) + if err != nil { + return nil, err + } + return &remoteIPMatcher{ipNet: ipNet}, nil +} + +func (sim *remoteIPMatcher) match(data *rpcData) bool { + return sim.ipNet.Contains(net.IP(net.ParseIP(data.peerInfo.Addr.String()))) +} + +type localIPMatcher struct { + ipNet *net.IPNet +} + +func newLocalIPMatcher(cidrRange *v3corepb.CidrRange) (*localIPMatcher, error) { + cidrRangeString := fmt.Sprintf("%s/%d", cidrRange.AddressPrefix, cidrRange.PrefixLen.Value) + _, ipNet, err := net.ParseCIDR(cidrRangeString) + if err != nil { + return nil, err + } + return &localIPMatcher{ipNet: ipNet}, nil +} + +func (dim *localIPMatcher) match(data *rpcData) bool { + return dim.ipNet.Contains(net.IP(net.ParseIP(data.localAddr.String()))) +} + +// portMatcher matches on whether the destination port of the RPC matches the +// destination port this matcher was instantiated with. portMatcher +// implements the matcher interface. +type portMatcher struct { + destinationPort uint32 +} + +func newPortMatcher(destinationPort uint32) *portMatcher { + return &portMatcher{destinationPort: destinationPort} +} + +func (pm *portMatcher) match(data *rpcData) bool { + return data.destinationPort == pm.destinationPort +} + +// authenticatedMatcher matches on the name of the Principal. If set, the URI +// SAN or DNS SAN in that order is used from the certificate, otherwise the +// subject field is used. If unset, it applies to any user that is +// authenticated. authenticatedMatcher implements the matcher interface. +type authenticatedMatcher struct { + stringMatcher *internalmatcher.StringMatcher +} + +func newAuthenticatedMatcher(authenticatedMatcherConfig *v3rbacpb.Principal_Authenticated) (*authenticatedMatcher, error) { + // Represents this line in the RBAC documentation = "If unset, it applies to + // any user that is authenticated" (see package-level comments). + if authenticatedMatcherConfig.PrincipalName == nil { + return &authenticatedMatcher{}, nil + } + stringMatcher, err := internalmatcher.StringMatcherFromProto(authenticatedMatcherConfig.PrincipalName) + if err != nil { + return nil, err + } + return &authenticatedMatcher{stringMatcher: &stringMatcher}, nil +} + +func (am *authenticatedMatcher) match(data *rpcData) bool { + // Represents this line in the RBAC documentation = "If unset, it applies to + // any user that is authenticated" (see package-level comments). An + // authenticated downstream in a stateful TLS connection will have to + // provide a certificate to prove their identity. Thus, you can simply check + // if there is a certificate present. + if am.stringMatcher == nil { + return len(data.certs) != 0 + } + // "If there is no client certificate (thus no SAN nor Subject), check if "" + // (empty string) matches. If it matches, the principal_name is said to + // match" - A41 + if len(data.certs) == 0 { + return am.stringMatcher.Match("") + } + cert := data.certs[0] + // The order of matching as per the RBAC documentation (see package-level comments) + // is as follows: URI SANs, DNS SANs, and then subject name. + for _, uriSAN := range cert.URIs { + if am.stringMatcher.Match(uriSAN.String()) { + return true + } + } + for _, dnsSAN := range cert.DNSNames { + if am.stringMatcher.Match(dnsSAN) { + return true + } + } + return am.stringMatcher.Match(cert.Subject.String()) +} diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go new file mode 100644 index 00000000000..a25f9cfdeef --- /dev/null +++ b/internal/xds/rbac/rbac_engine.go @@ -0,0 +1,225 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package rbac provides service-level and method-level access control for a +// service. See +// https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/rbac/v3/rbac.proto#role-based-access-control-rbac +// for documentation. +package rbac + +import ( + "context" + "crypto/x509" + "errors" + "fmt" + "net" + "strconv" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/transport" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +const logLevel = 2 + +var logger = grpclog.Component("rbac") + +var getConnection = transport.GetConnection + +// ChainEngine represents a chain of RBAC Engines, used to make authorization +// decisions on incoming RPCs. +type ChainEngine struct { + chainedEngines []*engine +} + +// NewChainEngine returns a chain of RBAC engines, used to make authorization +// decisions on incoming RPCs. Returns a non-nil error for invalid policies. +func NewChainEngine(policies []*v3rbacpb.RBAC) (*ChainEngine, error) { + engines := make([]*engine, 0, len(policies)) + for _, policy := range policies { + engine, err := newEngine(policy) + if err != nil { + return nil, err + } + engines = append(engines, engine) + } + return &ChainEngine{chainedEngines: engines}, nil +} + +// IsAuthorized determines if an incoming RPC is authorized based on the chain of RBAC +// engines and their associated actions. +// +// Errors returned by this function are compatible with the status package. +func (cre *ChainEngine) IsAuthorized(ctx context.Context) error { + // This conversion step (i.e. pulling things out of ctx) can be done once, + // and then be used for the whole chain of RBAC Engines. + rpcData, err := newRPCData(ctx) + if err != nil { + logger.Errorf("newRPCData: %v", err) + return status.Errorf(codes.Internal, "gRPC RBAC: %v", err) + } + for _, engine := range cre.chainedEngines { + matchingPolicyName, ok := engine.findMatchingPolicy(rpcData) + if logger.V(logLevel) && ok { + logger.Infof("incoming RPC matched to policy %v in engine with action %v", matchingPolicyName, engine.action) + } + + switch { + case engine.action == v3rbacpb.RBAC_ALLOW && !ok: + return status.Errorf(codes.PermissionDenied, "incoming RPC did not match an allow policy") + case engine.action == v3rbacpb.RBAC_DENY && ok: + return status.Errorf(codes.PermissionDenied, "incoming RPC matched a deny policy %q", matchingPolicyName) + } + // Every policy in the engine list must be queried. Thus, iterate to the + // next policy. + } + // If the incoming RPC gets through all of the engines successfully (i.e. + // doesn't not match an allow or match a deny engine), the RPC is authorized + // to proceed. + return nil +} + +// engine is used for matching incoming RPCs to policies. +type engine struct { + policies map[string]*policyMatcher + // action must be ALLOW or DENY. + action v3rbacpb.RBAC_Action +} + +// newEngine creates an RBAC Engine based on the contents of policy. Returns a +// non-nil error if the policy is invalid. +func newEngine(config *v3rbacpb.RBAC) (*engine, error) { + a := *config.Action.Enum() + if a != v3rbacpb.RBAC_ALLOW && a != v3rbacpb.RBAC_DENY { + return nil, fmt.Errorf("unsupported action %s", config.Action) + } + + policies := make(map[string]*policyMatcher, len(config.Policies)) + for name, policy := range config.Policies { + matcher, err := newPolicyMatcher(policy) + if err != nil { + return nil, err + } + policies[name] = matcher + } + return &engine{ + policies: policies, + action: a, + }, nil +} + +// findMatchingPolicy determines if an incoming RPC matches a policy. On a +// successful match, it returns the name of the matching policy and a true bool +// to specify that there was a matching policy found. It returns false in +// the case of not finding a matching policy. +func (r *engine) findMatchingPolicy(rpcData *rpcData) (string, bool) { + for policy, matcher := range r.policies { + if matcher.match(rpcData) { + return policy, true + } + } + return "", false +} + +// newRPCData takes an incoming context (should be a context representing state +// needed for server RPC Call with metadata, peer info (used for source ip/port +// and TLS information) and connection (used for destination ip/port) piped into +// it) and the method name of the Service being called server side and populates +// an rpcData struct ready to be passed to the RBAC Engine to find a matching +// policy. +func newRPCData(ctx context.Context) (*rpcData, error) { + // The caller should populate all of these fields (i.e. for empty headers, + // pipe an empty md into context). + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, errors.New("missing metadata in incoming context") + } + // ":method can be hard-coded to POST if unavailable" - A41 + md[":method"] = []string{"POST"} + // "If the transport exposes TE in Metadata, then RBAC must special-case the + // header to treat it as not present." - A41 + delete(md, "TE") + + pi, ok := peer.FromContext(ctx) + if !ok { + return nil, errors.New("missing peer info in incoming context") + } + + // The methodName will be available in the passed in ctx from a unary or streaming + // interceptor, as grpc.Server pipes in a transport stream which contains the methodName + // into contexts available in both unary or streaming interceptors. + mn, ok := grpc.Method(ctx) + if !ok { + return nil, errors.New("missing method in incoming context") + } + + // The connection is needed in order to find the destination address and + // port of the incoming RPC Call. + conn := getConnection(ctx) + if conn == nil { + return nil, errors.New("missing connection in incoming context") + } + _, dPort, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + return nil, fmt.Errorf("error parsing local address: %v", err) + } + dp, err := strconv.ParseUint(dPort, 10, 32) + if err != nil { + return nil, fmt.Errorf("error parsing local address: %v", err) + } + + var peerCertificates []*x509.Certificate + if pi.AuthInfo != nil { + tlsInfo, ok := pi.AuthInfo.(credentials.TLSInfo) + if ok { + peerCertificates = tlsInfo.State.PeerCertificates + } + } + + return &rpcData{ + md: md, + peerInfo: pi, + fullMethod: mn, + destinationPort: uint32(dp), + localAddr: conn.LocalAddr(), + certs: peerCertificates, + }, nil +} + +// rpcData wraps data pulled from an incoming RPC that the RBAC engine needs to +// find a matching policy. +type rpcData struct { + // md is the HTTP Headers that are present in the incoming RPC. + md metadata.MD + // peerInfo is information about the downstream peer. + peerInfo *peer.Peer + // fullMethod is the method name being called on the upstream service. + fullMethod string + // destinationPort is the port that the RPC is being sent to on the + // server. + destinationPort uint32 + // localAddr is the address that the RPC is being sent to. + localAddr net.Addr + // certs are the certificates presented by the peer during a TLS + // handshake. + certs []*x509.Certificate +} diff --git a/internal/xds/rbac/rbac_engine_test.go b/internal/xds/rbac/rbac_engine_test.go new file mode 100644 index 00000000000..17832458209 --- /dev/null +++ b/internal/xds/rbac/rbac_engine_test.go @@ -0,0 +1,1007 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rbac + +import ( + "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "net" + "net/url" + "testing" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type addr struct { + ipAddress string +} + +func (addr) Network() string { return "" } +func (a *addr) String() string { return a.ipAddress } + +// TestNewChainEngine tests the construction of the ChainEngine. Due to some +// types of RBAC configuration being logically wrong and returning an error +// rather than successfully constructing the RBAC Engine, this test tests both +// RBAC Configurations deemed successful and also RBAC Configurations that will +// raise errors. +func (s) TestNewChainEngine(t *testing.T) { + tests := []struct { + name string + policies []*v3rbacpb.RBAC + wantErr bool + }{ + { + name: "SuccessCaseAnyMatchSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SuccessCaseAnyMatchMultiple", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + { + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SuccessCaseSimplePolicySingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 8080}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + // SuccessCaseSimplePolicyTwoPolicies tests the construction of the + // chained engines in the case where there are two policies in a list, + // one with an allow policy and one with a deny policy. A situation + // where two policies (allow and deny) is a very common use case for + // this API, and should successfully build. + { + name: "SuccessCaseSimplePolicyTwoPolicies", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 8080}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + { + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 8080}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SuccessCaseEnvoyExampleSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "service-admin": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "cluster.local/ns/default/sa/admin"}}}}}, + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "cluster.local/ns/default/sa/superuser"}}}}}, + }, + }, + "product-viewer": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/products"}}}}}}, + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 80}}, + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 443}}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SourceIpMatcherSuccessSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + }, + { + name: "SourceIpMatcherFailureSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "not a correct address", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "DestinationIpMatcherSuccess", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-destination-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationIp{DestinationIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "DestinationIpMatcherFailure", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-destination-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationIp{DestinationIp: &v3corepb.CidrRange{AddressPrefix: "not a correct address", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "MatcherToNotPolicy", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "not-secret-content": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_NotRule{NotRule: &v3rbacpb.Permission{Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/secret-content"}}}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "MatcherToNotPrinicipal", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "not-from-certain-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_NotId{NotId: &v3rbacpb.Principal{Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}}}, + }, + }, + }, + }, + }, + }, + // PrinicpalProductViewer tests the construction of a chained engine + // with a policy that allows any downstream to send a GET request on a + // certain path. + { + name: "PrincipalProductViewer", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "product-viewer": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + { + Identifier: &v3rbacpb.Principal_AndIds{AndIds: &v3rbacpb.Principal_Set{Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/books"}}}}}}, + {Identifier: &v3rbacpb.Principal_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/cars"}}}}}}, + }, + }}}, + }}}, + }, + }, + }, + }, + }, + }, + }, + // Certain Headers tests the construction of a chained engine with a + // policy that allows any downstream to send an HTTP request with + // certain headers. + { + name: "CertainHeaders", + policies: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-headers": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + { + Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{SafeRegexMatch: &v3matcherpb.RegexMatcher{Regex: "GET"}}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_RangeMatch{RangeMatch: &v3typepb.Int64Range{ + Start: 0, + End: 64, + }}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PresentMatch{PresentMatch: true}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ContainsMatch{ContainsMatch: "GET"}}}}, + }}}, + }, + }, + }, + }, + }, + }, + }, + { + name: "LogAction", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_LOG, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ActionNotSpecified", + policies: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if _, err := NewChainEngine(test.policies); (err != nil) != test.wantErr { + t.Fatalf("NewChainEngine(%+v) returned err: %v, wantErr: %v", test.policies, err, test.wantErr) + } + }) + } +} + +// TestChainEngine tests the chain of RBAC Engines by configuring the chain of +// engines in a certain way in different scenarios. After configuring the chain +// of engines in a certain way, this test pings the chain of engines with +// different types of data representing incoming RPC's (piped into a context), +// and verifies that it works as expected. +func (s) TestChainEngine(t *testing.T) { + defer func(gc func(ctx context.Context) net.Conn) { + getConnection = gc + }(getConnection) + tests := []struct { + name string + rbacConfigs []*v3rbacpb.RBAC + rbacQueries []struct { + rpcData *rpcData + wantStatusCode codes.Code + } + }{ + // SuccessCaseAnyMatch tests a single RBAC Engine instantiated with + // a config with a policy with any rules for both permissions and + // principals, meaning that any data about incoming RPC's that the RBAC + // Engine is queried with should match that policy. + { + name: "SuccessCaseAnyMatch", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + { + rpcData: &rpcData{ + fullMethod: "some method", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + }, + }, + // SuccessCaseSimplePolicy is a test that tests a single policy + // that only allows an rpc to proceed if the rpc is calling with a certain + // path. + { + name: "SuccessCaseSimplePolicy", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This RPC should match with the local host fan policy. Thus, + // this RPC should be allowed to proceed. + { + rpcData: &rpcData{ + fullMethod: "localhost-fan-page", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + + // This RPC shouldn't match with the local host fan policy. Thus, + // this rpc shouldn't be allowed to proceed. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // SuccessCaseEnvoyExample is a test based on the example provided + // in the EnvoyProxy docs. The RBAC Config contains two policies, + // service admin and product viewer, that provides an example of a real + // RBAC Config that might be configured for a given for a given backend + // service. + { + name: "SuccessCaseEnvoyExample", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "service-admin": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "//cluster.local/ns/default/sa/admin"}}}}}, + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "//cluster.local/ns/default/sa/superuser"}}}}}, + }, + }, + "product-viewer": { + Permissions: []*v3rbacpb.Permission{ + { + Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/products"}}}}}}, + }, + }, + }, + }, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the service admin + // policy. + { + rpcData: &rpcData{ + fullMethod: "some method", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + { + URIs: []*url.URL{ + { + Host: "cluster.local", + Path: "/ns/default/sa/admin", + }, + }, + }, + }, + }, + }, + }, + }, + wantStatusCode: codes.OK, + }, + // These incoming RPC calls should not match any policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + { + rpcData: &rpcData{ + fullMethod: "get-product-list", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + { + Subject: pkix.Name{ + CommonName: "localhost", + }, + }, + }, + }, + }, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + { + name: "NotMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "not-secret-content": { + Permissions: []*v3rbacpb.Permission{ + { + Rule: &v3rbacpb.Permission_NotRule{ + NotRule: &v3rbacpb.Permission{Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/secret-content"}}}}}}, + }, + }, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the not-secret-content policy. + { + rpcData: &rpcData{ + fullMethod: "/regular-content", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This incoming RPC Call shouldn't match with the not-secret-content-policy. + { + rpcData: &rpcData{ + fullMethod: "/secret-content", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + { + name: "DirectRemoteIpMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-direct-remote-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the certain-direct-remote-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This incoming RPC Call shouldn't match with the certain-direct-remote-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // This test tests a RBAC policy configured with a remote-ip policy. + // This should be logically equivalent to configuring a Engine with a + // direct-remote-ip policy, as per A41 - "allow equating RBAC's + // direct_remote_ip and remote_ip." + { + name: "RemoteIpMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-remote-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_RemoteIp{RemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the certain-remote-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This incoming RPC Call shouldn't match with the certain-remote-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + { + name: "DestinationIpMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-destination-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationIp{DestinationIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call shouldn't match with the + // certain-destination-ip policy, as the test listens on local + // host. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // AllowAndDenyPolicy tests a policy with an allow (on path) and + // deny (on port) policy chained together. This represents how a user + // configured interceptor would use this, and also is a potential + // configuration for a dynamic xds interceptor. + { + name: "AllowAndDenyPolicy", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + Action: v3rbacpb.RBAC_ALLOW, + }, + { + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + Action: v3rbacpb.RBAC_DENY, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This RPC should match with the allow policy, and shouldn't + // match with the deny and thus should be allowed to proceed. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This RPC should match with both the allow policy and deny policy + // and thus shouldn't be allowed to proceed as matched with deny. + { + rpcData: &rpcData{ + fullMethod: "localhost-fan-page", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + // This RPC shouldn't match with either policy, and thus + // shouldn't be allowed to proceed as didn't match with allow. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + // This RPC shouldn't match with allow, match with deny, and + // thus shouldn't be allowed to proceed. + { + rpcData: &rpcData{ + fullMethod: "localhost-fan-page", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // This test tests that when there are no SANs or Subject's + // distinguished name in incoming RPC's, that authenticated matchers + // match against the empty string. + { + name: "default-matching-no-credentials", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "service-admin": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: ""}}}}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the service admin + // policy. No authentication info is provided, so the + // authenticated matcher should match to the string matcher on + // the empty string, matching to the service-admin policy. + { + rpcData: &rpcData{ + fullMethod: "some method", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Instantiate the chainedRBACEngine with different configurations that are + // interesting to test and to query. + cre, err := NewChainEngine(test.rbacConfigs) + if err != nil { + t.Fatalf("Error constructing RBAC Engine: %v", err) + } + // Query the created chain of RBAC Engines with different args to see + // if the chain of RBAC Engines configured as such works as intended. + for _, data := range test.rbacQueries { + func() { + // Construct the context with three data points that have enough + // information to represent incoming RPC's. This will be how a + // user uses this API. A user will have to put MD, PeerInfo, and + // the connection the RPC is sent on in the context. + ctx := metadata.NewIncomingContext(context.Background(), data.rpcData.md) + + // Make a TCP connection with a certain destination port. The + // address/port of this connection will be used to populate the + // destination ip/port in RPCData struct. This represents what + // the user of ChainEngine will have to place into + // context, as this is only way to get destination ip and port. + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error listening: %v", err) + } + defer lis.Close() + connCh := make(chan net.Conn, 1) + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + connCh <- conn + }() + _, err = net.Dial("tcp", lis.Addr().String()) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + conn := <-connCh + defer conn.Close() + getConnection = func(context.Context) net.Conn { + return conn + } + ctx = peer.NewContext(ctx, data.rpcData.peerInfo) + stream := &ServerTransportStreamWithMethod{ + method: data.rpcData.fullMethod, + } + + ctx = grpc.NewContextWithServerTransportStream(ctx, stream) + err = cre.IsAuthorized(ctx) + if gotCode := status.Code(err); gotCode != data.wantStatusCode { + t.Fatalf("IsAuthorized(%+v, %+v) returned (%+v), want(%+v)", ctx, data.rpcData.fullMethod, gotCode, data.wantStatusCode) + } + }() + } + }) + } +} + +type ServerTransportStreamWithMethod struct { + method string +} + +func (sts *ServerTransportStreamWithMethod) Method() string { + return sts.method +} + +func (sts *ServerTransportStreamWithMethod) SetHeader(md metadata.MD) error { + return nil +} + +func (sts *ServerTransportStreamWithMethod) SendHeader(md metadata.MD) error { + return nil +} + +func (sts *ServerTransportStreamWithMethod) SetTrailer(md metadata.MD) error { + return nil +} diff --git a/internal/xds_handshake_cluster.go b/internal/xds_handshake_cluster.go new file mode 100644 index 00000000000..3677c3f04f8 --- /dev/null +++ b/internal/xds_handshake_cluster.go @@ -0,0 +1,40 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package internal + +import ( + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/resolver" +) + +// handshakeClusterNameKey is the type used as the key to store cluster name in +// the Attributes field of resolver.Address. +type handshakeClusterNameKey struct{} + +// SetXDSHandshakeClusterName returns a copy of addr in which the Attributes field +// is updated with the cluster name. +func SetXDSHandshakeClusterName(addr resolver.Address, clusterName string) resolver.Address { + addr.Attributes = addr.Attributes.WithValues(handshakeClusterNameKey{}, clusterName) + return addr +} + +// GetXDSHandshakeClusterName returns cluster name stored in attr. +func GetXDSHandshakeClusterName(attr *attributes.Attributes) (string, bool) { + v := attr.Value(handshakeClusterNameKey{}) + name, ok := v.(string) + return name, ok +} diff --git a/interop/client/client.go b/interop/client/client.go index 8854ed2d76e..f41f56fbbd5 100644 --- a/interop/client/client.go +++ b/interop/client/client.go @@ -20,9 +20,13 @@ package main import ( + "crypto/tls" + "crypto/x509" "flag" + "io/ioutil" "net" "strconv" + "time" "google.golang.org/grpc" _ "google.golang.org/grpc/balancer/grpclb" @@ -34,6 +38,7 @@ import ( "google.golang.org/grpc/interop" "google.golang.org/grpc/resolver" "google.golang.org/grpc/testdata" + _ "google.golang.org/grpc/xds/googledirectpath" testgrpc "google.golang.org/grpc/interop/grpc_testing" ) @@ -44,19 +49,24 @@ const ( ) var ( - caFile = flag.String("ca_file", "", "The file containning the CA root cert file") - useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true") - useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)") - customCredentialsType = flag.String("custom_credentials_type", "", "Custom creds to use, excluding TLS or ALTS") - altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address") - testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") - serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file") - oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens") - defaultServiceAccount = flag.String("default_service_account", "", "Email of GCE default service account") - serverHost = flag.String("server_host", "localhost", "The server host name") - serverPort = flag.Int("server_port", 10000, "The server port number") - tlsServerName = flag.String("server_host_override", "", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") - testCase = flag.String("test_case", "large_unary", + caFile = flag.String("ca_file", "", "The file containning the CA root cert file") + useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true") + useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)") + customCredentialsType = flag.String("custom_credentials_type", "", "Custom creds to use, excluding TLS or ALTS") + altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address") + testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") + serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file") + oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens") + defaultServiceAccount = flag.String("default_service_account", "", "Email of GCE default service account") + serverHost = flag.String("server_host", "localhost", "The server host name") + serverPort = flag.Int("server_port", 10000, "The server port number") + serviceConfigJSON = flag.String("service_config_json", "", "Disables service config lookups and sets the provided string as the default service config.") + soakIterations = flag.Int("soak_iterations", 10, "The number of iterations to use for the two soak tests: rpc_soak and channel_soak") + soakMaxFailures = flag.Int("soak_max_failures", 0, "The number of iterations in soak tests that are allowed to fail (either due to non-OK status code or exceeding the per-iteration max acceptable latency).") + soakPerIterationMaxAcceptableLatencyMs = flag.Int("soak_per_iteration_max_acceptable_latency_ms", 1000, "The number of milliseconds a single iteration in the two soak tests (rpc_soak and channel_soak) should take.") + soakOverallTimeoutSeconds = flag.Int("soak_overall_timeout_seconds", 10, "The overall number of seconds after which a soak test should stop and fail, if the desired number of iterations have not yet completed.") + tlsServerName = flag.String("server_host_override", "", "The server name used to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") + testCase = flag.String("test_case", "large_unary", `Configure different test cases. Valid options are: empty_unary : empty (zero bytes) request and response; large_unary : single request and (large) response; @@ -126,26 +136,32 @@ func main() { } resolver.SetDefaultScheme("dns") - serverAddr := net.JoinHostPort(*serverHost, strconv.Itoa(*serverPort)) + serverAddr := *serverHost + if *serverPort != 0 { + serverAddr = net.JoinHostPort(*serverHost, strconv.Itoa(*serverPort)) + } var opts []grpc.DialOption switch credsChosen { case credsTLS: - var sn string - if *tlsServerName != "" { - sn = *tlsServerName - } - var creds credentials.TransportCredentials + var roots *x509.CertPool if *testCA { - var err error if *caFile == "" { *caFile = testdata.Path("ca.pem") } - creds, err = credentials.NewClientTLSFromFile(*caFile, sn) + b, err := ioutil.ReadFile(*caFile) if err != nil { - logger.Fatalf("Failed to create TLS credentials %v", err) + logger.Fatalf("Failed to read root certificate file %q: %v", *caFile, err) + } + roots = x509.NewCertPool() + if !roots.AppendCertsFromPEM(b) { + logger.Fatalf("Failed to append certificates: %s", string(b)) } + } + var creds credentials.TransportCredentials + if *tlsServerName != "" { + creds = credentials.NewClientTLSFromCert(roots, *tlsServerName) } else { - creds = credentials.NewClientTLSFromCert(nil, sn) + creds = credentials.NewTLS(&tls.Config{RootCAs: roots}) } opts = append(opts, grpc.WithTransportCredentials(creds)) case credsALTS: @@ -183,7 +199,9 @@ func main() { opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(interop.GetToken(*serviceAccountKeyFile, *oauthScope)))) } } - opts = append(opts, grpc.WithBlock()) + if len(*serviceConfigJSON) > 0 { + opts = append(opts, grpc.WithDisableServiceConfig(), grpc.WithDefaultServiceConfig(*serviceConfigJSON)) + } conn, err := grpc.Dial(serverAddr, opts...) if err != nil { logger.Fatalf("Fail to dial: %v", err) @@ -278,6 +296,12 @@ func main() { case "pick_first_unary": interop.DoPickFirstUnary(tc) logger.Infoln("PickFirstUnary done") + case "rpc_soak": + interop.DoSoakTest(tc, serverAddr, opts, false /* resetChannel */, *soakIterations, *soakMaxFailures, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Now().Add(time.Duration(*soakOverallTimeoutSeconds)*time.Second)) + logger.Infoln("RpcSoak done") + case "channel_soak": + interop.DoSoakTest(tc, serverAddr, opts, true /* resetChannel */, *soakIterations, *soakMaxFailures, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Now().Add(time.Duration(*soakOverallTimeoutSeconds)*time.Second)) + logger.Infoln("ChannelSoak done") default: logger.Fatal("Unsupported test case: ", *testCase) } diff --git a/interop/grpc_testing/benchmark_service_grpc.pb.go b/interop/grpc_testing/benchmark_service_grpc.pb.go index 1dcba4587d2..f4e4436e97e 100644 --- a/interop/grpc_testing/benchmark_service_grpc.pb.go +++ b/interop/grpc_testing/benchmark_service_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/testing/benchmark_service.proto package grpc_testing diff --git a/interop/grpc_testing/report_qps_scenario_service_grpc.pb.go b/interop/grpc_testing/report_qps_scenario_service_grpc.pb.go index b0fe8c8f5ee..4bf3fce68ab 100644 --- a/interop/grpc_testing/report_qps_scenario_service_grpc.pb.go +++ b/interop/grpc_testing/report_qps_scenario_service_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/testing/report_qps_scenario_service.proto package grpc_testing diff --git a/interop/grpc_testing/test_grpc.pb.go b/interop/grpc_testing/test_grpc.pb.go index ad5310aed62..137a1e98ce6 100644 --- a/interop/grpc_testing/test_grpc.pb.go +++ b/interop/grpc_testing/test_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/testing/test.proto package grpc_testing diff --git a/interop/grpc_testing/worker_service_grpc.pb.go b/interop/grpc_testing/worker_service_grpc.pb.go index cc49b22b926..a97366df09a 100644 --- a/interop/grpc_testing/worker_service_grpc.pb.go +++ b/interop/grpc_testing/worker_service_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/testing/worker_service.proto package grpc_testing diff --git a/interop/grpclb_fallback/client.go b/interop/grpclb_fallback/client_linux.go similarity index 99% rename from interop/grpclb_fallback/client.go rename to interop/grpclb_fallback/client_linux.go index 61b2fae6968..c9b25a894b3 100644 --- a/interop/grpclb_fallback/client.go +++ b/interop/grpclb_fallback/client_linux.go @@ -1,5 +1,3 @@ -// +build linux,!appengine - /* * * Copyright 2019 gRPC authors. diff --git a/interop/test_utils.go b/interop/test_utils.go index cbcbcc4da17..19a5c1f7cd3 100644 --- a/interop/test_utils.go +++ b/interop/test_utils.go @@ -20,10 +20,12 @@ package interop import ( + "bytes" "context" "fmt" "io" "io/ioutil" + "os" "strings" "time" @@ -31,6 +33,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/grpc" + "google.golang.org/grpc/benchmark/stats" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" @@ -673,6 +676,91 @@ func DoPickFirstUnary(tc testgrpc.TestServiceClient) { } } +func doOneSoakIteration(ctx context.Context, tc testgrpc.TestServiceClient, resetChannel bool, serverAddr string, dopts []grpc.DialOption) (latency time.Duration, err error) { + start := time.Now() + client := tc + if resetChannel { + var conn *grpc.ClientConn + conn, err = grpc.Dial(serverAddr, dopts...) + if err != nil { + return + } + defer conn.Close() + client = testgrpc.NewTestServiceClient(conn) + } + // per test spec, don't include channel shutdown in latency measurement + defer func() { latency = time.Since(start) }() + // do a large-unary RPC + pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseSize: int32(largeRespSize), + Payload: pl, + } + var reply *testpb.SimpleResponse + reply, err = client.UnaryCall(ctx, req) + if err != nil { + err = fmt.Errorf("/TestService/UnaryCall RPC failed: %s", err) + return + } + t := reply.GetPayload().GetType() + s := len(reply.GetPayload().GetBody()) + if t != testpb.PayloadType_COMPRESSABLE || s != largeRespSize { + err = fmt.Errorf("got the reply with type %d len %d; want %d, %d", t, s, testpb.PayloadType_COMPRESSABLE, largeRespSize) + return + } + return +} + +// DoSoakTest runs large unary RPCs in a loop for a configurable number of times, with configurable failure thresholds. +// If resetChannel is false, then each RPC will be performed on tc. Otherwise, each RPC will be performed on a new +// stub that is created with the provided server address and dial options. +func DoSoakTest(tc testgrpc.TestServiceClient, serverAddr string, dopts []grpc.DialOption, resetChannel bool, soakIterations int, maxFailures int, perIterationMaxAcceptableLatency time.Duration, overallDeadline time.Time) { + start := time.Now() + ctx, cancel := context.WithDeadline(context.Background(), overallDeadline) + defer cancel() + iterationsDone := 0 + totalFailures := 0 + hopts := stats.HistogramOptions{ + NumBuckets: 20, + GrowthFactor: 1, + BaseBucketSize: 1, + MinValue: 0, + } + h := stats.NewHistogram(hopts) + for i := 0; i < soakIterations; i++ { + if time.Now().After(overallDeadline) { + break + } + iterationsDone++ + latency, err := doOneSoakIteration(ctx, tc, resetChannel, serverAddr, dopts) + latencyMs := int64(latency / time.Millisecond) + h.Add(latencyMs) + if err != nil { + totalFailures++ + fmt.Fprintf(os.Stderr, "soak iteration: %d elapsed_ms: %d failed: %s\n", i, latencyMs, err) + continue + } + if latency > perIterationMaxAcceptableLatency { + totalFailures++ + fmt.Fprintf(os.Stderr, "soak iteration: %d elapsed_ms: %d exceeds max acceptable latency: %d\n", i, latencyMs, perIterationMaxAcceptableLatency.Milliseconds()) + continue + } + fmt.Fprintf(os.Stderr, "soak iteration: %d elapsed_ms: %d succeeded\n", i, latencyMs) + } + var b bytes.Buffer + h.Print(&b) + fmt.Fprintln(os.Stderr, "Histogram of per-iteration latencies in milliseconds:") + fmt.Fprintln(os.Stderr, b.String()) + fmt.Fprintf(os.Stderr, "soak test ran: %d / %d iterations. total failures: %d. max failures threshold: %d. See breakdown above for which iterations succeeded, failed, and why for more info.\n", iterationsDone, soakIterations, totalFailures, maxFailures) + if iterationsDone < soakIterations { + logger.Fatalf("soak test consumed all %f seconds of time and quit early, only having ran %d out of desired %d iterations.", overallDeadline.Sub(start).Seconds(), iterationsDone, soakIterations) + } + if totalFailures > maxFailures { + logger.Fatalf("soak test total failures: %d exceeds max failures threshold: %d.", totalFailures, maxFailures) + } +} + type testServer struct { testgrpc.UnimplementedTestServiceServer } diff --git a/interop/xds/client/Dockerfile b/interop/xds/client/Dockerfile new file mode 100644 index 00000000000..533bb6adb3e --- /dev/null +++ b/interop/xds/client/Dockerfile @@ -0,0 +1,37 @@ +# Copyright 2021 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dockerfile for building the xDS interop client. To build the image, run the +# following command from grpc-go directory: +# docker build -t -f interop/xds/client/Dockerfile . + +FROM golang:1.16-alpine as build + +# Make a grpc-go directory and copy the repo into it. +WORKDIR /go/src/grpc-go +COPY . . + +# Build a static binary without cgo so that we can copy just the binary in the +# final image, and can get rid of Go compiler and gRPC-Go dependencies. +RUN go build -tags osusergo,netgo interop/xds/client/client.go + +# Second stage of the build which copies over only the client binary and skips +# the Go compiler and gRPC repo from the earlier stage. This significantly +# reduces the docker image size. +FROM alpine +COPY --from=build /go/src/grpc-go/client . +ENV GRPC_GO_LOG_VERBOSITY_LEVEL=2 +ENV GRPC_GO_LOG_SEVERITY_LEVEL="info" +ENV GRPC_GO_LOG_FORMATTER="json" +ENTRYPOINT ["./client"] diff --git a/interop/xds/client/client.go b/interop/xds/client/client.go index 5b755272d3e..190c8a7d786 100644 --- a/interop/xds/client/client.go +++ b/interop/xds/client/client.go @@ -31,9 +31,13 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/admin" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" _ "google.golang.org/grpc/xds" @@ -173,6 +177,7 @@ var ( rpcTimeout = flag.Duration("rpc_timeout", 20*time.Second, "Per RPC timeout") server = flag.String("server", "localhost:8080", "Address of server to connect to") statsPort = flag.Int("stats_port", 8081, "Port to expose peer distribution stats service") + secureMode = flag.Bool("secure_mode", false, "If true, retrieve security configuration from the management server. Else, use insecure credentials.") rpcCfgs atomic.Value @@ -309,12 +314,13 @@ const ( emptyCall string = "EmptyCall" ) -func parseRPCTypes(rpcStr string) (ret []string) { +func parseRPCTypes(rpcStr string) []string { if len(rpcStr) == 0 { return []string{unaryCall} } rpcs := strings.Split(rpcStr, ",") + ret := make([]string, 0, len(rpcStr)) for _, r := range rpcs { switch r { case unaryCall, emptyCall: @@ -324,7 +330,7 @@ func parseRPCTypes(rpcStr string) (ret []string) { log.Fatalf("unsupported RPC type: %v", r) } } - return + return ret } type rpcConfig struct { @@ -370,11 +376,26 @@ func main() { defer s.Stop() testgrpc.RegisterLoadBalancerStatsServiceServer(s, &statsService{}) testgrpc.RegisterXdsUpdateClientConfigureServiceServer(s, &configureService{}) + reflection.Register(s) + cleanup, err := admin.Register(s) + if err != nil { + logger.Fatalf("Failed to register admin: %v", err) + } + defer cleanup() go s.Serve(lis) + creds := insecure.NewCredentials() + if *secureMode { + var err error + creds, err = xds.NewClientCredentials(xds.ClientOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + logger.Fatalf("Failed to create xDS credentials: %v", err) + } + } + clients := make([]testgrpc.TestServiceClient, *numChannels) for i := 0; i < *numChannels; i++ { - conn, err := grpc.Dial(*server, grpc.WithInsecure()) + conn, err := grpc.Dial(*server, grpc.WithTransportCredentials(creds)) if err != nil { logger.Fatalf("Fail to dial: %v", err) } diff --git a/interop/xds/server/Dockerfile b/interop/xds/server/Dockerfile new file mode 100644 index 00000000000..cd8dcb5ccaa --- /dev/null +++ b/interop/xds/server/Dockerfile @@ -0,0 +1,36 @@ +# Copyright 2021 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dockerfile for building the xDS interop server. To build the image, run the +# following command from grpc-go directory: +# docker build -t -f interop/xds/server/Dockerfile . + +FROM golang:1.16-alpine as build + +# Make a grpc-go directory and copy the repo into it. +WORKDIR /go/src/grpc-go +COPY . . + +# Build a static binary without cgo so that we can copy just the binary in the +# final image, and can get rid of the Go compiler and gRPC-Go dependencies. +RUN go build -tags osusergo,netgo interop/xds/server/server.go + +# Second stage of the build which copies over only the client binary and skips +# the Go compiler and gRPC repo from the earlier stage. This significantly +# reduces the docker image size. +FROM alpine +COPY --from=build /go/src/grpc-go/server . +ENV GRPC_GO_LOG_VERBOSITY_LEVEL=2 +ENV GRPC_GO_LOG_SEVERITY_LEVEL="info" +ENTRYPOINT ["./server"] diff --git a/interop/xds/server/server.go b/interop/xds/server/server.go index 4989eb728ee..afbbc56af89 100644 --- a/interop/xds/server/server.go +++ b/interop/xds/server/server.go @@ -1,6 +1,6 @@ /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,34 +16,46 @@ * */ -// Binary server for xDS interop tests. +// Binary server is the server used for xDS interop tests. package main import ( "context" "flag" + "fmt" "log" "net" "os" - "strconv" "google.golang.org/grpc" + "google.golang.org/grpc/admin" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/health" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/xds" + xdscreds "google.golang.org/grpc/credentials/xds" + healthpb "google.golang.org/grpc/health/grpc_health_v1" testgrpc "google.golang.org/grpc/interop/grpc_testing" testpb "google.golang.org/grpc/interop/grpc_testing" ) var ( - port = flag.Int("port", 8080, "The server port") - serverID = flag.String("server_id", "go_server", "Server ID included in response") - hostname = getHostname() + port = flag.Int("port", 8080, "Listening port for test service") + maintenancePort = flag.Int("maintenance_port", 8081, "Listening port for maintenance services like health, reflection, channelz etc when -secure_mode is true. When -secure_mode is false, all these services will be registered on -port") + serverID = flag.String("server_id", "go_server", "Server ID included in response") + secureMode = flag.Bool("secure_mode", false, "If true, retrieve security configuration from the management server. Else, use insecure credentials.") + hostNameOverride = flag.String("host_name_override", "", "If set, use this as the hostname instead of the real hostname") logger = grpclog.Component("interop") ) func getHostname() string { + if *hostNameOverride != "" { + return *hostNameOverride + } hostname, err := os.Hostname() if err != nil { log.Fatalf("failed to get hostname: %v", err) @@ -51,28 +63,127 @@ func getHostname() string { return hostname } -type server struct { +// testServiceImpl provides an implementation of the TestService defined in +// grpc.testing package. +type testServiceImpl struct { testgrpc.UnimplementedTestServiceServer + hostname string + serverID string +} + +func (s *testServiceImpl) EmptyCall(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + grpc.SetHeader(ctx, metadata.Pairs("hostname", s.hostname)) + return &testpb.Empty{}, nil +} + +func (s *testServiceImpl) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + grpc.SetHeader(ctx, metadata.Pairs("hostname", s.hostname)) + return &testpb.SimpleResponse{ServerId: s.serverID, Hostname: s.hostname}, nil +} + +// xdsUpdateHealthServiceImpl provides an implementation of the +// XdsUpdateHealthService defined in grpc.testing package. +type xdsUpdateHealthServiceImpl struct { + testgrpc.UnimplementedXdsUpdateHealthServiceServer + healthServer *health.Server } -func (s *server) EmptyCall(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { - grpc.SetHeader(ctx, metadata.Pairs("hostname", hostname)) +func (x *xdsUpdateHealthServiceImpl) SetServing(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + x.healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + return &testpb.Empty{}, nil + +} + +func (x *xdsUpdateHealthServiceImpl) SetNotServing(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + x.healthServer.SetServingStatus("", healthpb.HealthCheckResponse_NOT_SERVING) return &testpb.Empty{}, nil } -func (s *server) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { - grpc.SetHeader(ctx, metadata.Pairs("hostname", hostname)) - return &testpb.SimpleResponse{ServerId: *serverID, Hostname: hostname}, nil +func xdsServingModeCallback(addr net.Addr, args xds.ServingModeChangeArgs) { + logger.Infof("Serving mode for xDS server at %s changed to %s", addr.String(), args.Mode) + if args.Err != nil { + logger.Infof("ServingModeCallback returned error: %v", args.Err) + } } func main() { flag.Parse() - p := strconv.Itoa(*port) - lis, err := net.Listen("tcp", ":"+p) + + if *secureMode && *port == *maintenancePort { + logger.Fatal("-port and -maintenance_port must be different when -secure_mode is set") + } + + testService := &testServiceImpl{hostname: getHostname(), serverID: *serverID} + healthServer := health.NewServer() + updateHealthService := &xdsUpdateHealthServiceImpl{healthServer: healthServer} + + // If -secure_mode is not set, expose all services on -port with a regular + // gRPC server. + if !*secureMode { + lis, err := net.Listen("tcp4", fmt.Sprintf(":%d", *port)) + if err != nil { + logger.Fatalf("net.Listen(%s) failed: %v", fmt.Sprintf(":%d", *port), err) + } + + server := grpc.NewServer() + testgrpc.RegisterTestServiceServer(server, testService) + healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + healthpb.RegisterHealthServer(server, healthServer) + testgrpc.RegisterXdsUpdateHealthServiceServer(server, updateHealthService) + reflection.Register(server) + cleanup, err := admin.Register(server) + if err != nil { + logger.Fatalf("Failed to register admin services: %v", err) + } + defer cleanup() + if err := server.Serve(lis); err != nil { + logger.Errorf("Serve() failed: %v", err) + } + return + } + + // Create a listener on -port to expose the test service. + testLis, err := net.Listen("tcp4", fmt.Sprintf(":%d", *port)) if err != nil { - logger.Fatalf("failed to listen: %v", err) + logger.Fatalf("net.Listen(%s) failed: %v", fmt.Sprintf(":%d", *port), err) + } + + // Create server-side xDS credentials with a plaintext fallback. + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + logger.Fatalf("Failed to create xDS credentials: %v", err) + } + + // Create an xDS enabled gRPC server, register the test service + // implementation and start serving. + testServer := xds.NewGRPCServer(grpc.Creds(creds), xds.ServingModeCallback(xdsServingModeCallback)) + testgrpc.RegisterTestServiceServer(testServer, testService) + go func() { + if err := testServer.Serve(testLis); err != nil { + logger.Errorf("test server Serve() failed: %v", err) + } + }() + defer testServer.Stop() + + // Create a listener on -maintenance_port to expose other services. + maintenanceLis, err := net.Listen("tcp4", fmt.Sprintf(":%d", *maintenancePort)) + if err != nil { + logger.Fatalf("net.Listen(%s) failed: %v", fmt.Sprintf(":%d", *maintenancePort), err) + } + + // Create a regular gRPC server and register the maintenance services on + // it and start serving. + maintenanceServer := grpc.NewServer() + healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + healthpb.RegisterHealthServer(maintenanceServer, healthServer) + testgrpc.RegisterXdsUpdateHealthServiceServer(maintenanceServer, updateHealthService) + reflection.Register(maintenanceServer) + cleanup, err := admin.Register(maintenanceServer) + if err != nil { + logger.Fatalf("Failed to register admin services: %v", err) + } + defer cleanup() + if err := maintenanceServer.Serve(maintenanceLis); err != nil { + logger.Errorf("maintenance server Serve() failed: %v", err) } - s := grpc.NewServer() - testgrpc.RegisterTestServiceServer(s, &server{}) - s.Serve(lis) } diff --git a/metadata/metadata.go b/metadata/metadata.go index cf6d1b94781..3604c7819fd 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -75,13 +75,9 @@ func Pairs(kv ...string) MD { panic(fmt.Sprintf("metadata: Pairs got the odd number of input pairs for metadata: %d", len(kv))) } md := MD{} - var key string - for i, s := range kv { - if i%2 == 0 { - key = strings.ToLower(s) - continue - } - md[key] = append(md[key], s) + for i := 0; i < len(kv); i += 2 { + key := strings.ToLower(kv[i]) + md[key] = append(md[key], kv[i+1]) } return md } @@ -97,12 +93,16 @@ func (md MD) Copy() MD { } // Get obtains the values for a given key. +// +// k is converted to lowercase before searching in md. func (md MD) Get(k string) []string { k = strings.ToLower(k) return md[k] } // Set sets the value of a given key with a slice of values. +// +// k is converted to lowercase before storing in md. func (md MD) Set(k string, vals ...string) { if len(vals) == 0 { return @@ -111,7 +111,10 @@ func (md MD) Set(k string, vals ...string) { md[k] = vals } -// Append adds the values to key k, not overwriting what was already stored at that key. +// Append adds the values to key k, not overwriting what was already stored at +// that key. +// +// k is converted to lowercase before storing in md. func (md MD) Append(k string, vals ...string) { if len(vals) == 0 { return @@ -120,9 +123,17 @@ func (md MD) Append(k string, vals ...string) { md[k] = append(md[k], vals...) } +// Delete removes the values for a given key k which is converted to lowercase +// before removing it from md. +func (md MD) Delete(k string) { + k = strings.ToLower(k) + delete(md, k) +} + // Join joins any number of mds into a single MD. -// The order of values for each key is determined by the order in which -// the mds containing those values are presented to Join. +// +// The order of values for each key is determined by the order in which the mds +// containing those values are presented to Join. func Join(mds ...MD) MD { out := MD{} for _, md := range mds { @@ -149,8 +160,8 @@ func NewOutgoingContext(ctx context.Context, md MD) context.Context { } // AppendToOutgoingContext returns a new context with the provided kv merged -// with any existing metadata in the context. Please refer to the -// documentation of Pairs for a description of kv. +// with any existing metadata in the context. Please refer to the documentation +// of Pairs for a description of kv. func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context { if len(kv)%2 == 1 { panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv))) @@ -163,20 +174,34 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md.md, added: added}) } -// FromIncomingContext returns the incoming metadata in ctx if it exists. The -// returned MD should not be modified. Writing to it may cause races. -// Modification should be made to copies of the returned MD. -func FromIncomingContext(ctx context.Context) (md MD, ok bool) { - md, ok = ctx.Value(mdIncomingKey{}).(MD) - return +// FromIncomingContext returns the incoming metadata in ctx if it exists. +// +// All keys in the returned MD are lowercase. +func FromIncomingContext(ctx context.Context) (MD, bool) { + md, ok := ctx.Value(mdIncomingKey{}).(MD) + if !ok { + return nil, false + } + out := MD{} + for k, v := range md { + // We need to manually convert all keys to lower case, because MD is a + // map, and there's no guarantee that the MD attached to the context is + // created using our helper functions. + key := strings.ToLower(k) + out[key] = v + } + return out, true } -// FromOutgoingContextRaw returns the un-merged, intermediary contents -// of rawMD. Remember to perform strings.ToLower on the keys. The returned -// MD should not be modified. Writing to it may cause races. Modification -// should be made to copies of the returned MD. +// FromOutgoingContextRaw returns the un-merged, intermediary contents of rawMD. // -// This is intended for gRPC-internal use ONLY. +// Remember to perform strings.ToLower on the keys, for both the returned MD (MD +// is a map, there's no guarantee it's created using our helper functions) and +// the extra kv pairs (AppendToOutgoingContext doesn't turn them into +// lowercase). +// +// This is intended for gRPC-internal use ONLY. Users should use +// FromOutgoingContext instead. func FromOutgoingContextRaw(ctx context.Context) (MD, [][]string, bool) { raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD) if !ok { @@ -186,21 +211,34 @@ func FromOutgoingContextRaw(ctx context.Context) (MD, [][]string, bool) { return raw.md, raw.added, true } -// FromOutgoingContext returns the outgoing metadata in ctx if it exists. The -// returned MD should not be modified. Writing to it may cause races. -// Modification should be made to copies of the returned MD. +// FromOutgoingContext returns the outgoing metadata in ctx if it exists. +// +// All keys in the returned MD are lowercase. func FromOutgoingContext(ctx context.Context) (MD, bool) { raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD) if !ok { return nil, false } - mds := make([]MD, 0, len(raw.added)+1) - mds = append(mds, raw.md) - for _, vv := range raw.added { - mds = append(mds, Pairs(vv...)) + out := MD{} + for k, v := range raw.md { + // We need to manually convert all keys to lower case, because MD is a + // map, and there's no guarantee that the MD attached to the context is + // created using our helper functions. + key := strings.ToLower(k) + out[key] = v + } + for _, added := range raw.added { + if len(added)%2 == 1 { + panic(fmt.Sprintf("metadata: FromOutgoingContext got an odd number of input pairs for metadata: %d", len(added))) + } + + for i := 0; i < len(added); i += 2 { + key := strings.ToLower(added[i]) + out[key] = append(out[key], added[i+1]) + } } - return Join(mds...), ok + return out, ok } type rawMD struct { diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index f1fb5f6d324..89be06eaada 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -169,6 +169,35 @@ func (s) TestAppend(t *testing.T) { } } +func (s) TestDelete(t *testing.T) { + for _, test := range []struct { + md MD + deleteKey string + want MD + }{ + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "My-Optional-Header", + want: Pairs(), + }, + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "Other-Key", + want: Pairs("my-optional-header", "42"), + }, + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "my-OptIoNal-HeAder", + want: Pairs(), + }, + } { + test.md.Delete(test.deleteKey) + if !reflect.DeepEqual(test.md, test.want) { + t.Errorf("value of metadata is %v, want %v", test.md, test.want) + } + } +} + func (s) TestAppendToOutgoingContext(t *testing.T) { // Pre-existing metadata tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) diff --git a/picker_wrapper.go b/picker_wrapper.go index a58174b6f43..e8367cb8993 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -144,10 +144,10 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. acw, ok := pickResult.SubConn.(*acBalancerWrapper) if !ok { - logger.Error("subconn returned from pick is not *acBalancerWrapper") + logger.Errorf("subconn returned from pick is type %T, not *acBalancerWrapper", pickResult.SubConn) continue } - if t, ok := acw.getAddrConn().getReadyTransport(); ok { + if t := acw.getAddrConn().getReadyTransport(); t != nil { if channelz.IsOn() { return t, doneChannelzWrapper(acw, pickResult.Done), nil } diff --git a/pickfirst.go b/pickfirst.go index b858c2a5e63..f194d14a081 100644 --- a/pickfirst.go +++ b/pickfirst.go @@ -107,10 +107,12 @@ func (b *pickfirstBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.S } switch s.ConnectivityState { - case connectivity.Ready, connectivity.Idle: + case connectivity.Ready: b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{result: balancer.PickResult{SubConn: sc}}}) case connectivity.Connecting: b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable}}) + case connectivity.Idle: + b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &idlePicker{sc: sc}}) case connectivity.TransientFailure: b.cc.UpdateState(balancer.State{ ConnectivityState: s.ConnectivityState, @@ -122,6 +124,12 @@ func (b *pickfirstBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.S func (b *pickfirstBalancer) Close() { } +func (b *pickfirstBalancer) ExitIdle() { + if b.state == connectivity.Idle { + b.sc.Connect() + } +} + type picker struct { result balancer.PickResult err error @@ -131,6 +139,17 @@ func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { return p.result, p.err } +// idlePicker is used when the SubConn is IDLE and kicks the SubConn into +// CONNECTING when Pick is called. +type idlePicker struct { + sc balancer.SubConn +} + +func (i *idlePicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + i.sc.Connect() + return balancer.PickResult{}, balancer.ErrNoSubConnAvailable +} + func init() { balancer.Register(newPickfirstBuilder()) } diff --git a/profiling/proto/service_grpc.pb.go b/profiling/proto/service_grpc.pb.go index bfdcc69bffb..a0656149bda 100644 --- a/profiling/proto/service_grpc.pb.go +++ b/profiling/proto/service_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: profiling/proto/service.proto package proto diff --git a/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go b/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go index c2b7429a06b..7d05c14ebd8 100644 --- a/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go +++ b/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: reflection/grpc_reflection_v1alpha/reflection.proto package grpc_reflection_v1alpha diff --git a/reflection/grpc_testing/test_grpc.pb.go b/reflection/grpc_testing/test_grpc.pb.go index 76ec8d52d68..235b5d82484 100644 --- a/reflection/grpc_testing/test_grpc.pb.go +++ b/reflection/grpc_testing/test_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: reflection/grpc_testing/test.proto package grpc_testing diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index d2696168b10..82a5ba7f244 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -54,9 +54,19 @@ import ( "google.golang.org/grpc/status" ) +// GRPCServer is the interface provided by a gRPC server. It is implemented by +// *grpc.Server, but could also be implemented by other concrete types. It acts +// as a registry, for accumulating the services exposed by the server. +type GRPCServer interface { + grpc.ServiceRegistrar + GetServiceInfo() map[string]grpc.ServiceInfo +} + +var _ GRPCServer = (*grpc.Server)(nil) + type serverReflectionServer struct { rpb.UnimplementedServerReflectionServer - s *grpc.Server + s GRPCServer initSymbols sync.Once serviceNames []string @@ -64,7 +74,7 @@ type serverReflectionServer struct { } // Register registers the server reflection service on the given gRPC server. -func Register(s *grpc.Server) { +func Register(s GRPCServer) { rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ s: s, }) diff --git a/regenerate.sh b/regenerate.sh index fc6725b89f8..dfd3226a1d9 100755 --- a/regenerate.sh +++ b/regenerate.sh @@ -48,11 +48,6 @@ mkdir -p ${WORKDIR}/googleapis/google/rpc echo "curl https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto" curl --silent https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto > ${WORKDIR}/googleapis/google/rpc/code.proto -# Pull in the MeshCA service proto. -mkdir -p ${WORKDIR}/istio/istio/google/security/meshca/v1 -echo "curl https://raw.githubusercontent.com/istio/istio/master/security/proto/providers/google/meshca.proto" -curl --silent https://raw.githubusercontent.com/istio/istio/master/security/proto/providers/google/meshca.proto > ${WORKDIR}/istio/istio/google/security/meshca/v1/meshca.proto - mkdir -p ${WORKDIR}/out # Generates sources without the embed requirement @@ -76,7 +71,6 @@ SOURCES=( ${WORKDIR}/grpc-proto/grpc/service_config/service_config.proto ${WORKDIR}/grpc-proto/grpc/testing/*.proto ${WORKDIR}/grpc-proto/grpc/core/*.proto - ${WORKDIR}/istio/istio/google/security/meshca/v1/meshca.proto ) # These options of the form 'Mfoo.proto=bar' instruct the codegen to use an @@ -122,8 +116,4 @@ mv ${WORKDIR}/out/grpc/service_config/service_config.pb.go internal/proto/grpc_s mv ${WORKDIR}/out/grpc/testing/*.pb.go interop/grpc_testing/ mv ${WORKDIR}/out/grpc/core/*.pb.go interop/grpc_testing/core/ -# istio/google/security/meshca/v1/meshca.proto does not have a go_package option. -mkdir -p ${WORKDIR}/out/google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1/ -mv ${WORKDIR}/out/istio/google/security/meshca/v1/* ${WORKDIR}/out/google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1/ - cp -R ${WORKDIR}/out/google.golang.org/grpc/* . diff --git a/resolver/manual/manual.go b/resolver/manual/manual.go index 3679d702ab9..f6e7b5ae358 100644 --- a/resolver/manual/manual.go +++ b/resolver/manual/manual.go @@ -27,7 +27,9 @@ import ( // NewBuilderWithScheme creates a new test resolver builder with the given scheme. func NewBuilderWithScheme(scheme string) *Resolver { return &Resolver{ + BuildCallback: func(resolver.Target, resolver.ClientConn, resolver.BuildOptions) {}, ResolveNowCallback: func(resolver.ResolveNowOptions) {}, + CloseCallback: func() {}, scheme: scheme, } } @@ -35,11 +37,17 @@ func NewBuilderWithScheme(scheme string) *Resolver { // Resolver is also a resolver builder. // It's build() function always returns itself. type Resolver struct { + // BuildCallback is called when the Build method is called. Must not be + // nil. Must not be changed after the resolver may be built. + BuildCallback func(resolver.Target, resolver.ClientConn, resolver.BuildOptions) // ResolveNowCallback is called when the ResolveNow method is called on the // resolver. Must not be nil. Must not be changed after the resolver may // be built. ResolveNowCallback func(resolver.ResolveNowOptions) - scheme string + // CloseCallback is called when the Close method is called. Must not be + // nil. Must not be changed after the resolver may be built. + CloseCallback func() + scheme string // Fields actually belong to the resolver. CC resolver.ClientConn @@ -54,6 +62,7 @@ func (r *Resolver) InitialState(s resolver.State) { // Build returns itself for Resolver, because it's both a builder and a resolver. func (r *Resolver) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { + r.BuildCallback(target, cc, opts) r.CC = cc if r.bootstrapState != nil { r.UpdateState(*r.bootstrapState) @@ -72,9 +81,16 @@ func (r *Resolver) ResolveNow(o resolver.ResolveNowOptions) { } // Close is a noop for Resolver. -func (*Resolver) Close() {} +func (r *Resolver) Close() { + r.CloseCallback() +} // UpdateState calls CC.UpdateState. func (r *Resolver) UpdateState(s resolver.State) { r.CC.UpdateState(s) } + +// ReportError calls CC.ReportError. +func (r *Resolver) ReportError(err error) { + r.CC.ReportError(err) +} diff --git a/resolver/resolver.go b/resolver/resolver.go index e9fa8e33d92..6a9d234a597 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -181,7 +181,7 @@ type State struct { // gRPC to add new methods to this interface. type ClientConn interface { // UpdateState updates the state of the ClientConn appropriately. - UpdateState(State) + UpdateState(State) error // ReportError notifies the ClientConn that the Resolver encountered an // error. The ClientConn will notify the load balancer and begin calling // ResolveNow on the Resolver with exponential backoff. diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go index f2d81968f9e..2c47cd54f07 100644 --- a/resolver_conn_wrapper.go +++ b/resolver_conn_wrapper.go @@ -22,7 +22,6 @@ import ( "fmt" "strings" "sync" - "time" "google.golang.org/grpc/balancer" "google.golang.org/grpc/credentials" @@ -41,8 +40,7 @@ type ccResolverWrapper struct { done *grpcsync.Event curState resolver.State - pollingMu sync.Mutex - polling chan struct{} + incomingMu sync.Mutex // Synchronizes all the incoming calls. } // newCCResolverWrapper uses the resolver.Builder to build a Resolver and @@ -93,71 +91,37 @@ func (ccr *ccResolverWrapper) close() { ccr.resolverMu.Unlock() } -// poll begins or ends asynchronous polling of the resolver based on whether -// err is ErrBadResolverState. -func (ccr *ccResolverWrapper) poll(err error) { - ccr.pollingMu.Lock() - defer ccr.pollingMu.Unlock() - if err != balancer.ErrBadResolverState { - // stop polling - if ccr.polling != nil { - close(ccr.polling) - ccr.polling = nil - } - return - } - if ccr.polling != nil { - // already polling - return - } - p := make(chan struct{}) - ccr.polling = p - go func() { - for i := 0; ; i++ { - ccr.resolveNow(resolver.ResolveNowOptions{}) - t := time.NewTimer(ccr.cc.dopts.resolveNowBackoff(i)) - select { - case <-p: - t.Stop() - return - case <-ccr.done.Done(): - // Resolver has been closed. - t.Stop() - return - case <-t.C: - select { - case <-p: - return - default: - } - // Timer expired; re-resolve. - } - } - }() -} - -func (ccr *ccResolverWrapper) UpdateState(s resolver.State) { +func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { - return + return nil } channelz.Infof(logger, ccr.cc.channelzID, "ccResolverWrapper: sending update to cc: %v", s) if channelz.IsOn() { ccr.addChannelzTraceEvent(s) } ccr.curState = s - ccr.poll(ccr.cc.updateResolverState(ccr.curState, nil)) + if err := ccr.cc.updateResolverState(ccr.curState, nil); err == balancer.ErrBadResolverState { + return balancer.ErrBadResolverState + } + return nil } func (ccr *ccResolverWrapper) ReportError(err error) { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return } channelz.Warningf(logger, ccr.cc.channelzID, "ccResolverWrapper: reporting error to cc: %v", err) - ccr.poll(ccr.cc.updateResolverState(resolver.State{}, err)) + ccr.cc.updateResolverState(resolver.State{}, err) } // NewAddress is called by the resolver implementation to send addresses to gRPC. func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return } @@ -166,12 +130,14 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { ccr.addChannelzTraceEvent(resolver.State{Addresses: addrs, ServiceConfig: ccr.curState.ServiceConfig}) } ccr.curState.Addresses = addrs - ccr.poll(ccr.cc.updateResolverState(ccr.curState, nil)) + ccr.cc.updateResolverState(ccr.curState, nil) } // NewServiceConfig is called by the resolver implementation to send service // configs to gRPC. func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return } @@ -183,14 +149,13 @@ func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { scpr := parseServiceConfig(sc) if scpr.Err != nil { channelz.Warningf(logger, ccr.cc.channelzID, "ccResolverWrapper: error parsing service config: %v", scpr.Err) - ccr.poll(balancer.ErrBadResolverState) return } if channelz.IsOn() { ccr.addChannelzTraceEvent(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: scpr}) } ccr.curState.ServiceConfig = scpr - ccr.poll(ccr.cc.updateResolverState(ccr.curState, nil)) + ccr.cc.updateResolverState(ccr.curState, nil) } func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.ParseResult { diff --git a/resolver_conn_wrapper_test.go b/resolver_conn_wrapper_test.go index f13a408937b..81c5b9ea874 100644 --- a/resolver_conn_wrapper_test.go +++ b/resolver_conn_wrapper_test.go @@ -67,62 +67,6 @@ func (s) TestDialParseTargetUnknownScheme(t *testing.T) { } } -func testResolverErrorPolling(t *testing.T, badUpdate func(*manual.Resolver), goodUpdate func(*manual.Resolver), dopts ...DialOption) { - boIter := make(chan int) - resolverBackoff := func(v int) time.Duration { - boIter <- v - return 0 - } - - r := manual.NewBuilderWithScheme("whatever") - rn := make(chan struct{}) - defer func() { close(rn) }() - r.ResolveNowCallback = func(resolver.ResolveNowOptions) { rn <- struct{}{} } - - defaultDialOptions := []DialOption{ - WithInsecure(), - WithResolvers(r), - withResolveNowBackoff(resolverBackoff), - } - cc, err := Dial(r.Scheme()+":///test.server", append(defaultDialOptions, dopts...)...) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v; want _, nil", err) - } - defer cc.Close() - badUpdate(r) - - panicAfter := time.AfterFunc(5*time.Second, func() { panic("timed out polling resolver") }) - defer panicAfter.Stop() - - // Ensure ResolveNow is called, then Backoff with the right parameter, several times - for i := 0; i < 7; i++ { - <-rn - if v := <-boIter; v != i { - t.Errorf("Backoff call %v uses value %v", i, v) - } - } - - // UpdateState will block if ResolveNow is being called (which blocks on - // rn), so call it in a goroutine. - goodUpdate(r) - - // Wait awhile to ensure ResolveNow and Backoff stop being called when the - // state is OK (i.e. polling was cancelled). - for { - t := time.NewTimer(50 * time.Millisecond) - select { - case <-rn: - // ClientConn is still calling ResolveNow - <-boIter - time.Sleep(5 * time.Millisecond) - continue - case <-t.C: - // ClientConn stopped calling ResolveNow; success - } - break - } -} - const happyBalancerName = "happy balancer" func init() { @@ -136,35 +80,6 @@ func init() { stub.Register(happyBalancerName, bf) } -// TestResolverErrorPolling injects resolver errors and verifies ResolveNow is -// called with the appropriate backoff strategy being consulted between -// ResolveNow calls. -func (s) TestResolverErrorPolling(t *testing.T) { - testResolverErrorPolling(t, func(r *manual.Resolver) { - r.CC.ReportError(errors.New("res err")) - }, func(r *manual.Resolver) { - // UpdateState will block if ResolveNow is being called (which blocks on - // rn), so call it in a goroutine. - go r.CC.UpdateState(resolver.State{}) - }, - WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, happyBalancerName))) -} - -// TestServiceConfigErrorPolling injects a service config error and verifies -// ResolveNow is called with the appropriate backoff strategy being consulted -// between ResolveNow calls. -func (s) TestServiceConfigErrorPolling(t *testing.T) { - testResolverErrorPolling(t, func(r *manual.Resolver) { - badsc := r.CC.ParseServiceConfig("bad config") - r.UpdateState(resolver.State{ServiceConfig: badsc}) - }, func(r *manual.Resolver) { - // UpdateState will block if ResolveNow is being called (which blocks on - // rn), so call it in a goroutine. - go r.CC.UpdateState(resolver.State{}) - }, - WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, happyBalancerName))) -} - // TestResolverErrorInBuild makes the resolver.Builder call into the ClientConn // during the Build call. We use two separate mutexes in the code which make // sure there is no data race in this code path, and also that there is no diff --git a/rpc_util.go b/rpc_util.go index f781aaf751e..05904c816de 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -258,7 +258,8 @@ func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) { } // WaitForReady configures the action to take when an RPC is attempted on broken -// connections or unreachable servers. If waitForReady is false, the RPC will fail +// connections or unreachable servers. If waitForReady is false and the +// connection is in the TRANSIENT_FAILURE state, the RPC will fail // immediately. Otherwise, the RPC client will block the call until a // connection is available (or the call is canceled or times out) and will // retry the call if it fails due to a transient error. gRPC will not retry if @@ -429,9 +430,10 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error { } func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {} -// ForceCodec returns a CallOption that will set the given Codec to be -// used for all request and response messages for a call. The result of calling -// String() will be used as the content-subtype in a case-insensitive manner. +// ForceCodec returns a CallOption that will set codec to be used for all +// request and response messages for a call. The result of calling Name() will +// be used as the content-subtype after converting to lowercase, unless +// CallContentSubtype is also used. // // See Content-Type on // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for @@ -834,33 +836,45 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { // toRPCErr converts an error into an error from the status package. func toRPCErr(err error) error { - if err == nil || err == io.EOF { + switch err { + case nil, io.EOF: return err - } - if err == io.ErrUnexpectedEOF { + case context.DeadlineExceeded: + return status.Error(codes.DeadlineExceeded, err.Error()) + case context.Canceled: + return status.Error(codes.Canceled, err.Error()) + case io.ErrUnexpectedEOF: return status.Error(codes.Internal, err.Error()) } - if _, ok := status.FromError(err); ok { - return err - } + switch e := err.(type) { case transport.ConnectionError: return status.Error(codes.Unavailable, e.Desc) - default: - switch err { - case context.DeadlineExceeded: - return status.Error(codes.DeadlineExceeded, err.Error()) - case context.Canceled: - return status.Error(codes.Canceled, err.Error()) - } + case *transport.NewStreamError: + return toRPCErr(e.Err) } + + if _, ok := status.FromError(err); ok { + return err + } + return status.Error(codes.Unknown, err.Error()) } // setCallInfoCodec should only be called after CallOptions have been applied. func setCallInfoCodec(c *callInfo) error { if c.codec != nil { - // codec was already set by a CallOption; use it. + // codec was already set by a CallOption; use it, but set the content + // subtype if it is not set. + if c.contentSubtype == "" { + // c.codec is a baseCodec to hide the difference between grpc.Codec and + // encoding.Codec (Name vs. String method name). We only support + // setting content subtype from encoding.Codec to avoid a behavior + // change with the deprecated version. + if ec, ok := c.codec.(encoding.Codec); ok { + c.contentSubtype = strings.ToLower(ec.Name()) + } + } return nil } diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index 534a3ed417b..1892c5ed766 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -181,6 +181,9 @@ type ClientOptions struct { RootOptions RootCertificateOptions // VType is the verification type on the client side. VType VerificationType + // RevocationConfig is the configurations for certificate revocation checks. + // It could be nil if such checks are not needed. + RevocationConfig *RevocationConfig } // ServerOptions contains the fields needed to be filled by the server. @@ -199,6 +202,9 @@ type ServerOptions struct { RequireClientCert bool // VType is the verification type on the server side. VType VerificationType + // RevocationConfig is the configurations for certificate revocation checks. + // It could be nil if such checks are not needed. + RevocationConfig *RevocationConfig } func (o *ClientOptions) config() (*tls.Config, error) { @@ -356,11 +362,12 @@ func (o *ServerOptions) config() (*tls.Config, error) { // advancedTLSCreds is the credentials required for authenticating a connection // using TLS. type advancedTLSCreds struct { - config *tls.Config - verifyFunc CustomVerificationFunc - getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) - isClient bool - vType VerificationType + config *tls.Config + verifyFunc CustomVerificationFunc + getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) + isClient bool + vType VerificationType + revocationConfig *RevocationConfig } func (c advancedTLSCreds) Info() credentials.ProtocolInfo { @@ -451,6 +458,14 @@ func buildVerifyFunc(c *advancedTLSCreds, return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { chains := verifiedChains var leafCert *x509.Certificate + rawCertList := make([]*x509.Certificate, len(rawCerts)) + for i, asn1Data := range rawCerts { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return err + } + rawCertList[i] = cert + } if c.vType == CertAndHostVerification || c.vType == CertVerification { // perform possible trust credential reloading and certificate check rootCAs := c.config.RootCAs @@ -469,14 +484,6 @@ func buildVerifyFunc(c *advancedTLSCreds, rootCAs = results.TrustCerts } // Verify peers' certificates against RootCAs and get verifiedChains. - certs := make([]*x509.Certificate, len(rawCerts)) - for i, asn1Data := range rawCerts { - cert, err := x509.ParseCertificate(asn1Data) - if err != nil { - return err - } - certs[i] = cert - } keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} if !c.isClient { keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} @@ -487,7 +494,7 @@ func buildVerifyFunc(c *advancedTLSCreds, Intermediates: x509.NewCertPool(), KeyUsages: keyUsages, } - for _, cert := range certs[1:] { + for _, cert := range rawCertList[1:] { opts.Intermediates.AddCert(cert) } // Perform default hostname check if specified. @@ -501,11 +508,21 @@ func buildVerifyFunc(c *advancedTLSCreds, opts.DNSName = parsedName } var err error - chains, err = certs[0].Verify(opts) + chains, err = rawCertList[0].Verify(opts) if err != nil { return err } - leafCert = certs[0] + leafCert = rawCertList[0] + } + // Perform certificate revocation check if specified. + if c.revocationConfig != nil { + verifiedChains := chains + if verifiedChains == nil { + verifiedChains = [][]*x509.Certificate{rawCertList} + } + if err := CheckChainRevocation(verifiedChains, *c.revocationConfig); err != nil { + return err + } } // Perform custom verification check if specified. if c.verifyFunc != nil { @@ -529,11 +546,12 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) return nil, err } tc := &advancedTLSCreds{ - config: conf, - isClient: true, - getRootCAs: o.RootOptions.GetRootCertificates, - verifyFunc: o.VerifyPeer, - vType: o.VType, + config: conf, + isClient: true, + getRootCAs: o.RootOptions.GetRootCertificates, + verifyFunc: o.VerifyPeer, + vType: o.VType, + revocationConfig: o.RevocationConfig, } tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) return tc, nil @@ -547,11 +565,12 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) return nil, err } tc := &advancedTLSCreds{ - config: conf, - isClient: false, - getRootCAs: o.RootOptions.GetRootCertificates, - verifyFunc: o.VerifyPeer, - vType: o.VType, + config: conf, + isClient: false, + getRootCAs: o.RootOptions.GetRootCertificates, + verifyFunc: o.VerifyPeer, + vType: o.VType, + revocationConfig: o.RevocationConfig, } tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) return tc, nil diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index 4bb9e645b0a..8cddfc234b1 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -380,7 +380,7 @@ func (s) TestEnd2End(t *testing.T) { } clientTLSCreds, err := NewClientCreds(clientOptions) if err != nil { - t.Fatalf("clientTLSCreds failed to create") + t.Fatalf("clientTLSCreds failed to create: %v", err) } // ------------------------Scenario 1------------------------------------ // stage = 0, initial connection should succeed @@ -796,7 +796,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) { } clientTLSCreds, err := NewClientCreds(clientOptions) if err != nil { - t.Fatalf("clientTLSCreds failed to create") + t.Fatalf("clientTLSCreds failed to create: %v", err) } shouldFail := false if test.expectError { diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index 64da81a1700..7092d46e60f 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -27,10 +27,12 @@ import ( "net" "testing" + lru "github.com/hashicorp/golang-lru" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/security/advancedtls/internal/testutils" + "google.golang.org/grpc/security/advancedtls/testdata" ) type s struct { @@ -339,6 +341,10 @@ func (s) TestClientServerHandshake(t *testing.T) { getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return nil, fmt.Errorf("bad root certificate reloading") } + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } for _, test := range []struct { desc string clientCert []tls.Certificate @@ -349,6 +355,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientVType VerificationType clientRootProvider certprovider.Provider clientIdentityProvider certprovider.Provider + clientRevocationConfig *RevocationConfig clientExpectHandshakeError bool serverMutualTLS bool serverCert []tls.Certificate @@ -359,6 +366,7 @@ func (s) TestClientServerHandshake(t *testing.T) { serverVType VerificationType serverRootProvider certprovider.Provider serverIdentityProvider certprovider.Provider + serverRevocationConfig *RevocationConfig serverExpectError bool }{ // Client: nil setting except verifyFuncGood @@ -642,6 +650,30 @@ func (s) TestClientServerHandshake(t *testing.T) { serverRootProvider: fakeProvider{isClient: false}, serverVType: CertVerification, }, + // Client: set valid credentials with the revocation config + // Server: set valid credentials with the revocation config + // Expected Behavior: success, because non of the certificate chains sent in the connection are revoked + { + desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS", + clientCert: []tls.Certificate{cs.ClientCert1}, + clientGetRoot: getRootCAsForClient, + clientVerifyFunc: clientVerifyFuncGood, + clientVType: CertVerification, + clientRevocationConfig: &RevocationConfig{ + RootDir: testdata.Path("crl"), + AllowUndetermined: true, + Cache: cache, + }, + serverMutualTLS: true, + serverCert: []tls.Certificate{cs.ServerCert1}, + serverGetRoot: getRootCAsForServer, + serverVType: CertVerification, + serverRevocationConfig: &RevocationConfig{ + RootDir: testdata.Path("crl"), + AllowUndetermined: true, + Cache: cache, + }, + }, } { test := test t.Run(test.desc, func(t *testing.T) { @@ -665,6 +697,7 @@ func (s) TestClientServerHandshake(t *testing.T) { RequireClientCert: test.serverMutualTLS, VerifyPeer: test.serverVerifyFunc, VType: test.serverVType, + RevocationConfig: test.serverRevocationConfig, } go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) { serverRawConn, err := lis.Accept() @@ -706,7 +739,8 @@ func (s) TestClientServerHandshake(t *testing.T) { GetRootCertificates: test.clientGetRoot, RootProvider: test.clientRootProvider, }, - VType: test.clientVType, + VType: test.clientVType, + RevocationConfig: test.clientRevocationConfig, } clientTLS, err := NewClientCreds(clientOptions) if err != nil { diff --git a/security/advancedtls/crl.go b/security/advancedtls/crl.go new file mode 100644 index 00000000000..3931c1ec629 --- /dev/null +++ b/security/advancedtls/crl.go @@ -0,0 +1,499 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package advancedtls + +import ( + "bytes" + "crypto/sha1" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + "time" + + "google.golang.org/grpc/grpclog" +) + +var grpclogLogger = grpclog.Component("advancedtls") + +// Cache is an interface to cache CRL files. +// The cache implementation must be concurrency safe. +// A fixed size lru cache from golang-lru is recommended. +type Cache interface { + // Add adds a value to the cache. + Add(key, value interface{}) bool + // Get looks up a key's value from the cache. + Get(key interface{}) (value interface{}, ok bool) +} + +// RevocationConfig contains options for CRL lookup. +type RevocationConfig struct { + // RootDir is the directory to search for CRL files. + // Directory format must match OpenSSL X509_LOOKUP_hash_dir(3). + RootDir string + // AllowUndetermined controls if certificate chains with RevocationUndetermined + // revocation status are allowed to complete. + AllowUndetermined bool + // Cache will store CRL files if not nil, otherwise files are reloaded for every lookup. + Cache Cache +} + +// RevocationStatus is the revocation status for a certificate or chain. +type RevocationStatus int + +const ( + // RevocationUndetermined means we couldn't find or verify a CRL for the cert. + RevocationUndetermined RevocationStatus = iota + // RevocationUnrevoked means we found the CRL for the cert and the cert is not revoked. + RevocationUnrevoked + // RevocationRevoked means we found the CRL and the cert is revoked. + RevocationRevoked +) + +func (s RevocationStatus) String() string { + return [...]string{"RevocationUndetermined", "RevocationUnrevoked", "RevocationRevoked"}[s] +} + +// certificateListExt contains a pkix.CertificateList and parsed +// extensions that aren't provided by the golang CRL parser. +type certificateListExt struct { + CertList *pkix.CertificateList + // RFC5280, 5.2.1, all conforming CRLs must have a AKID with the ID method. + AuthorityKeyID []byte +} + +const tagDirectoryName = 4 + +var ( + // RFC5280, 5.2.4 id-ce-deltaCRLIndicator OBJECT IDENTIFIER ::= { id-ce 27 } + oidDeltaCRLIndicator = asn1.ObjectIdentifier{2, 5, 29, 27} + // RFC5280, 5.2.5 id-ce-issuingDistributionPoint OBJECT IDENTIFIER ::= { id-ce 28 } + oidIssuingDistributionPoint = asn1.ObjectIdentifier{2, 5, 29, 28} + // RFC5280, 5.3.3 id-ce-certificateIssuer OBJECT IDENTIFIER ::= { id-ce 29 } + oidCertificateIssuer = asn1.ObjectIdentifier{2, 5, 29, 29} + // RFC5290, 4.2.1.1 id-ce-authorityKeyIdentifier OBJECT IDENTIFIER ::= { id-ce 35 } + oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35} +) + +// x509NameHash implements the OpenSSL X509_NAME_hash function for hashed directory lookups. +func x509NameHash(r pkix.RDNSequence) string { + var canonBytes []byte + // First, canonicalize all the strings. + for _, rdnSet := range r { + for i, rdn := range rdnSet { + value, ok := rdn.Value.(string) + if !ok { + continue + } + // OpenSSL trims all whitespace, does a tolower, and removes extra spaces between words. + // Implemented in x509_name_canon in OpenSSL + canonStr := strings.Join(strings.Fields( + strings.TrimSpace(strings.ToLower(value))), " ") + // Then it changes everything to UTF8 strings + rdnSet[i].Value = asn1.RawValue{Tag: asn1.TagUTF8String, Bytes: []byte(canonStr)} + + } + } + + // Finally, OpenSSL drops the initial sequence tag + // so we marshal all the RDNs separately instead of as a group. + for _, canonRdn := range r { + b, err := asn1.Marshal(canonRdn) + if err != nil { + continue + } + canonBytes = append(canonBytes, b...) + } + + issuerHash := sha1.Sum(canonBytes) + // Openssl takes the first 4 bytes and encodes them as a little endian + // uint32 and then uses the hex to make the file name. + // In C++, this would be: + // (((unsigned long)md[0]) | ((unsigned long)md[1] << 8L) | + // ((unsigned long)md[2] << 16L) | ((unsigned long)md[3] << 24L) + // ) & 0xffffffffL; + fileHash := binary.LittleEndian.Uint32(issuerHash[0:4]) + return fmt.Sprintf("%08x", fileHash) +} + +// CheckRevocation checks the connection for revoked certificates based on RFC5280. +// This implementation has the following major limitations: +// * Indirect CRL files are not supported. +// * CRL loading is only supported from directories in the X509_LOOKUP_hash_dir format. +// * OnlySomeReasons is not supported. +// * Delta CRL files are not supported. +// * Certificate CRLDistributionPoint must be URLs, but are then ignored and converted into a file path. +// * CRL checks are done after path building, which goes against RFC4158. +func CheckRevocation(conn tls.ConnectionState, cfg RevocationConfig) error { + return CheckChainRevocation(conn.VerifiedChains, cfg) +} + +// CheckChainRevocation checks the verified certificate chain +// for revoked certificates based on RFC5280. +func CheckChainRevocation(verifiedChains [][]*x509.Certificate, cfg RevocationConfig) error { + // Iterate the verified chains looking for one that is RevocationUnrevoked. + // A single RevocationUnrevoked chain is enough to allow the connection, and a single RevocationRevoked + // chain does not mean the connection should fail. + count := make(map[RevocationStatus]int) + for _, chain := range verifiedChains { + switch checkChain(chain, cfg) { + case RevocationUnrevoked: + // If any chain is RevocationUnrevoked then return no error. + return nil + case RevocationRevoked: + // If this chain is revoked, keep looking for another chain. + count[RevocationRevoked]++ + continue + case RevocationUndetermined: + if cfg.AllowUndetermined { + return nil + } + count[RevocationUndetermined]++ + continue + } + } + return fmt.Errorf("no unrevoked chains found: %v", count) +} + +// checkChain will determine and check all certificates in chain against the CRL +// defined in the certificate with the following rules: +// 1. If any certificate is RevocationRevoked, return RevocationRevoked. +// 2. If any certificate is RevocationUndetermined, return RevocationUndetermined. +// 3. If all certificates are RevocationUnrevoked, return RevocationUnrevoked. +func checkChain(chain []*x509.Certificate, cfg RevocationConfig) RevocationStatus { + chainStatus := RevocationUnrevoked + for _, c := range chain { + switch checkCert(c, chain, cfg) { + case RevocationRevoked: + // Easy case, if a cert in the chain is revoked, the chain is revoked. + return RevocationRevoked + case RevocationUndetermined: + // If we couldn't find the revocation status for a cert, the chain is at best RevocationUndetermined + // keep looking to see if we find a cert in the chain that's RevocationRevoked, + // but return RevocationUndetermined at a minimum. + chainStatus = RevocationUndetermined + case RevocationUnrevoked: + // Continue iterating up the cert chain. + continue + } + } + return chainStatus +} + +func cachedCrl(rawIssuer []byte, cache Cache) (*certificateListExt, bool) { + val, ok := cache.Get(hex.EncodeToString(rawIssuer)) + if !ok { + return nil, false + } + crl, ok := val.(*certificateListExt) + if !ok { + return nil, false + } + // If the CRL is expired, force a reload. + if crl.CertList.HasExpired(time.Now()) { + return nil, false + } + return crl, true +} + +// fetchIssuerCRL fetches and verifies the CRL for rawIssuer from disk or cache if configured in cfg. +func fetchIssuerCRL(crlDistributionPoint string, rawIssuer []byte, crlVerifyCrt []*x509.Certificate, cfg RevocationConfig) (*certificateListExt, error) { + if cfg.Cache != nil { + if crl, ok := cachedCrl(rawIssuer, cfg.Cache); ok { + return crl, nil + } + } + + crl, err := fetchCRL(crlDistributionPoint, rawIssuer, cfg) + if err != nil { + return nil, fmt.Errorf("fetchCRL(%v) failed err = %v", crlDistributionPoint, err) + } + + if err := verifyCRL(crl, rawIssuer, crlVerifyCrt); err != nil { + return nil, fmt.Errorf("verifyCRL(%v) failed err = %v", crlDistributionPoint, err) + } + if cfg.Cache != nil { + cfg.Cache.Add(hex.EncodeToString(rawIssuer), crl) + } + return crl, nil +} + +// checkCert checks a single certificate against the CRL defined in the certificate. +// It will fetch and verify the CRL(s) defined by CRLDistributionPoints. +// If we can't load any authoritative CRL files, the status is RevocationUndetermined. +// c is the certificate to check. +// crlVerifyCrt is the group of possible certificates to verify the crl. +func checkCert(c *x509.Certificate, crlVerifyCrt []*x509.Certificate, cfg RevocationConfig) RevocationStatus { + if len(c.CRLDistributionPoints) == 0 { + return RevocationUnrevoked + } + // Iterate through CRL distribution points to check for status + for _, dp := range c.CRLDistributionPoints { + crl, err := fetchIssuerCRL(dp, c.RawIssuer, crlVerifyCrt, cfg) + if err != nil { + grpclogLogger.Warningf("getIssuerCRL(%v) err = %v", c.Issuer, err) + continue + } + revocation, err := checkCertRevocation(c, crl) + if err != nil { + grpclogLogger.Warningf("checkCertRevocation(CRL %v) failed %v", crl.CertList.TBSCertList.Issuer, err) + // We couldn't check the CRL file for some reason, so continue + // to the next file + continue + } + // Here we've gotten a CRL that loads and verifies. + // We only handle all-reasons CRL files, so this file + // is authoritative for the certificate. + return revocation + + } + // We couldn't load any CRL files for the certificate, so we don't know if it's RevocationUnrevoked or not. + return RevocationUndetermined +} + +func checkCertRevocation(c *x509.Certificate, crl *certificateListExt) (RevocationStatus, error) { + // Per section 5.3.3 we prime the certificate issuer with the CRL issuer. + // Subsequent entries use the previous entry's issuer. + rawEntryIssuer, err := asn1.Marshal(crl.CertList.TBSCertList.Issuer) + if err != nil { + return RevocationUndetermined, err + } + + // Loop through all the revoked certificates. + for _, revCert := range crl.CertList.TBSCertList.RevokedCertificates { + // 5.3 Loop through CRL entry extensions for needed information. + for _, ext := range revCert.Extensions { + if oidCertificateIssuer.Equal(ext.Id) { + extIssuer, err := parseCertIssuerExt(ext) + if err != nil { + grpclogLogger.Info(err) + if ext.Critical { + return RevocationUndetermined, err + } + // Since this is a non-critical extension, we can skip it even though + // there was a parsing failure. + continue + } + rawEntryIssuer = extIssuer + } else if ext.Critical { + return RevocationUndetermined, fmt.Errorf("checkCertRevocation: Unhandled critical extension: %v", ext.Id) + } + } + + // If the issuer and serial number appear in the CRL, the certificate is revoked. + if bytes.Equal(c.RawIssuer, rawEntryIssuer) && c.SerialNumber.Cmp(revCert.SerialNumber) == 0 { + // CRL contains the serial, so return revoked. + return RevocationRevoked, nil + } + } + // We did not find the serial in the CRL file that was valid for the cert + // so the certificate is not revoked. + return RevocationUnrevoked, nil +} + +func parseCertIssuerExt(ext pkix.Extension) ([]byte, error) { + // 5.3.3 Certificate Issuer + // CertificateIssuer ::= GeneralNames + // GeneralNames ::= SEQUENCE SIZE (1..MAX) OF GeneralName + var generalNames []asn1.RawValue + if rest, err := asn1.Unmarshal(ext.Value, &generalNames); err != nil || len(rest) != 0 { + return nil, fmt.Errorf("asn1.Unmarshal failed err = %v", err) + } + + for _, generalName := range generalNames { + // GeneralName ::= CHOICE { + // otherName [0] OtherName, + // rfc822Name [1] IA5String, + // dNSName [2] IA5String, + // x400Address [3] ORAddress, + // directoryName [4] Name, + // ediPartyName [5] EDIPartyName, + // uniformResourceIdentifier [6] IA5String, + // iPAddress [7] OCTET STRING, + // registeredID [8] OBJECT IDENTIFIER } + if generalName.Tag == tagDirectoryName { + return generalName.Bytes, nil + } + } + // Conforming CRL issuers MUST include in this extension the + // distinguished name (DN) from the issuer field of the certificate that + // corresponds to this CRL entry. + // If we couldn't get a directoryName, we can't reason about this file so cert status is + // RevocationUndetermined. + return nil, errors.New("no DN found in certificate issuer") +} + +// RFC 5280, 4.2.1.1 +type authKeyID struct { + ID []byte `asn1:"optional,tag:0"` +} + +// RFC5280, 5.2.5 +// id-ce-issuingDistributionPoint OBJECT IDENTIFIER ::= { id-ce 28 } + +// IssuingDistributionPoint ::= SEQUENCE { +// distributionPoint [0] DistributionPointName OPTIONAL, +// onlyContainsUserCerts [1] BOOLEAN DEFAULT FALSE, +// onlyContainsCACerts [2] BOOLEAN DEFAULT FALSE, +// onlySomeReasons [3] ReasonFlags OPTIONAL, +// indirectCRL [4] BOOLEAN DEFAULT FALSE, +// onlyContainsAttributeCerts [5] BOOLEAN DEFAULT FALSE } + +// -- at most one of onlyContainsUserCerts, onlyContainsCACerts, +// -- and onlyContainsAttributeCerts may be set to TRUE. +type issuingDistributionPoint struct { + DistributionPoint asn1.RawValue `asn1:"optional,tag:0"` + OnlyContainsUserCerts bool `asn1:"optional,tag:1"` + OnlyContainsCACerts bool `asn1:"optional,tag:2"` + OnlySomeReasons asn1.BitString `asn1:"optional,tag:3"` + IndirectCRL bool `asn1:"optional,tag:4"` + OnlyContainsAttributeCerts bool `asn1:"optional,tag:5"` +} + +// parseCRLExtensions parses the extensions for a CRL +// and checks that they're supported by the parser. +func parseCRLExtensions(c *pkix.CertificateList) (*certificateListExt, error) { + if c == nil { + return nil, errors.New("c is nil, expected any value") + } + certList := &certificateListExt{CertList: c} + + for _, ext := range c.TBSCertList.Extensions { + switch { + case oidDeltaCRLIndicator.Equal(ext.Id): + return nil, fmt.Errorf("delta CRLs unsupported") + + case oidAuthorityKeyIdentifier.Equal(ext.Id): + var a authKeyID + if rest, err := asn1.Unmarshal(ext.Value, &a); err != nil { + return nil, fmt.Errorf("asn1.Unmarshal failed. err = %v", err) + } else if len(rest) != 0 { + return nil, errors.New("trailing data after AKID extension") + } + certList.AuthorityKeyID = a.ID + + case oidIssuingDistributionPoint.Equal(ext.Id): + var dp issuingDistributionPoint + if rest, err := asn1.Unmarshal(ext.Value, &dp); err != nil { + return nil, fmt.Errorf("asn1.Unmarshal failed. err = %v", err) + } else if len(rest) != 0 { + return nil, errors.New("trailing data after IssuingDistributionPoint extension") + } + + if dp.OnlyContainsUserCerts || dp.OnlyContainsCACerts || dp.OnlyContainsAttributeCerts { + return nil, errors.New("CRL only contains some certificate types") + } + if dp.IndirectCRL { + return nil, errors.New("indirect CRLs unsupported") + } + if dp.OnlySomeReasons.BitLength != 0 { + return nil, errors.New("onlySomeReasons unsupported") + } + + case ext.Critical: + return nil, fmt.Errorf("unsupported critical extension: %v", ext.Id) + } + } + + if len(certList.AuthorityKeyID) == 0 { + return nil, errors.New("authority key identifier extension missing") + } + return certList, nil +} + +func fetchCRL(loc string, rawIssuer []byte, cfg RevocationConfig) (*certificateListExt, error) { + var parsedCRL *certificateListExt + // 6.3.3 (a) (1) (ii) + // According to X509_LOOKUP_hash_dir the format is issuer_hash.rN where N is an increasing number. + // There are no gaps, so we break when we can't find a file. + for i := 0; ; i++ { + // Unmarshal to RDNSeqence according to http://go/godoc/crypto/x509/pkix/#Name. + var r pkix.RDNSequence + rest, err := asn1.Unmarshal(rawIssuer, &r) + if len(rest) != 0 || err != nil { + return nil, fmt.Errorf("asn1.Unmarshal(Issuer) len(rest) = %v, err = %v", len(rest), err) + } + crlPath := fmt.Sprintf("%s.r%d", filepath.Join(cfg.RootDir, x509NameHash(r)), i) + crlBytes, err := ioutil.ReadFile(crlPath) + if err != nil { + // Break when we can't read a CRL file. + grpclogLogger.Infof("readFile: %v", err) + break + } + + crl, err := x509.ParseCRL(crlBytes) + if err != nil { + // Parsing errors for a CRL shouldn't happen so fail. + return nil, fmt.Errorf("x509.ParseCrl(%v) failed err = %v", crlPath, err) + } + var certList *certificateListExt + if certList, err = parseCRLExtensions(crl); err != nil { + grpclogLogger.Infof("fetchCRL: unsupported crl %v, err = %v", crlPath, err) + // Continue to find a supported CRL + continue + } + + rawCRLIssuer, err := asn1.Marshal(certList.CertList.TBSCertList.Issuer) + if err != nil { + return nil, fmt.Errorf("asn1.Marshal(%v) failed err = %v", certList.CertList.TBSCertList.Issuer, err) + } + // RFC5280, 6.3.3 (b) Verify the issuer and scope of the complete CRL. + if bytes.Equal(rawIssuer, rawCRLIssuer) { + parsedCRL = certList + // Continue to find the highest number in the .rN suffix. + continue + } + } + + if parsedCRL == nil { + return nil, fmt.Errorf("fetchCrls no CRLs found for issuer") + } + return parsedCRL, nil +} + +func verifyCRL(crl *certificateListExt, rawIssuer []byte, chain []*x509.Certificate) error { + // RFC5280, 6.3.3 (f) Obtain and validateate the certification path for the issuer of the complete CRL + // We intentionally limit our CRLs to be signed with the same certificate path as the certificate + // so we can use the chain from the connection. + rawCRLIssuer, err := asn1.Marshal(crl.CertList.TBSCertList.Issuer) + if err != nil { + return fmt.Errorf("asn1.Marshal(%v) failed err = %v", crl.CertList.TBSCertList.Issuer, err) + } + + for _, c := range chain { + // Use the key where the subject and KIDs match. + // This departs from RFC4158, 3.5.12 which states that KIDs + // cannot eliminate certificates, but RFC5280, 5.2.1 states that + // "Conforming CRL issuers MUST use the key identifier method, and MUST + // include this extension in all CRLs issued." + // So, this is much simpler than RFC4158 and should be compatible. + if bytes.Equal(c.SubjectKeyId, crl.AuthorityKeyID) && bytes.Equal(c.RawSubject, rawCRLIssuer) { + // RFC5280, 6.3.3 (g) Validate signature. + return c.CheckCRLSignature(crl.CertList) + } + } + return fmt.Errorf("verifyCRL: No certificates mached CRL issuer (%v)", crl.CertList.TBSCertList.Issuer) +} diff --git a/security/advancedtls/crl_test.go b/security/advancedtls/crl_test.go new file mode 100644 index 00000000000..ec4483304c7 --- /dev/null +++ b/security/advancedtls/crl_test.go @@ -0,0 +1,718 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package advancedtls + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/hex" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "os" + "path" + "strings" + "testing" + "time" + + lru "github.com/hashicorp/golang-lru" + "google.golang.org/grpc/security/advancedtls/testdata" +) + +func TestX509NameHash(t *testing.T) { + nameTests := []struct { + in pkix.Name + out string + }{ + { + in: pkix.Name{ + Country: []string{"US"}, + Organization: []string{"Example"}, + }, + out: "9cdd41ff", + }, + { + in: pkix.Name{ + Country: []string{"us"}, + Organization: []string{"example"}, + }, + out: "9cdd41ff", + }, + { + in: pkix.Name{ + Country: []string{" us"}, + Organization: []string{"example"}, + }, + out: "9cdd41ff", + }, + { + in: pkix.Name{ + Country: []string{"US"}, + Province: []string{"California"}, + Locality: []string{"Mountain View"}, + Organization: []string{"BoringSSL"}, + }, + out: "c24414d9", + }, + { + in: pkix.Name{ + Country: []string{"US"}, + Province: []string{"California"}, + Locality: []string{"Mountain View"}, + Organization: []string{"BoringSSL"}, + }, + out: "c24414d9", + }, + { + in: pkix.Name{ + SerialNumber: "87f4514475ba0a2b", + }, + out: "9dc713cd", + }, + { + in: pkix.Name{ + Country: []string{"US"}, + Province: []string{"California"}, + Locality: []string{"Mountain View"}, + Organization: []string{"Google LLC"}, + OrganizationalUnit: []string{"Production", "campus-sln"}, + CommonName: "Root CA (2021-02-02T07:30:36-08:00)", + }, + out: "0b35a562", + }, + { + in: pkix.Name{ + ExtraNames: []pkix.AttributeTypeAndValue{ + {Type: asn1.ObjectIdentifier{5, 5, 5, 5}, Value: "aaaa"}, + }, + }, + out: "eea339da", + }, + } + for _, tt := range nameTests { + t.Run(tt.in.String(), func(t *testing.T) { + h := x509NameHash(tt.in.ToRDNSequence()) + if h != tt.out { + t.Errorf("x509NameHash(%v): Got %v wanted %v", tt.in, h, tt.out) + } + }) + } +} + +func TestUnsupportedCRLs(t *testing.T) { + crlBytesSomeReasons := []byte(`-----BEGIN X509 CRL----- +MIIEeDCCA2ACAQEwDQYJKoZIhvcNAQELBQAwQjELMAkGA1UEBhMCVVMxHjAcBgNV +BAoTFUdvb2dsZSBUcnVzdCBTZXJ2aWNlczETMBEGA1UEAxMKR1RTIENBIDFPMRcN +MjEwNDI2MTI1OTQxWhcNMjEwNTA2MTE1OTQwWjCCAn0wIgIRAPOOG3L4VLC7CAAA +AABxQgEXDTIxMDQxOTEyMTgxOFowIQIQUK0UwBZkVdQIAAAAAHFCBRcNMjEwNDE5 +MTIxODE4WjAhAhBRIXBJaKoQkQgAAAAAcULHFw0yMTA0MjAxMjE4MTdaMCICEQCv +qQWUq5UxmQgAAAAAcULMFw0yMTA0MjAxMjE4MTdaMCICEQDdv5k1kKwKTQgAAAAA +cUOQFw0yMTA0MjExMjE4MTZaMCICEQDGIEfR8N9sEAgAAAAAcUOWFw0yMTA0MjEx +MjE4MThaMCECEBHgbLXlj5yUCAAAAABxQ/IXDTIxMDQyMTIzMDAyNlowIQIQE1wT +2GGYqKwIAAAAAHFD7xcNMjEwNDIxMjMwMDI5WjAiAhEAo/bSyDjpVtsIAAAAAHFE +txcNMjEwNDIyMjMwMDI3WjAhAhARdCrSrHE0dAgAAAAAcUS/Fw0yMTA0MjIyMzAw +MjhaMCECEHONohfWn3wwCAAAAABxRX8XDTIxMDQyMzIzMDAyOVowIgIRAOYkiUPA +os4vCAAAAABxRYgXDTIxMDQyMzIzMDAyOFowIQIQRNTow5Eg2gEIAAAAAHFGShcN +MjEwNDI0MjMwMDI2WjAhAhBX32dH4/WQ6AgAAAAAcUZNFw0yMTA0MjQyMzAwMjZa +MCICEQDHnUM1vsaP/wgAAAAAcUcQFw0yMTA0MjUyMzAwMjZaMCECEEm5rvmL8sj6 +CAAAAABxRxQXDTIxMDQyNTIzMDAyN1owIQIQW16OQs4YQYkIAAAAAHFIABcNMjEw +NDI2MTI1NDA4WjAhAhAhSohpYsJtDQgAAAAAcUgEFw0yMTA0MjYxMjU0MDlaoGkw +ZzAfBgNVHSMEGDAWgBSY0fhuEOvPm+xgnxiQG6DrfQn9KzALBgNVHRQEBAICBngw +NwYDVR0cAQH/BC0wK6AmoCSGImh0dHA6Ly9jcmwucGtpLmdvb2cvR1RTMU8xY29y +ZS5jcmyBAf8wDQYJKoZIhvcNAQELBQADggEBADPBXbxVxMJ1HC7btXExRUpJHUlU +YbeCZGx6zj5F8pkopbmpV7cpewwhm848Fx4VaFFppZQZd92O08daEC6aEqoug4qF +z6ZrOLzhuKfpW8E93JjgL91v0FYN7iOcT7+ERKCwVEwEkuxszxs7ggW6OJYJNvHh +priIdmcPoiQ3ZrIRH0vE3BfUcNXnKFGATWuDkiRI0I4A5P7NiOf+lAuGZet3/eom +0chgts6sdau10GfeUpHUd4f8e93cS/QeLeG16z7LC8vRLstU3m3vrknpZbdGqSia +97w66mqcnQh9V0swZiEnVLmLufaiuDZJ+6nUzSvLqBlb/ei3T/tKV0BoKJA= +-----END X509 CRL-----`) + + crlBytesIndirect := []byte(`-----BEGIN X509 CRL----- +MIIDGjCCAgICAQEwDQYJKoZIhvcNAQELBQAwdjELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCkNhbGlmb3JuaWExFDASBgNVBAoTC1Rlc3RpbmcgTHRkMSowKAYDVQQLEyFU +ZXN0aW5nIEx0ZCBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkxEDAOBgNVBAMTB1Rlc3Qg +Q0EXDTIxMDExNjAyMjAxNloXDTIxMDEyMDA2MjAxNlowgfIwbAIBAhcNMjEwMTE2 +MDIyMDE2WjBYMAoGA1UdFQQDCgEEMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQG +EwNVU0ExDTALBgNVBAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0 +MTAgAgEDFw0yMTAxMTYwMjIwMTZaMAwwCgYDVR0VBAMKAQEwYAIBBBcNMjEwMTE2 +MDIyMDE2WjBMMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQGEwNVU0ExDTALBgNV +BAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0MqBjMGEwHwYDVR0j +BBgwFoAURJSDWAOfhGCryBjl8dsQjBitl3swCgYDVR0UBAMCAQEwMgYDVR0cAQH/ +BCgwJqAhoB+GHWh0dHA6Ly9jcmxzLnBraS5nb29nL3Rlc3QuY3JshAH/MA0GCSqG +SIb3DQEBCwUAA4IBAQBVXX67mr2wFPmEWCe6mf/wFnPl3xL6zNOl96YJtsd7ulcS +TEbdJpaUnWFQ23+Tpzdj/lI2aQhTg5Lvii3o+D8C5r/Jc5NhSOtVJJDI/IQLh4pG +NgGdljdbJQIT5D2Z71dgbq1ocxn8DefZIJjO3jp8VnAm7AIMX2tLTySzD2MpMeMq +XmcN4lG1e4nx+xjzp7MySYO42NRY3LkphVzJhu3dRBYhBKViRJxw9hLttChitJpF +6Kh6a0QzrEY/QDJGhE1VrAD2c5g/SKnHPDVoCWo4ACIICi76KQQSIWfIdp4W/SY3 +qsSIp8gfxSyzkJP+Ngkm2DdLjlJQCZ9R0MZP9Xj4 +-----END X509 CRL-----`) + + var tests = []struct { + desc string + in []byte + }{ + { + desc: "some reasons", + in: crlBytesSomeReasons, + }, + { + desc: "indirect", + in: crlBytesIndirect, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + crl, err := x509.ParseCRL(tt.in) + if err != nil { + t.Fatal(err) + } + if _, err := parseCRLExtensions(crl); err == nil { + t.Error("expected error got ok") + } + }) + } +} + +func TestCheckCertRevocation(t *testing.T) { + dummyCrlFile := []byte(`-----BEGIN X509 CRL----- +MIIDGjCCAgICAQEwDQYJKoZIhvcNAQELBQAwdjELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCkNhbGlmb3JuaWExFDASBgNVBAoTC1Rlc3RpbmcgTHRkMSowKAYDVQQLEyFU +ZXN0aW5nIEx0ZCBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkxEDAOBgNVBAMTB1Rlc3Qg +Q0EXDTIxMDExNjAyMjAxNloXDTIxMDEyMDA2MjAxNlowgfIwbAIBAhcNMjEwMTE2 +MDIyMDE2WjBYMAoGA1UdFQQDCgEEMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQG +EwNVU0ExDTALBgNVBAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0 +MTAgAgEDFw0yMTAxMTYwMjIwMTZaMAwwCgYDVR0VBAMKAQEwYAIBBBcNMjEwMTE2 +MDIyMDE2WjBMMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQGEwNVU0ExDTALBgNV +BAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0MqBjMGEwHwYDVR0j +BBgwFoAURJSDWAOfhGCryBjl8dsQjBitl3swCgYDVR0UBAMCAQEwMgYDVR0cAQH/ +BCgwJqAhoB+GHWh0dHA6Ly9jcmxzLnBraS5nb29nL3Rlc3QuY3JshAH/MA0GCSqG +SIb3DQEBCwUAA4IBAQBVXX67mr2wFPmEWCe6mf/wFnPl3xL6zNOl96YJtsd7ulcS +TEbdJpaUnWFQ23+Tpzdj/lI2aQhTg5Lvii3o+D8C5r/Jc5NhSOtVJJDI/IQLh4pG +NgGdljdbJQIT5D2Z71dgbq1ocxn8DefZIJjO3jp8VnAm7AIMX2tLTySzD2MpMeMq +XmcN4lG1e4nx+xjzp7MySYO42NRY3LkphVzJhu3dRBYhBKViRJxw9hLttChitJpF +6Kh6a0QzrEY/QDJGhE1VrAD2c5g/SKnHPDVoCWo4ACIICi76KQQSIWfIdp4W/SY3 +qsSIp8gfxSyzkJP+Ngkm2DdLjlJQCZ9R0MZP9Xj4 +-----END X509 CRL-----`) + crl, err := x509.ParseCRL(dummyCrlFile) + if err != nil { + t.Fatalf("x509.ParseCRL(dummyCrlFile) failed: %v", err) + } + crlExt := &certificateListExt{CertList: crl} + var crlIssuer pkix.Name + crlIssuer.FillFromRDNSequence(&crl.TBSCertList.Issuer) + + var revocationTests = []struct { + desc string + in x509.Certificate + revoked RevocationStatus + }{ + { + desc: "Single revoked", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test1", + }, + SerialNumber: big.NewInt(2), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationRevoked, + }, + { + desc: "Revoked no entry issuer", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test1", + }, + SerialNumber: big.NewInt(3), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationRevoked, + }, + { + desc: "Revoked new entry issuer", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test2", + }, + SerialNumber: big.NewInt(4), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationRevoked, + }, + { + desc: "Single unrevoked", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test2", + }, + SerialNumber: big.NewInt(1), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationUnrevoked, + }, + { + desc: "Single unrevoked Issuer", + in: x509.Certificate{ + Issuer: crlIssuer, + SerialNumber: big.NewInt(2), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationUnrevoked, + }, + } + + for _, tt := range revocationTests { + rawIssuer, err := asn1.Marshal(tt.in.Issuer.ToRDNSequence()) + if err != nil { + t.Fatalf("asn1.Marshal(%v) failed: %v", tt.in.Issuer.ToRDNSequence(), err) + } + tt.in.RawIssuer = rawIssuer + t.Run(tt.desc, func(t *testing.T) { + rev, err := checkCertRevocation(&tt.in, crlExt) + if err != nil { + t.Errorf("checkCertRevocation(%v) err = %v", tt.in.Issuer, err) + } else if rev != tt.revoked { + t.Errorf("checkCertRevocation(%v(%v)) returned %v wanted %v", + tt.in.Issuer, tt.in.SerialNumber, rev, tt.revoked) + } + }) + } +} + +func makeChain(t *testing.T, name string) []*x509.Certificate { + t.Helper() + + certChain := make([]*x509.Certificate, 0) + + rest, err := ioutil.ReadFile(name) + if err != nil { + t.Fatalf("ioutil.ReadFile(%v) failed %v", name, err) + } + for len(rest) > 0 { + var block *pem.Block + block, rest = pem.Decode(rest) + c, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("ParseCertificate error %v", err) + } + t.Logf("Parsed Cert sub = %v iss = %v", c.Subject, c.Issuer) + certChain = append(certChain, c) + } + return certChain +} + +func loadCRL(t *testing.T, path string) *pkix.CertificateList { + b, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("readFile(%v) failed err = %v", path, err) + } + crl, err := x509.ParseCRL(b) + if err != nil { + t.Fatalf("ParseCrl(%v) failed err = %v", path, err) + } + return crl +} + +func TestCachedCRL(t *testing.T) { + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } + + tests := []struct { + desc string + val interface{} + ok bool + }{ + { + desc: "Valid", + val: &certificateListExt{ + CertList: &pkix.CertificateList{ + TBSCertList: pkix.TBSCertificateList{ + NextUpdate: time.Now().Add(time.Hour), + }, + }}, + ok: true, + }, + { + desc: "Expired", + val: &certificateListExt{ + CertList: &pkix.CertificateList{ + TBSCertList: pkix.TBSCertificateList{ + NextUpdate: time.Now().Add(-time.Hour), + }, + }}, + ok: false, + }, + { + desc: "Wrong Type", + val: "string", + ok: false, + }, + { + desc: "Empty", + val: nil, + ok: false, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if tt.val != nil { + cache.Add(hex.EncodeToString([]byte(tt.desc)), tt.val) + } + _, ok := cachedCrl([]byte(tt.desc), cache) + if tt.ok != ok { + t.Errorf("Cache ok error expected %v vs %v", tt.ok, ok) + } + }) + } +} + +func TestGetIssuerCRLCache(t *testing.T) { + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } + + tests := []struct { + desc string + rawIssuer []byte + certs []*x509.Certificate + }{ + { + desc: "Valid", + rawIssuer: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1].RawIssuer, + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + }, + { + desc: "Unverified", + rawIssuer: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1].RawIssuer, + }, + { + desc: "Not Found", + rawIssuer: []byte("not_found"), + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + cache.Purge() + _, err := fetchIssuerCRL("test", tt.rawIssuer, tt.certs, RevocationConfig{ + RootDir: testdata.Path("."), + Cache: cache, + }) + if err == nil && cache.Len() == 0 { + t.Error("Verified CRL not added to cache") + } + if err != nil && cache.Len() != 0 { + t.Error("Unverified CRL added to cache") + } + }) + } +} + +func TestVerifyCrl(t *testing.T) { + tampered := loadCRL(t, testdata.Path("crl/1.crl")) + // Change the signature so it won't verify + tampered.SignatureValue.Bytes[0]++ + + verifyTests := []struct { + desc string + crl *pkix.CertificateList + certs []*x509.Certificate + cert *x509.Certificate + errWant string + }{ + { + desc: "Pass intermediate", + crl: loadCRL(t, testdata.Path("crl/1.crl")), + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1], + errWant: "", + }, + { + desc: "Pass leaf", + crl: loadCRL(t, testdata.Path("crl/2.crl")), + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[2], + errWant: "", + }, + { + desc: "Fail wrong cert chain", + crl: loadCRL(t, testdata.Path("crl/3.crl")), + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/revokedInt.pem"))[1], + errWant: "No certificates mached", + }, + { + desc: "Fail no certs", + crl: loadCRL(t, testdata.Path("crl/1.crl")), + certs: []*x509.Certificate{}, + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1], + errWant: "No certificates mached", + }, + { + desc: "Fail Tampered signature", + crl: tampered, + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1], + errWant: "verification failure", + }, + } + + for _, tt := range verifyTests { + t.Run(tt.desc, func(t *testing.T) { + crlExt, err := parseCRLExtensions(tt.crl) + if err != nil { + t.Fatalf("parseCRLExtensions(%v) failed, err = %v", tt.crl.TBSCertList.Issuer, err) + } + err = verifyCRL(crlExt, tt.cert.RawIssuer, tt.certs) + switch { + case tt.errWant == "" && err != nil: + t.Errorf("Valid CRL did not verify err = %v", err) + case tt.errWant != "" && err == nil: + t.Error("Invalid CRL verified") + case tt.errWant != "" && !strings.Contains(err.Error(), tt.errWant): + t.Errorf("fetchIssuerCRL(_, %v, %v, _) = %v; want Contains(%v)", tt.cert.RawIssuer, tt.certs, err, tt.errWant) + } + }) + } +} + +func TestRevokedCert(t *testing.T) { + revokedIntChain := makeChain(t, testdata.Path("crl/revokedInt.pem")) + revokedLeafChain := makeChain(t, testdata.Path("crl/revokedLeaf.pem")) + validChain := makeChain(t, testdata.Path("crl/unrevoked.pem")) + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } + + var revocationTests = []struct { + desc string + in tls.ConnectionState + revoked bool + allowUndetermined bool + }{ + { + desc: "Single unrevoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{validChain}}, + revoked: false, + }, + { + desc: "Single revoked intermediate", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{revokedIntChain}}, + revoked: true, + }, + { + desc: "Single revoked leaf", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{revokedLeafChain}}, + revoked: true, + }, + { + desc: "Multi one revoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{validChain, revokedLeafChain}}, + revoked: false, + }, + { + desc: "Multi revoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{revokedLeafChain, revokedIntChain}}, + revoked: true, + }, + { + desc: "Multi unrevoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{validChain, validChain}}, + revoked: false, + }, + { + desc: "Undetermined revoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ + {&x509.Certificate{CRLDistributionPoints: []string{"test"}}}, + }}, + revoked: true, + }, + { + desc: "Undetermined allowed", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ + {&x509.Certificate{CRLDistributionPoints: []string{"test"}}}, + }}, + revoked: false, + allowUndetermined: true, + }, + } + + for _, tt := range revocationTests { + t.Run(tt.desc, func(t *testing.T) { + err := CheckRevocation(tt.in, RevocationConfig{ + RootDir: testdata.Path("crl"), + AllowUndetermined: tt.allowUndetermined, + Cache: cache, + }) + t.Logf("CheckRevocation err = %v", err) + if tt.revoked && err == nil { + t.Error("Revoked certificate chain was allowed") + } else if !tt.revoked && err != nil { + t.Error("Unrevoked certificate not allowed") + } + }) + } +} + +func setupTLSConn(t *testing.T) (net.Listener, *x509.Certificate, *ecdsa.PrivateKey) { + t.Helper() + templ := x509.Certificate{ + SerialNumber: big.NewInt(5), + BasicConstraintsValid: true, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + Subject: pkix.Name{CommonName: "test-cert"}, + KeyUsage: x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + IPAddresses: []net.IP{net.ParseIP("::1")}, + CRLDistributionPoints: []string{"http://static.corp.google.com/crl/campus-sln/borg"}, + } + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("ecdsa.GenerateKey failed err = %v", err) + } + rawCert, err := x509.CreateCertificate(rand.Reader, &templ, &templ, key.Public(), key) + if err != nil { + t.Fatalf("x509.CreateCertificate failed err = %v", err) + } + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + t.Fatalf("x509.ParseCertificate failed err = %v", err) + } + + srvCfg := tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{cert.Raw}, + PrivateKey: key, + }, + }, + } + l, err := tls.Listen("tcp6", "[::1]:0", &srvCfg) + if err != nil { + t.Fatalf("tls.Listen failed err = %v", err) + } + return l, cert, key +} + +// TestVerifyConnection will setup a client/server connection and check revocation in the real TLS dialer +func TestVerifyConnection(t *testing.T) { + lis, cert, key := setupTLSConn(t) + defer func() { + lis.Close() + }() + + var handshakeTests = []struct { + desc string + revoked []pkix.RevokedCertificate + success bool + }{ + { + desc: "Empty CRL", + revoked: []pkix.RevokedCertificate{}, + success: true, + }, + { + desc: "Revoked Cert", + revoked: []pkix.RevokedCertificate{ + { + SerialNumber: cert.SerialNumber, + RevocationTime: time.Now(), + }, + }, + success: false, + }, + } + for _, tt := range handshakeTests { + t.Run(tt.desc, func(t *testing.T) { + // Accept one connection. + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("tls.Accept failed err = %v", err) + } else { + conn.Write([]byte("Hello, World!")) + conn.Close() + } + }() + + dir, err := ioutil.TempDir("", "crl_dir") + if err != nil { + t.Fatalf("ioutil.TempDir failed err = %v", err) + } + defer os.RemoveAll(dir) + + crl, err := cert.CreateCRL(rand.Reader, key, tt.revoked, time.Now(), time.Now().Add(time.Hour)) + if err != nil { + t.Fatalf("templ.CreateCRL failed err = %v", err) + } + + err = ioutil.WriteFile(path.Join(dir, fmt.Sprintf("%s.r0", x509NameHash(cert.Subject.ToRDNSequence()))), crl, 0777) + if err != nil { + t.Fatalf("ioutil.WriteFile failed err = %v", err) + } + + cp := x509.NewCertPool() + cp.AddCert(cert) + cliCfg := tls.Config{ + RootCAs: cp, + VerifyConnection: func(cs tls.ConnectionState) error { + return CheckRevocation(cs, RevocationConfig{RootDir: dir}) + }, + } + conn, err := tls.Dial(lis.Addr().Network(), lis.Addr().String(), &cliCfg) + t.Logf("tls.Dial err = %v", err) + if tt.success && err != nil { + t.Errorf("Expected success got err = %v", err) + } + if !tt.success && err == nil { + t.Error("Expected error, but got success") + } + if err == nil { + conn.Close() + } + }) + } +} diff --git a/security/advancedtls/examples/go.mod b/security/advancedtls/examples/go.mod index 936aa476893..20ed81e24d3 100644 --- a/security/advancedtls/examples/go.mod +++ b/security/advancedtls/examples/go.mod @@ -3,7 +3,7 @@ module google.golang.org/grpc/security/advancedtls/examples go 1.15 require ( - google.golang.org/grpc v1.33.1 + google.golang.org/grpc v1.38.0 google.golang.org/grpc/examples v0.0.0-20201112215255-90f1b3ee835b google.golang.org/grpc/security/advancedtls v0.0.0-20201112215255-90f1b3ee835b ) diff --git a/security/advancedtls/examples/go.sum b/security/advancedtls/examples/go.sum index 519267dbc27..272f1afa407 100644 --- a/security/advancedtls/examples/go.sum +++ b/security/advancedtls/examples/go.sum @@ -1,20 +1,25 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -22,35 +27,41 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 h1:LCO0fg4kb6WwkXQXRQQgUYsFeFb5taTX5WAx5O/Vt28= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= @@ -66,6 +77,6 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/security/advancedtls/go.mod b/security/advancedtls/go.mod index be35029503d..75527018ee7 100644 --- a/security/advancedtls/go.mod +++ b/security/advancedtls/go.mod @@ -4,7 +4,8 @@ go 1.14 require ( github.com/google/go-cmp v0.5.1 // indirect - google.golang.org/grpc v1.31.0 + github.com/hashicorp/golang-lru v0.5.4 + google.golang.org/grpc v1.38.0 google.golang.org/grpc/examples v0.0.0-20201112215255-90f1b3ee835b ) diff --git a/security/advancedtls/go.sum b/security/advancedtls/go.sum index 519267dbc27..272f1afa407 100644 --- a/security/advancedtls/go.sum +++ b/security/advancedtls/go.sum @@ -1,20 +1,25 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -22,35 +27,41 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 h1:LCO0fg4kb6WwkXQXRQQgUYsFeFb5taTX5WAx5O/Vt28= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= @@ -66,6 +77,6 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/security/advancedtls/sni.go b/security/advancedtls/sni.go index 120acf2b376..3e7befb1f90 100644 --- a/security/advancedtls/sni.go +++ b/security/advancedtls/sni.go @@ -1,5 +1,3 @@ -// +build !appengine,go1.14 - /* * * Copyright 2020 gRPC authors. diff --git a/security/advancedtls/sni_beforego114.go b/security/advancedtls/sni_beforego114.go deleted file mode 100644 index 26a09b98849..00000000000 --- a/security/advancedtls/sni_beforego114.go +++ /dev/null @@ -1,42 +0,0 @@ -// +build !appengine,!go1.14 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package advancedtls - -import ( - "crypto/tls" - "fmt" -) - -// buildGetCertificates returns the first cert contained in ServerOptions for -// non-appengine builds before version 1.4. -func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) { - if o.IdentityOptions.GetIdentityCertificatesForServer == nil { - return nil, fmt.Errorf("function GetCertificates must be specified") - } - certificates, err := o.IdentityOptions.GetIdentityCertificatesForServer(clientHello) - if err != nil { - return nil, err - } - if len(certificates) == 0 { - return nil, fmt.Errorf("no certificates configured") - } - return certificates[0], nil -} diff --git a/security/advancedtls/testdata/crl/0b35a562.r0 b/security/advancedtls/testdata/crl/0b35a562.r0 new file mode 120000 index 00000000000..1a84eabdfc7 --- /dev/null +++ b/security/advancedtls/testdata/crl/0b35a562.r0 @@ -0,0 +1 @@ +5.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/0b35a562.r1 b/security/advancedtls/testdata/crl/0b35a562.r1 new file mode 120000 index 00000000000..6e6f1097891 --- /dev/null +++ b/security/advancedtls/testdata/crl/0b35a562.r1 @@ -0,0 +1 @@ +1.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/1.crl b/security/advancedtls/testdata/crl/1.crl new file mode 100644 index 00000000000..5b12ded4a66 --- /dev/null +++ b/security/advancedtls/testdata/crl/1.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAyMS0wMi0wMlQwNzozMDozNi0wODowMCkX +DTIxMDIwMjE1MzAzNloXDTIxMDIwOTE1MzAzNlqgLzAtMB8GA1UdIwQYMBaAFPQN +tnCIBcG4ReQgoVi0kPgTROseMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0gAMEUC +IQDB9WEPBPHEo5xjCv8CT9okockJJnkLDOus6FypVLqj5QIgYw9/PYLwb41/Uc+4 +LLTAsfdDWh7xBJmqvVQglMoJOEc= +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/1ab871c8.r0 b/security/advancedtls/testdata/crl/1ab871c8.r0 new file mode 120000 index 00000000000..f2cd877e7ed --- /dev/null +++ b/security/advancedtls/testdata/crl/1ab871c8.r0 @@ -0,0 +1 @@ +2.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/2.crl b/security/advancedtls/testdata/crl/2.crl new file mode 100644 index 00000000000..5ca9afd7141 --- /dev/null +++ b/security/advancedtls/testdata/crl/2.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAyMS0wMi0wMlQwNzozMDozNi0wODowMCkX +DTIxMDIwMjE1MzAzNloXDTIxMDIwOTE1MzAzNlqgLzAtMB8GA1UdIwQYMBaAFBjo +V5Jnk/gp1k7fmWwkvTk/cF/IMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0gAMEUC +IQDgjA1Vj/pNFtNRL0vFEdapmFoArHM2+rn4IiP8jYLsCAIgAj2KEHbbtJ3zl5XP +WVW6ZyW7r3wIX+Bt3vLJWPrQtf8= +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/3.crl b/security/advancedtls/testdata/crl/3.crl new file mode 100644 index 00000000000..d37ad2247f5 --- /dev/null +++ b/security/advancedtls/testdata/crl/3.crl @@ -0,0 +1,11 @@ +-----BEGIN X509 CRL----- +MIIBiDCCAS8CAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAyMS0wMi0wMlQwNzozMTo1NC0wODowMCkX +DTIxMDIwMjE1MzE1NFoXDTIxMDIwOTE1MzE1NFowJzAlAhQAroEYW855BRqTrlov +5cBCGvkutxcNMjEwMjAyMTUzMTU0WqAvMC0wHwYDVR0jBBgwFoAUeq/TQ959KbWk +/um08jSTXogXpWUwCgYDVR0UBAMCAQEwCgYIKoZIzj0EAwIDRwAwRAIgaSOIhJDg +wOLYlbXkmxW0cqy/AfOUNYbz5D/8/FfvhosCICftg7Vzlu0Nh83jikyjy+wtkiJt +ZYNvGFQ3Sp2L3A9e +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/4.crl b/security/advancedtls/testdata/crl/4.crl new file mode 100644 index 00000000000..d4ee6f7cf18 --- /dev/null +++ b/security/advancedtls/testdata/crl/4.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAyMS0wMi0wMlQwNzozMTo1NC0wODowMCkX +DTIxMDIwMjE1MzE1NFoXDTIxMDIwOTE1MzE1NFqgLzAtMB8GA1UdIwQYMBaAFIVn +8tIFgZpIdhomgYJ2c5ULLzpSMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0gAMEUC +ICupTvOqgAyRa1nn7+Pe/1vvlJPAQ8gUfTQsQ6XX3v6oAiEA08B2PsK6aTEwzjry +pXqhlUNZFzgaXrVVQuEJbyJ1qoU= +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/5.crl b/security/advancedtls/testdata/crl/5.crl new file mode 100644 index 00000000000..d1c24f0f25a --- /dev/null +++ b/security/advancedtls/testdata/crl/5.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBXzCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAyMS0wMi0wMlQwNzozMjo1Ny0wODowMCkX +DTIxMDIwMjE1MzI1N1oXDTIxMDIwOTE1MzI1N1qgLzAtMB8GA1UdIwQYMBaAFN+g +xTAtSTlb5Qqvrbp4rZtsaNzqMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0cAMEQC +IHrRKjieY7w7gxvpkJAdszPZBlaSSp/c9wILutBTy7SyAiAwhaHfgas89iRfaBs2 +EhGIeK39A+kSzqu6qEQBHpK36g== +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/6.crl b/security/advancedtls/testdata/crl/6.crl new file mode 100644 index 00000000000..87ef378f6ab --- /dev/null +++ b/security/advancedtls/testdata/crl/6.crl @@ -0,0 +1,11 @@ +-----BEGIN X509 CRL----- +MIIBiDCCAS8CAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAyMS0wMi0wMlQwNzozMjo1Ny0wODowMCkX +DTIxMDIwMjE1MzI1N1oXDTIxMDIwOTE1MzI1N1owJzAlAhQAxSe/pGmyvzN7mxm5 +6ZJTYUXYuhcNMjEwMjAyMTUzMjU3WqAvMC0wHwYDVR0jBBgwFoAUpZ30UJXB4lI9 +j2SzodCtRFckrRcwCgYDVR0UBAMCAQEwCgYIKoZIzj0EAwIDRwAwRAIgRg3u7t3b +oyV5FhMuGGzWnfIwnKclpT8imnp8tEN253sCIFUY7DjiDohwu4Zup3bWs1OaZ3q3 +cm+j0H/oe8zzCAgp +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/71eac5a2.r0 b/security/advancedtls/testdata/crl/71eac5a2.r0 new file mode 120000 index 00000000000..9f37924cae0 --- /dev/null +++ b/security/advancedtls/testdata/crl/71eac5a2.r0 @@ -0,0 +1 @@ +4.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/7a1799af.r0 b/security/advancedtls/testdata/crl/7a1799af.r0 new file mode 120000 index 00000000000..f34df5b59c0 --- /dev/null +++ b/security/advancedtls/testdata/crl/7a1799af.r0 @@ -0,0 +1 @@ +3.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/8828a7e6.r0 b/security/advancedtls/testdata/crl/8828a7e6.r0 new file mode 120000 index 00000000000..70bead214cc --- /dev/null +++ b/security/advancedtls/testdata/crl/8828a7e6.r0 @@ -0,0 +1 @@ +6.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/README.md b/security/advancedtls/testdata/crl/README.md new file mode 100644 index 00000000000..00cb09c3192 --- /dev/null +++ b/security/advancedtls/testdata/crl/README.md @@ -0,0 +1,48 @@ +# CRL Test Data + +This directory contains cert chains and CRL files for revocation testing. + +To print the chain, use a command like, + +```shell +openssl crl2pkcs7 -nocrl -certfile security/crl/x509/client/testdata/revokedLeaf.pem | openssl pkcs7 -print_certs -text -noout +``` + +The crl file symlinks are generated with `openssl rehash` + +## unrevoked.pem + +A certificate chain with CRL files and unrevoked certs + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=Root CA (2021-02-02T07:30:36-08:00) + * 1.crl + +NOTE: 1.crl file is symlinked with 5.crl to simulate two issuers that hash to +the same value to test that loading multiple files works. + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=node CA (2021-02-02T07:30:36-08:00) + * 2.crl + +## revokedInt.pem + +Certificate chain where the intermediate is revoked + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=Root CA (2021-02-02T07:31:54-08:00) + * 3.crl +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=node CA (2021-02-02T07:31:54-08:00) + * 4.crl + +## revokedLeaf.pem + +Certificate chain where the leaf is revoked + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=Root CA (2021-02-02T07:32:57-08:00) + * 5.crl +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=node CA (2021-02-02T07:32:57-08:00) + * 6.crl diff --git a/security/advancedtls/testdata/crl/deee447d.r0 b/security/advancedtls/testdata/crl/deee447d.r0 new file mode 120000 index 00000000000..1a84eabdfc7 --- /dev/null +++ b/security/advancedtls/testdata/crl/deee447d.r0 @@ -0,0 +1 @@ +5.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/revokedInt.pem b/security/advancedtls/testdata/crl/revokedInt.pem new file mode 100644 index 00000000000..8b7282ff822 --- /dev/null +++ b/security/advancedtls/testdata/crl/revokedInt.pem @@ -0,0 +1,58 @@ +-----BEGIN CERTIFICATE----- +MIIDAzCCAqmgAwIBAgITAWjKwm2dNQvkO62Jgyr5rAvVQzAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNSb290IENBICgyMDIx +LTAyLTAyVDA3OjMxOjU0LTA4OjAwKTAgFw0yMTAyMDIxNTMxNTRaGA85OTk5MTIz +MTIzNTk1OVowgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYw +FAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYD +VQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9v +dCBDQSAoMjAyMS0wMi0wMlQwNzozMTo1NC0wODowMCkwWTATBgcqhkjOPQIBBggq +hkjOPQMBBwNCAAQhA0/puhTtSxbVVHseVhL2z7QhpPyJs5Q4beKi7tpaYRDmVn6p +Phh+jbRzg8Qj4gKI/Q1rrdm4rKer63LHpdWdo4GzMIGwMA4GA1UdDwEB/wQEAwIB +BjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUwAwEB +/zAdBgNVHQ4EFgQUeq/TQ959KbWk/um08jSTXogXpWUwHwYDVR0jBBgwFoAUeq/T +Q959KbWk/um08jSTXogXpWUwLgYDVR0RBCcwJYYjc3BpZmZlOi8vY2FtcHVzLXNs +bi5wcm9kLmdvb2dsZS5jb20wCgYIKoZIzj0EAwIDSAAwRQIgOSQZvyDPQwVOWnpF +zWvI+DS2yXIj/2T2EOvJz2XgcK4CIQCL0mh/+DxLiO4zzbInKr0mxpGSxSeZCUk7 +1ZF7AeLlbw== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDizCCAzKgAwIBAgIUAK6BGFvOeQUak65aL+XAQhr5LrcwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMTo1NC0wODowMCkwIBcNMjEwMjAyMTUzMTU0WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI25v +ZGUgQ0EgKDIwMjEtMDItMDJUMDc6MzE6NTQtMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEye6UOlBos8Q3FFBiLahD9BaLTA18bO4MTPyv35T3lppvxD5X +U/AnEllOnx5OMtMjMBbIQjSkMbiQ9xNXoSqB6aOCATowggE2MA4GA1UdDwEB/wQE +AwIBBjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUhWfy0gWBmkh2GiaBgnZzlQsvOlIwHwYDVR0jBBgwFoAU +eq/TQ959KbWk/um08jSTXogXpWUwMwYDVR0RBCwwKoYoc3BpZmZlOi8vbm9kZS5j +YW1wdXMtc2xuLnByb2QuZ29vZ2xlLmNvbTA7BgNVHR4BAf8EMTAvoC0wK4YpY3Nj +cy10ZWFtLm5vZGUuY2FtcHVzLXNsbi5wcm9kLmdvb2dsZS5jb20wQgYDVR0fBDsw +OTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2NhbXB1 +cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNHADBEAiA79rPu6ZO1/0qB6RxL7jVz1200 +UTo8ioB4itbTzMnJqAIgJqp/Rc8OhpsfzQX8XnIIkl+SewT+tOxJT1MHVNMlVhc= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIC0DCCAnWgAwIBAgITXQ2c/C27OGqk4Pbu+MNJlOtpYTAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNub2RlIENBICgyMDIx +LTAyLTAyVDA3OjMxOjU0LTA4OjAwKTAgFw0yMTAyMDIxNTMxNTRaGA85OTk5MTIz +MTIzNTk1OVowADBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABN2/1le5d3hS/piw +hrNMHjd7gPEjzXwtuXQTzdV+aaeOf3ldnC6OnEF/bggym9MldQSJZLXPYSaoj430 +Vu5PRNejggEkMIIBIDAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUH +AwIGCCsGAQUFBwMBMB0GA1UdDgQWBBTEewP3JgrJPekWWGGjChVqaMhaqTAfBgNV +HSMEGDAWgBSFZ/LSBYGaSHYaJoGCdnOVCy86UjBrBgNVHREBAf8EYTBfghZqemFi +MTIucHJvZC5nb29nbGUuY29thkVzcGlmZmU6Ly9jc2NzLXRlYW0ubm9kZS5jYW1w +dXMtc2xuLnByb2QuZ29vZ2xlLmNvbS9yb2xlL2JvcmctYWRtaW4tY28wQgYDVR0f +BDswOTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2Nh +bXB1cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNJADBGAiEA9w4qp3nHpXo+6d7mZc69 +QoALfP5ynfBCArt8bAlToo8CIQCgc/lTfl2BtBko+7h/w6pKxLeuoQkvCL5gHFyK +LXE6vA== +-----END CERTIFICATE----- diff --git a/security/advancedtls/testdata/crl/revokedLeaf.pem b/security/advancedtls/testdata/crl/revokedLeaf.pem new file mode 100644 index 00000000000..b7541abf621 --- /dev/null +++ b/security/advancedtls/testdata/crl/revokedLeaf.pem @@ -0,0 +1,59 @@ +-----BEGIN CERTIFICATE----- +MIIDAzCCAqmgAwIBAgITTwodm6C4ZabFVUVa5yBw0TbzJTAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNSb290IENBICgyMDIx +LTAyLTAyVDA3OjMyOjU3LTA4OjAwKTAgFw0yMTAyMDIxNTMyNTdaGA85OTk5MTIz +MTIzNTk1OVowgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYw +FAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYD +VQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9v +dCBDQSAoMjAyMS0wMi0wMlQwNzozMjo1Ny0wODowMCkwWTATBgcqhkjOPQIBBggq +hkjOPQMBBwNCAARoZnzQWvAoyhvCLA2cFIK17khSaA9aA+flS5X9fLRt4RsfPCx3 +kim7wYKQSmBhQdc1UM4h3969r1c1Fvsh2H9qo4GzMIGwMA4GA1UdDwEB/wQEAwIB +BjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUwAwEB +/zAdBgNVHQ4EFgQU36DFMC1JOVvlCq+tunitm2xo3OowHwYDVR0jBBgwFoAU36DF +MC1JOVvlCq+tunitm2xo3OowLgYDVR0RBCcwJYYjc3BpZmZlOi8vY2FtcHVzLXNs +bi5wcm9kLmdvb2dsZS5jb20wCgYIKoZIzj0EAwIDSAAwRQIgN7S9dQOQzNih92ag +7c5uQxuz+M6wnxWj/uwGQIIghRUCIQD2UDH6kkRSYQuyP0oN7XYO3XFjmZ2Yer6m +1ZS8fyWYYA== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDjTCCAzKgAwIBAgIUAOmArBu9gihLTlqP3W7Et0UoocEwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMjo1Ny0wODowMCkwIBcNMjEwMjAyMTUzMjU3WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI25v +ZGUgQ0EgKDIwMjEtMDItMDJUMDc6MzI6NTctMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEfrgVEVQfSEFeCF1/FGeW7oq0yxecenT1BESfj4Z0zJ8p7P9W +bj1o6Rn6dUNlEhGrx7E3/4NFJ0cL1BSNGHkjiqOCATowggE2MA4GA1UdDwEB/wQE +AwIBBjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUpZ30UJXB4lI9j2SzodCtRFckrRcwHwYDVR0jBBgwFoAU +36DFMC1JOVvlCq+tunitm2xo3OowMwYDVR0RBCwwKoYoc3BpZmZlOi8vbm9kZS5j +YW1wdXMtc2xuLnByb2QuZ29vZ2xlLmNvbTA7BgNVHR4BAf8EMTAvoC0wK4YpY3Nj +cy10ZWFtLm5vZGUuY2FtcHVzLXNsbi5wcm9kLmdvb2dsZS5jb20wQgYDVR0fBDsw +OTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2NhbXB1 +cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNJADBGAiEAnuONgMqmbBlj4ibw5BgDtZUM +pboACSFJtEOJu4Yqjt0CIQDI5193J4wUcAY0BK0vO9rRfbNOIc+4ke9ieBDPSuhm +mA== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIICzzCCAnagAwIBAgIUAMUnv6Rpsr8ze5sZuemSU2FF2LowCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAy +MS0wMi0wMlQwNzozMjo1Ny0wODowMCkwIBcNMjEwMjAyMTUzMjU3WhgPOTk5OTEy +MzEyMzU5NTlaMAAwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAASCmYiIHUux5WFz +S0ksJzAPL7YTEh5o5MdXgLPB/WM6x9sVsQDSYU0PF5qc9vPNhkQzGBW79dkBnxhW +AGJkFr1Po4IBJDCCASAwDgYDVR0PAQH/BAQDAgeAMB0GA1UdJQQWMBQGCCsGAQUF +BwMCBggrBgEFBQcDATAdBgNVHQ4EFgQUCR1CGEdlks0qcxCExO0rP1B/Z7UwHwYD +VR0jBBgwFoAUpZ30UJXB4lI9j2SzodCtRFckrRcwawYDVR0RAQH/BGEwX4IWanph +YjEyLnByb2QuZ29vZ2xlLmNvbYZFc3BpZmZlOi8vY3Njcy10ZWFtLm5vZGUuY2Ft +cHVzLXNsbi5wcm9kLmdvb2dsZS5jb20vcm9sZS9ib3JnLWFkbWluLWNvMEIGA1Ud +HwQ7MDkwN6A1oDOGMWh0dHA6Ly9zdGF0aWMuY29ycC5nb29nbGUuY29tL2NybC9j +YW1wdXMtc2xuL25vZGUwCgYIKoZIzj0EAwIDRwAwRAIgK9vQYNoL8HlEwWv89ioG +aQ1+8swq6Bo/5mJBrdVLvY8CIGxo6M9vJkPdObmetWNC+lmKuZDoqJWI0AAmBT2J +mR2r +-----END CERTIFICATE----- diff --git a/security/advancedtls/testdata/crl/unrevoked.pem b/security/advancedtls/testdata/crl/unrevoked.pem new file mode 100644 index 00000000000..5c5fc58a7a5 --- /dev/null +++ b/security/advancedtls/testdata/crl/unrevoked.pem @@ -0,0 +1,58 @@ +-----BEGIN CERTIFICATE----- +MIIDBDCCAqqgAwIBAgIUALy864QhnkTdceLH52k2XVOe8IQwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMDozNi0wODowMCkwIBcNMjEwMjAyMTUzMDM2WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI1Jv +b3QgQ0EgKDIwMjEtMDItMDJUMDc6MzA6MzYtMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEYv/JS5hQ5kIgdKqYZWTKCO/6gloHAmIb1G8lmY0oXLXYNHQ4 +qHN7/pPtlcHQp0WK/hM8IGvgOUDoynA8mj0H9KOBszCBsDAOBgNVHQ8BAf8EBAMC +AQYwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMA8GA1UdEwEB/wQFMAMB +Af8wHQYDVR0OBBYEFPQNtnCIBcG4ReQgoVi0kPgTROseMB8GA1UdIwQYMBaAFPQN +tnCIBcG4ReQgoVi0kPgTROseMC4GA1UdEQQnMCWGI3NwaWZmZTovL2NhbXB1cy1z +bG4ucHJvZC5nb29nbGUuY29tMAoGCCqGSM49BAMCA0gAMEUCIQDwBn20DB4X/7Uk +Q5BR8JxQYUPxOfvuedjfeA8bPvQ2FwIgOEWa0cXJs1JxarILJeCXtdXvBgu6LEGQ +3Pk/bgz8Gek= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDizCCAzKgAwIBAgIUAM/6RKQ7Vke0i4xp5LaAqV73cmIwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMDozNi0wODowMCkwIBcNMjEwMjAyMTUzMDM2WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI25v +ZGUgQ0EgKDIwMjEtMDItMDJUMDc6MzA6MzYtMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEllnhxmMYiUPUgRGmenbnm10gXpM94zHx3D1/HumPs6arjYuT +Zlhx81XL+g4bu4HII2qcGdP+Hqj/MMFNDI9z4aOCATowggE2MA4GA1UdDwEB/wQE +AwIBBjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUGOhXkmeT+CnWTt+ZbCS9OT9wX8gwHwYDVR0jBBgwFoAU +9A22cIgFwbhF5CChWLSQ+BNE6x4wMwYDVR0RBCwwKoYoc3BpZmZlOi8vbm9kZS5j +YW1wdXMtc2xuLnByb2QuZ29vZ2xlLmNvbTA7BgNVHR4BAf8EMTAvoC0wK4YpY3Nj +cy10ZWFtLm5vZGUuY2FtcHVzLXNsbi5wcm9kLmdvb2dsZS5jb20wQgYDVR0fBDsw +OTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2NhbXB1 +cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNHADBEAiA86egqPw0qyapAeMGbHxrmYZYa +i5ARQsSKRmQixgYizQIgW+2iRWN6Kbqt4WcwpmGv/xDckdRXakF5Ign/WUDO5u4= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIICzzCCAnWgAwIBAgITYjjKfYZUKQNUjNyF+hLDGpHJKTAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNub2RlIENBICgyMDIx +LTAyLTAyVDA3OjMwOjM2LTA4OjAwKTAgFw0yMTAyMDIxNTMwMzZaGA85OTk5MTIz +MTIzNTk1OVowADBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD4r4+nCgZExYF8v +CLvGn0lY/cmam8mAkJDXRN2Ja2t+JwaTOptPmbbXft+1NTk5gCg5wB+FJCnaV3I/ +HaxEhBWjggEkMIIBIDAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUH +AwIGCCsGAQUFBwMBMB0GA1UdDgQWBBTTCjXX1Txjc00tBg/5cFzpeCSKuDAfBgNV +HSMEGDAWgBQY6FeSZ5P4KdZO35lsJL05P3BfyDBrBgNVHREBAf8EYTBfghZqemFi +MTIucHJvZC5nb29nbGUuY29thkVzcGlmZmU6Ly9jc2NzLXRlYW0ubm9kZS5jYW1w +dXMtc2xuLnByb2QuZ29vZ2xlLmNvbS9yb2xlL2JvcmctYWRtaW4tY28wQgYDVR0f +BDswOTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2Nh +bXB1cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNIADBFAiBq3URViNyMLpvzZHC1Y+4L ++35guyIJfjHu08P3S8/xswIhAJtWSQ1ZtozdOzGxg7GfUo4hR+5SP6rBTgIqXEfq +48fW +-----END CERTIFICATE----- diff --git a/security/authorization/go.mod b/security/authorization/go.mod index 0581b3401f3..ce34742af2c 100644 --- a/security/authorization/go.mod +++ b/security/authorization/go.mod @@ -1,6 +1,6 @@ module google.golang.org/grpc/security/authorization -go 1.12 +go 1.14 require ( github.com/envoyproxy/go-control-plane v0.9.5 diff --git a/security/authorization/go.sum b/security/authorization/go.sum index a953711e01e..3c7ea6cf47f 100644 --- a/security/authorization/go.sum +++ b/security/authorization/go.sum @@ -14,14 +14,11 @@ github.com/envoyproxy/go-control-plane v0.9.5 h1:lRJIqDD8yjV1YyPRqecMdytjDLs2fTX github.com/envoyproxy/go-control-plane v0.9.5/go.mod h1:OXl5to++W0ctG+EHWTFUjiypVxC/Y4VLc/KFU+al13s= github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.4 h1:87PNWwrRvUSnqS4dlcBU/ftvOIBep4sYuBLlh6rX2wk= github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -36,7 +33,6 @@ github.com/google/cel-spec v0.4.0/go.mod h1:2pBM5cU4UKjbPDXBgwWkiwBsVgnxknuEJ7C5 github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -49,7 +45,6 @@ golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20200301022130-244492dfa37a h1:GuSPYbZzB5/dcLNCwLQLsg3obCJtX9IJhpXkvY7kzk0= golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -58,11 +53,9 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -76,18 +69,14 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 h1:gSJIx1SDwno+2ElGhA4+qG2zF97qiUzTM+rQ0klBOcE= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200305110556-506484158171 h1:xes2Q2k+d/+YNXVw0FpZkIDJiaux4OVrRKXRAzH6A0U= google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1 h1:wdKvqQk7IttEw92GoRyKG2IDrUIpgpj6H6m81yfeMW0= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.1 h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk= google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= diff --git a/server.go b/server.go index 7a2aa28a114..eadf9e05fd1 100644 --- a/server.go +++ b/server.go @@ -57,12 +57,22 @@ import ( const ( defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 defaultServerMaxSendMessageSize = math.MaxInt32 + + // Server transports are tracked in a map which is keyed on listener + // address. For regular gRPC traffic, connections are accepted in Serve() + // through a call to Accept(), and we use the actual listener address as key + // when we add it to the map. But for connections received through + // ServeHTTP(), we do not have a listener and hence use this dummy value. + listenerAddressForServeHTTP = "listenerAddressForServeHTTP" ) func init() { internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials { return srv.opts.creds } + internal.DrainServerTransports = func(srv *Server, addr string) { + srv.drainServerTransports(addr) + } } var statusOK = status.New(codes.OK, "") @@ -107,9 +117,12 @@ type serverWorkerData struct { type Server struct { opts serverOptions - mu sync.Mutex // guards following - lis map[net.Listener]bool - conns map[transport.ServerTransport]bool + mu sync.Mutex // guards following + lis map[net.Listener]bool + // conns contains all active server transports. It is a map keyed on a + // listener address with the value being the set of active transports + // belonging to that listener. + conns map[string]map[transport.ServerTransport]bool serve bool drain bool cv *sync.Cond // signaled when connections close for GracefulStop @@ -266,6 +279,35 @@ func CustomCodec(codec Codec) ServerOption { }) } +// ForceServerCodec returns a ServerOption that sets a codec for message +// marshaling and unmarshaling. +// +// This will override any lookups by content-subtype for Codecs registered +// with RegisterCodec. +// +// See Content-Type on +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. Also see the documentation on RegisterCodec and +// CallContentSubtype for more details on the interaction between encoding.Codec +// and content-subtype. +// +// This function is provided for advanced users; prefer to register codecs +// using encoding.RegisterCodec. +// The server will automatically use registered codecs based on the incoming +// requests' headers. See also +// https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec. +// Will be supported throughout 1.x. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func ForceServerCodec(codec encoding.Codec) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.codec = codec + }) +} + // RPCCompressor returns a ServerOption that sets a compressor for outbound // messages. For backward compatibility, all outbound messages will be sent // using this compressor, regardless of incoming message compression. By @@ -376,6 +418,11 @@ func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOptio // InTapHandle returns a ServerOption that sets the tap handle for all the server // transport to be created. Only one can be installed. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. func InTapHandle(h tap.ServerInHandle) ServerOption { return newFuncServerOption(func(o *serverOptions) { if o.inTapHandle != nil { @@ -519,7 +566,7 @@ func NewServer(opt ...ServerOption) *Server { s := &Server{ lis: make(map[net.Listener]bool), opts: opts, - conns: make(map[transport.ServerTransport]bool), + conns: make(map[string]map[transport.ServerTransport]bool), services: make(map[string]*serviceInfo), quit: grpcsync.NewEvent(), done: grpcsync.NewEvent(), @@ -663,13 +710,6 @@ func (s *Server) GetServiceInfo() map[string]ServiceInfo { // the server being stopped. var ErrServerStopped = errors.New("grpc: the server has been stopped") -func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - if s.opts.creds == nil { - return rawConn, nil, nil - } - return s.opts.creds.ServerHandshake(rawConn) -} - type listenSocket struct { net.Listener channelzID int64 @@ -778,7 +818,7 @@ func (s *Server) Serve(lis net.Listener) error { // s.conns before this conn can be added. s.serveWG.Add(1) go func() { - s.handleRawConn(rawConn) + s.handleRawConn(lis.Addr().String(), rawConn) s.serveWG.Done() }() } @@ -786,49 +826,45 @@ func (s *Server) Serve(lis net.Listener) error { // handleRawConn forks a goroutine to handle a just-accepted connection that // has not had any I/O performed on it yet. -func (s *Server) handleRawConn(rawConn net.Conn) { +func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { if s.quit.HasFired() { rawConn.Close() return } rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) - conn, authInfo, err := s.useTransportAuthenticator(rawConn) - if err != nil { - // ErrConnDispatched means that the connection was dispatched away from - // gRPC; those connections should be left open. - if err != credentials.ErrConnDispatched { - s.mu.Lock() - s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) - s.mu.Unlock() - channelz.Warningf(logger, s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) - rawConn.Close() - } - rawConn.SetDeadline(time.Time{}) - return - } // Finish handshaking (HTTP2) - st := s.newHTTP2Transport(conn, authInfo) + st := s.newHTTP2Transport(rawConn) + rawConn.SetDeadline(time.Time{}) if st == nil { return } - rawConn.SetDeadline(time.Time{}) - if !s.addConn(st) { + if !s.addConn(lisAddr, st) { return } go func() { s.serveStreams(st) - s.removeConn(st) + s.removeConn(lisAddr, st) }() } +func (s *Server) drainServerTransports(addr string) { + s.mu.Lock() + conns := s.conns[addr] + for st := range conns { + st.Drain() + } + s.mu.Unlock() +} + // newHTTP2Transport sets up a http/2 transport (using the // gRPC http2 server transport in transport/http2_server.go). -func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) transport.ServerTransport { +func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { config := &transport.ServerConfig{ MaxStreams: s.opts.maxConcurrentStreams, - AuthInfo: authInfo, + ConnectionTimeout: s.opts.connectionTimeout, + Credentials: s.opts.creds, InTapHandle: s.opts.inTapHandle, StatsHandler: s.opts.statsHandler, KeepaliveParams: s.opts.keepaliveParams, @@ -841,13 +877,20 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr MaxHeaderListSize: s.opts.maxHeaderListSize, HeaderTableSize: s.opts.headerTableSize, } - st, err := transport.NewServerTransport("http2", c, config) + st, err := transport.NewServerTransport(c, config) if err != nil { s.mu.Lock() s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) s.mu.Unlock() - c.Close() - channelz.Warning(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err) + // ErrConnDispatched means that the connection was dispatched away from + // gRPC; those connections should be left open. + if err != credentials.ErrConnDispatched { + // Don't log on ErrConnDispatched and io.EOF to prevent log spam. + if err != io.EOF { + channelz.Warning(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err) + } + c.Close() + } return nil } @@ -924,10 +967,10 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - if !s.addConn(st) { + if !s.addConn(listenerAddressForServeHTTP, st) { return } - defer s.removeConn(st) + defer s.removeConn(listenerAddressForServeHTTP, st) s.serveStreams(st) } @@ -955,7 +998,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea return trInfo } -func (s *Server) addConn(st transport.ServerTransport) bool { +func (s *Server) addConn(addr string, st transport.ServerTransport) bool { s.mu.Lock() defer s.mu.Unlock() if s.conns == nil { @@ -967,15 +1010,28 @@ func (s *Server) addConn(st transport.ServerTransport) bool { // immediately. st.Drain() } - s.conns[st] = true + + if s.conns[addr] == nil { + // Create a map entry if this is the first connection on this listener. + s.conns[addr] = make(map[transport.ServerTransport]bool) + } + s.conns[addr][st] = true return true } -func (s *Server) removeConn(st transport.ServerTransport) { +func (s *Server) removeConn(addr string, st transport.ServerTransport) { s.mu.Lock() defer s.mu.Unlock() - if s.conns != nil { - delete(s.conns, st) + + conns := s.conns[addr] + if conns != nil { + delete(conns, st) + if len(conns) == 0 { + // If the last connection for this address is being removed, also + // remove the map entry corresponding to the address. This is used + // in GracefulStop() when waiting for all connections to be closed. + delete(s.conns, addr) + } s.cv.Broadcast() } } @@ -1040,22 +1096,29 @@ func chainUnaryServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { - return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) - } + chainedInt = chainUnaryInterceptors(interceptors) } s.opts.unaryInt = chainedInt } -// getChainUnaryHandler recursively generate the chained UnaryHandler -func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler { - if curr == len(interceptors)-1 { - return finalHandler - } - - return func(ctx context.Context, req interface{}) (interface{}, error) { - return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) +func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { + // the struct ensures the variables are allocated together, rather than separately, since we + // know they should be garbage collected together. This saves 1 allocation and decreases + // time/call by about 10% on the microbenchmark. + var state struct { + i int + next UnaryHandler + } + state.next = func(ctx context.Context, req interface{}) (interface{}, error) { + if state.i == len(interceptors)-1 { + return interceptors[state.i](ctx, req, info, handler) + } + state.i++ + return interceptors[state.i-1](ctx, req, info, state.next) + } + return state.next(ctx, req) } } @@ -1069,7 +1132,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if sh != nil { beginTime := time.Now() statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: beginTime, + IsClientStream: false, + IsServerStream: false, } sh.HandleRPC(stream.Context(), statsBegin) } @@ -1321,22 +1386,29 @@ func chainStreamServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { - return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler)) - } + chainedInt = chainStreamInterceptors(interceptors) } s.opts.streamInt = chainedInt } -// getChainStreamHandler recursively generate the chained StreamHandler -func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler { - if curr == len(interceptors)-1 { - return finalHandler - } - - return func(srv interface{}, ss ServerStream) error { - return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler)) +func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor { + return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { + // the struct ensures the variables are allocated together, rather than separately, since we + // know they should be garbage collected together. This saves 1 allocation and decreases + // time/call by about 10% on the microbenchmark. + var state struct { + i int + next StreamHandler + } + state.next = func(srv interface{}, ss ServerStream) error { + if state.i == len(interceptors)-1 { + return interceptors[state.i](srv, ss, info, handler) + } + state.i++ + return interceptors[state.i-1](srv, ss, info, state.next) + } + return state.next(srv, ss) } } @@ -1349,7 +1421,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if sh != nil { beginTime := time.Now() statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: beginTime, + IsClientStream: sd.ClientStreams, + IsServerStream: sd.ServerStreams, } sh.HandleRPC(stream.Context(), statsBegin) } @@ -1452,6 +1526,8 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp } } + ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp) + if trInfo != nil { trInfo.tr.LazyLog(&trInfo.firstLine, false) } @@ -1519,7 +1595,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.SetError() } errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) - if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil { + if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -1639,7 +1715,7 @@ func (s *Server) Stop() { s.mu.Lock() listeners := s.lis s.lis = nil - st := s.conns + conns := s.conns s.conns = nil // interrupt GracefulStop if Stop and GracefulStop are called concurrently. s.cv.Broadcast() @@ -1648,8 +1724,10 @@ func (s *Server) Stop() { for lis := range listeners { lis.Close() } - for c := range st { - c.Close() + for _, cs := range conns { + for st := range cs { + st.Close() + } } if s.opts.numServerWorkers > 0 { s.stopServerWorkers() @@ -1686,8 +1764,10 @@ func (s *Server) GracefulStop() { } s.lis = nil if !s.drain { - for st := range s.conns { - st.Drain() + for _, conns := range s.conns { + for st := range conns { + st.Drain() + } } s.drain = true } diff --git a/server_test.go b/server_test.go index fcfde30706c..b1593916014 100644 --- a/server_test.go +++ b/server_test.go @@ -22,6 +22,7 @@ import ( "context" "net" "reflect" + "strconv" "strings" "testing" "time" @@ -130,3 +131,59 @@ func (s) TestStreamContext(t *testing.T) { t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream) } } + +func BenchmarkChainUnaryInterceptor(b *testing.B) { + for _, n := range []int{1, 3, 5, 10} { + n := n + b.Run(strconv.Itoa(n), func(b *testing.B) { + interceptors := make([]UnaryServerInterceptor, 0, n) + for i := 0; i < n; i++ { + interceptors = append(interceptors, func( + ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler, + ) (interface{}, error) { + return handler(ctx, req) + }) + } + + s := NewServer(ChainUnaryInterceptor(interceptors...)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := s.opts.unaryInt(context.Background(), nil, nil, + func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }, + ); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkChainStreamInterceptor(b *testing.B) { + for _, n := range []int{1, 3, 5, 10} { + n := n + b.Run(strconv.Itoa(n), func(b *testing.B) { + interceptors := make([]StreamServerInterceptor, 0, n) + for i := 0; i < n; i++ { + interceptors = append(interceptors, func( + srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler, + ) error { + return handler(srv, ss) + }) + } + + s := NewServer(ChainStreamInterceptor(interceptors...)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := s.opts.streamInt(nil, nil, nil, func(srv interface{}, stream ServerStream) error { + return nil + }); err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/stats/stats.go b/stats/stats.go index 63e476ee7ff..0285dcc6a26 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -36,15 +36,22 @@ type RPCStats interface { IsClient() bool } -// Begin contains stats when an RPC begins. +// Begin contains stats when an RPC attempt begins. // FailFast is only valid if this Begin is from client side. type Begin struct { // Client is true if this Begin is from client side. Client bool - // BeginTime is the time when the RPC begins. + // BeginTime is the time when the RPC attempt begins. BeginTime time.Time // FailFast indicates if this RPC is failfast. FailFast bool + // IsClientStream indicates whether the RPC is a client streaming RPC. + IsClientStream bool + // IsServerStream indicates whether the RPC is a server streaming RPC. + IsServerStream bool + // IsTransparentRetryAttempt indicates whether this attempt was initiated + // due to transparently retrying a previous attempt. + IsTransparentRetryAttempt bool } // IsClient indicates if the stats information is from client side. diff --git a/stats/stats_test.go b/stats/stats_test.go index 306f2f6b8e9..dfc6edfc3d3 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -407,15 +407,17 @@ func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallReq } type expectedData struct { - method string - serverAddr string - compression string - reqIdx int - requests []proto.Message - respIdx int - responses []proto.Message - err error - failfast bool + method string + isClientStream bool + isServerStream bool + serverAddr string + compression string + reqIdx int + requests []proto.Message + respIdx int + responses []proto.Message + err error + failfast bool } type gotData struct { @@ -456,6 +458,12 @@ func checkBegin(t *testing.T, d *gotData, e *expectedData) { t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast) } } + if st.IsClientStream != e.isClientStream { + t.Fatalf("st.IsClientStream = %v, want %v", st.IsClientStream, e.isClientStream) + } + if st.IsServerStream != e.isServerStream { + t.Fatalf("st.IsServerStream = %v, want %v", st.IsServerStream, e.isServerStream) + } } func checkInHeader(t *testing.T, d *gotData, e *expectedData) { @@ -847,6 +855,9 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f err error method string + isClientStream bool + isServerStream bool + req proto.Message resp proto.Message e error @@ -864,14 +875,18 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f reqs, resp, e = te.doClientStreamCall(cc) resps = []proto.Message{resp} err = e + isClientStream = true case serverStreamRPC: method = "/grpc.testing.TestService/StreamingOutputCall" req, resps, e = te.doServerStreamCall(cc) reqs = []proto.Message{req} err = e + isServerStream = true case fullDuplexStreamRPC: method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) + isClientStream = true + isServerStream = true } if cc.success != (err == nil) { t.Fatalf("cc.success: %v, got error: %v", cc.success, err) @@ -900,12 +915,14 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f } expect := &expectedData{ - serverAddr: te.srvAddr, - compression: tc.compress, - method: method, - requests: reqs, - responses: resps, - err: err, + serverAddr: te.srvAddr, + compression: tc.compress, + method: method, + requests: reqs, + responses: resps, + err: err, + isClientStream: isClientStream, + isServerStream: isServerStream, } h.mu.Lock() @@ -1138,6 +1155,9 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map method string err error + isClientStream bool + isServerStream bool + req proto.Message resp proto.Message e error @@ -1154,14 +1174,18 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map reqs, resp, e = te.doClientStreamCall(cc) resps = []proto.Message{resp} err = e + isClientStream = true case serverStreamRPC: method = "/grpc.testing.TestService/StreamingOutputCall" req, resps, e = te.doServerStreamCall(cc) reqs = []proto.Message{req} err = e + isServerStream = true case fullDuplexStreamRPC: method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) + isClientStream = true + isServerStream = true } if cc.success != (err == nil) { t.Fatalf("cc.success: %v, got error: %v", cc.success, err) @@ -1194,13 +1218,15 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map } expect := &expectedData{ - serverAddr: te.srvAddr, - compression: tc.compress, - method: method, - requests: reqs, - responses: resps, - failfast: cc.failfast, - err: err, + serverAddr: te.srvAddr, + compression: tc.compress, + method: method, + requests: reqs, + responses: resps, + failfast: cc.failfast, + err: err, + isClientStream: isClientStream, + isServerStream: isServerStream, } h.mu.Lock() diff --git a/stream.go b/stream.go index 8ba0e8a5eea..fc3299ae1f0 100644 --- a/stream.go +++ b/stream.go @@ -52,14 +52,20 @@ import ( // of the RPC. type StreamHandler func(srv interface{}, stream ServerStream) error -// StreamDesc represents a streaming RPC service's method specification. +// StreamDesc represents a streaming RPC service's method specification. Used +// on the server when registering services and on the client when initiating +// new streams. type StreamDesc struct { - StreamName string - Handler StreamHandler - - // At least one of these is true. - ServerStreams bool - ClientStreams bool + // StreamName and Handler are only used when registering handlers on a + // server. + StreamName string // the name of the method excluding the service + Handler StreamHandler // the handler called for the method + + // ServerStreams and ClientStreams are used for registering handlers on a + // server as well as defining RPC behavior when passed to NewClientStream + // and ClientConn.NewStream. At least one must be true. + ServerStreams bool // indicates the server can perform streaming sends + ClientStreams bool // indicates the client can perform streaming sends } // Stream defines the common interface a client or server stream has to satisfy. @@ -181,7 +187,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth rpcInfo := iresolver.RPCInfo{Context: ctx, Method: method} rpcConfig, err := cc.safeConfigSelector.SelectConfig(rpcInfo) if err != nil { - return nil, status.Convert(err).Err() + return nil, toRPCErr(err) } if rpcConfig != nil { @@ -196,7 +202,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth newStream = func(ctx context.Context, done func()) (iresolver.ClientStream, error) { cs, err := rpcConfig.Interceptor.NewStream(ctx, rpcInfo, done, ns) if err != nil { - return nil, status.Convert(err).Err() + return nil, toRPCErr(err) } return cs, nil } @@ -268,33 +274,6 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client if c.creds != nil { callHdr.Creds = c.creds } - var trInfo *traceInfo - if EnableTracing { - trInfo = &traceInfo{ - tr: trace.New("grpc.Sent."+methodFamily(method), method), - firstLine: firstLine{ - client: true, - }, - } - if deadline, ok := ctx.Deadline(); ok { - trInfo.firstLine.deadline = time.Until(deadline) - } - trInfo.tr.LazyLog(&trInfo.firstLine, false) - ctx = trace.NewContext(ctx, trInfo.tr) - } - ctx = newContextWithRPCInfo(ctx, c.failFast, c.codec, cp, comp) - sh := cc.dopts.copts.StatsHandler - var beginTime time.Time - if sh != nil { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) - beginTime = time.Now() - begin := &stats.Begin{ - Client: true, - BeginTime: beginTime, - FailFast: c.failFast, - } - sh.HandleRPC(ctx, begin) - } cs := &clientStream{ callHdr: callHdr, @@ -308,7 +287,6 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client cp: cp, comp: comp, cancel: cancel, - beginTime: beginTime, firstAttempt: true, onCommit: onCommit, } @@ -317,9 +295,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client } cs.binlog = binarylog.GetMethodLogger(method) - // Only this initial attempt has stats/tracing. - // TODO(dfawley): move to newAttempt when per-attempt stats are implemented. - if err := cs.newAttemptLocked(sh, trInfo); err != nil { + if err := cs.newAttemptLocked(false /* isTransparent */); err != nil { cs.finish(err) return nil, err } @@ -367,8 +343,43 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client // newAttemptLocked creates a new attempt with a transport. // If it succeeds, then it replaces clientStream's attempt with this new attempt. -func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (retErr error) { +func (cs *clientStream) newAttemptLocked(isTransparent bool) (retErr error) { + ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.cp, cs.comp) + method := cs.callHdr.Method + sh := cs.cc.dopts.copts.StatsHandler + var beginTime time.Time + if sh != nil { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast}) + beginTime = time.Now() + begin := &stats.Begin{ + Client: true, + BeginTime: beginTime, + FailFast: cs.callInfo.failFast, + IsClientStream: cs.desc.ClientStreams, + IsServerStream: cs.desc.ServerStreams, + IsTransparentRetryAttempt: isTransparent, + } + sh.HandleRPC(ctx, begin) + } + + var trInfo *traceInfo + if EnableTracing { + trInfo = &traceInfo{ + tr: trace.New("grpc.Sent."+methodFamily(method), method), + firstLine: firstLine{ + client: true, + }, + } + if deadline, ok := ctx.Deadline(); ok { + trInfo.firstLine.deadline = time.Until(deadline) + } + trInfo.tr.LazyLog(&trInfo.firstLine, false) + ctx = trace.NewContext(ctx, trInfo.tr) + } + newAttempt := &csAttempt{ + ctx: ctx, + beginTime: beginTime, cs: cs, dc: cs.cc.dopts.dc, statsHandler: sh, @@ -383,15 +394,14 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (r } }() - if err := cs.ctx.Err(); err != nil { + if err := ctx.Err(); err != nil { return toRPCErr(err) } - ctx := cs.ctx if cs.cc.parsedTarget.Scheme == "xds" { // Add extra metadata (metadata that will be added by transport) to context // so the balancer can see them. - ctx = grpcutil.WithExtraMetadata(cs.ctx, metadata.Pairs( + ctx = grpcutil.WithExtraMetadata(ctx, metadata.Pairs( "content-type", grpcutil.ContentType(cs.callHdr.ContentSubtype), )) } @@ -411,14 +421,11 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (r func (a *csAttempt) newStream() error { cs := a.cs cs.callHdr.PreviousAttempts = cs.numRetries - s, err := a.t.NewStream(cs.ctx, cs.callHdr) + s, err := a.t.NewStream(a.ctx, cs.callHdr) if err != nil { - if _, ok := err.(transport.PerformedIOError); ok { - // Return without converting to an RPC error so retry code can - // inspect. - return err - } - return toRPCErr(err) + // Return without converting to an RPC error so retry code can + // inspect. + return err } cs.attempt.s = s cs.attempt.p = &parser{r: s} @@ -439,8 +446,7 @@ type clientStream struct { cancel context.CancelFunc // cancels all attempts - sentLast bool // sent an end stream - beginTime time.Time + sentLast bool // sent an end stream methodConfig *MethodConfig @@ -480,6 +486,7 @@ type clientStream struct { // csAttempt implements a single transport stream attempt within a // clientStream. type csAttempt struct { + ctx context.Context cs *clientStream t transport.ClientTransport s *transport.Stream @@ -498,6 +505,7 @@ type csAttempt struct { trInfo *traceInfo statsHandler stats.Handler + beginTime time.Time } func (cs *clientStream) commitAttemptLocked() { @@ -515,46 +523,57 @@ func (cs *clientStream) commitAttempt() { } // shouldRetry returns nil if the RPC should be retried; otherwise it returns -// the error that should be returned by the operation. -func (cs *clientStream) shouldRetry(err error) error { - unprocessed := false +// the error that should be returned by the operation. If the RPC should be +// retried, the bool indicates whether it is being retried transparently. +func (cs *clientStream) shouldRetry(err error) (bool, error) { if cs.attempt.s == nil { - pioErr, ok := err.(transport.PerformedIOError) - if ok { - // Unwrap error. - err = toRPCErr(pioErr.Err) - } else { - unprocessed = true + // Error from NewClientStream. + nse, ok := err.(*transport.NewStreamError) + if !ok { + // Unexpected, but assume no I/O was performed and the RPC is not + // fatal, so retry indefinitely. + return true, nil } - if !ok && !cs.callInfo.failFast { - // In the event of a non-IO operation error from NewStream, we - // never attempted to write anything to the wire, so we can retry - // indefinitely for non-fail-fast RPCs. - return nil + + // Unwrap and convert error. + err = toRPCErr(nse.Err) + + // Never retry DoNotRetry errors, which indicate the RPC should not be + // retried due to max header list size violation, etc. + if nse.DoNotRetry { + return false, err + } + + // In the event of a non-IO operation error from NewStream, we never + // attempted to write anything to the wire, so we can retry + // indefinitely. + if !nse.DoNotTransparentRetry { + return true, nil } } if cs.finished || cs.committed { // RPC is finished or committed; cannot retry. - return err + return false, err } // Wait for the trailers. + unprocessed := false if cs.attempt.s != nil { <-cs.attempt.s.Done() unprocessed = cs.attempt.s.Unprocessed() } if cs.firstAttempt && unprocessed { // First attempt, stream unprocessed: transparently retry. - return nil + return true, nil } if cs.cc.dopts.disableRetry { - return err + return false, err } pushback := 0 hasPushback := false if cs.attempt.s != nil { if !cs.attempt.s.TrailersOnly() { - return err + return false, err } // TODO(retry): Move down if the spec changes to not check server pushback @@ -565,13 +584,13 @@ func (cs *clientStream) shouldRetry(err error) error { if pushback, e = strconv.Atoi(sps[0]); e != nil || pushback < 0 { channelz.Infof(logger, cs.cc.channelzID, "Server retry pushback specified to abort (%q).", sps[0]) cs.retryThrottler.throttle() // This counts as a failure for throttling. - return err + return false, err } hasPushback = true } else if len(sps) > 1 { channelz.Warningf(logger, cs.cc.channelzID, "Server retry pushback specified multiple values (%q); not retrying.", sps) cs.retryThrottler.throttle() // This counts as a failure for throttling. - return err + return false, err } } @@ -584,16 +603,16 @@ func (cs *clientStream) shouldRetry(err error) error { rp := cs.methodConfig.RetryPolicy if rp == nil || !rp.RetryableStatusCodes[code] { - return err + return false, err } // Note: the ordering here is important; we count this as a failure // only if the code matched a retryable code. if cs.retryThrottler.throttle() { - return err + return false, err } if cs.numRetries+1 >= rp.MaxAttempts { - return err + return false, err } var dur time.Duration @@ -616,23 +635,24 @@ func (cs *clientStream) shouldRetry(err error) error { select { case <-t.C: cs.numRetries++ - return nil + return false, nil case <-cs.ctx.Done(): t.Stop() - return status.FromContextError(cs.ctx.Err()).Err() + return false, status.FromContextError(cs.ctx.Err()).Err() } } // Returns nil if a retry was performed and succeeded; error otherwise. func (cs *clientStream) retryLocked(lastErr error) error { for { - cs.attempt.finish(lastErr) - if err := cs.shouldRetry(lastErr); err != nil { + cs.attempt.finish(toRPCErr(lastErr)) + isTransparent, err := cs.shouldRetry(lastErr) + if err != nil { cs.commitAttemptLocked() return err } cs.firstAttempt = false - if err := cs.newAttemptLocked(nil, nil); err != nil { + if err := cs.newAttemptLocked(isTransparent); err != nil { return err } if lastErr = cs.replayBufferLocked(); lastErr == nil { @@ -653,7 +673,11 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func()) for { if cs.committed { cs.mu.Unlock() - return op(cs.attempt) + // toRPCErr is used in case the error from the attempt comes from + // NewClientStream, which intentionally doesn't return a status + // error to allow for further inspection; all other errors should + // already be status errors. + return toRPCErr(op(cs.attempt)) } a := cs.attempt cs.mu.Unlock() @@ -918,7 +942,7 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { return io.EOF } if a.statsHandler != nil { - a.statsHandler.HandleRPC(cs.ctx, outPayload(true, m, data, payld, time.Now())) + a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) } if channelz.IsOn() { a.t.IncrMsgSent() @@ -966,7 +990,7 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { a.mu.Unlock() } if a.statsHandler != nil { - a.statsHandler.HandleRPC(cs.ctx, &stats.InPayload{ + a.statsHandler.HandleRPC(a.ctx, &stats.InPayload{ Client: true, RecvTime: time.Now(), Payload: m, @@ -1028,12 +1052,12 @@ func (a *csAttempt) finish(err error) { if a.statsHandler != nil { end := &stats.End{ Client: true, - BeginTime: a.cs.beginTime, + BeginTime: a.beginTime, EndTime: time.Now(), Trailer: tr, Error: err, } - a.statsHandler.HandleRPC(a.cs.ctx, end) + a.statsHandler.HandleRPC(a.ctx, end) } if a.trInfo != nil && a.trInfo.tr != nil { if err == nil { diff --git a/stress/grpc_testing/metrics_grpc.pb.go b/stress/grpc_testing/metrics_grpc.pb.go index 2ece0325563..0730fad49a4 100644 --- a/stress/grpc_testing/metrics_grpc.pb.go +++ b/stress/grpc_testing/metrics_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: stress/grpc_testing/metrics.proto package grpc_testing diff --git a/tap/tap.go b/tap/tap.go index caea1ebed6e..dbf34e6bb5f 100644 --- a/tap/tap.go +++ b/tap/tap.go @@ -37,16 +37,16 @@ type Info struct { // TODO: More to be added. } -// ServerInHandle defines the function which runs before a new stream is created -// on the server side. If it returns a non-nil error, the stream will not be -// created and a RST_STREAM will be sent back to the client with REFUSED_STREAM. -// The client will receive an RPC error "code = Unavailable, desc = stream -// terminated by RST_STREAM with error code: REFUSED_STREAM". +// ServerInHandle defines the function which runs before a new stream is +// created on the server side. If it returns a non-nil error, the stream will +// not be created and an error will be returned to the client. If the error +// returned is a status error, that status code and message will be used, +// otherwise PermissionDenied will be the code and err.Error() will be the +// message. // // It's intended to be used in situations where you don't want to waste the -// resources to accept the new stream (e.g. rate-limiting). And the content of -// the error will be ignored and won't be sent back to the client. For other -// general usages, please use interceptors. +// resources to accept the new stream (e.g. rate-limiting). For other general +// usages, please use interceptors. // // Note that it is executed in the per-connection I/O goroutine(s) instead of // per-RPC goroutine. Therefore, users should NOT have any diff --git a/test/authority_test.go b/test/authority_test.go index 17ae178b73c..15afa759c90 100644 --- a/test/authority_test.go +++ b/test/authority_test.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux /* diff --git a/test/balancer_test.go b/test/balancer_test.go index bc22036dbac..e2fa4cf31d0 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -28,6 +28,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "google.golang.org/grpc" "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" @@ -37,7 +38,6 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/balancerload" - "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcutil" imetadata "google.golang.org/grpc/internal/metadata" "google.golang.org/grpc/internal/stubserver" @@ -88,7 +88,7 @@ func (b *testBalancer) UpdateClientConnState(state balancer.ClientConnState) err logger.Errorf("testBalancer: failed to NewSubConn: %v", err) return nil } - b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &picker{sc: b.sc, bal: b}}) + b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}}) b.sc.Connect() } return nil @@ -106,8 +106,10 @@ func (b *testBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubCon } switch s.ConnectivityState { - case connectivity.Ready, connectivity.Idle: + case connectivity.Ready: b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{sc: sc, bal: b}}) + case connectivity.Idle: + b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{sc: sc, bal: b, idle: true}}) case connectivity.Connecting: b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}}) case connectivity.TransientFailure: @@ -117,16 +119,23 @@ func (b *testBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubCon func (b *testBalancer) Close() {} +func (b *testBalancer) ExitIdle() {} + type picker struct { - err error - sc balancer.SubConn - bal *testBalancer + err error + sc balancer.SubConn + bal *testBalancer + idle bool } func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { if p.err != nil { return balancer.PickResult{}, p.err } + if p.idle { + p.sc.Connect() + return balancer.PickResult{}, balancer.ErrNoSubConnAvailable + } extraMD, _ := grpcutil.ExtraMetadata(info.Ctx) info.Ctx = nil // Do not validate context. p.bal.pickInfos = append(p.bal.pickInfos, info) @@ -195,14 +204,14 @@ func testPickExtraMetadata(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - // The RPCs will fail, but we don't care. We just need the pick to happen. - ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) - defer cancel1() - tc.EmptyCall(ctx1, &testpb.Empty{}) - - ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) - defer cancel2() - tc.EmptyCall(ctx2, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType)) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil) + } + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType)); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil) + } want := []metadata.MD{ // First RPC doesn't have sub-content-type. @@ -210,9 +219,8 @@ func testPickExtraMetadata(t *testing.T, e env) { // Second RPC has sub-content-type "proto". {"content-type": []string{"application/grpc+proto"}}, } - - if !cmp.Equal(b.pickExtraMDs, want) { - t.Fatalf("%s", cmp.Diff(b.pickExtraMDs, want)) + if diff := cmp.Diff(want, b.pickExtraMDs); diff != "" { + t.Fatalf("unexpected diff in metadata (-want, +got): %s", diff) } } @@ -374,8 +382,9 @@ func (testBalancerKeepAddresses) UpdateSubConnState(sc balancer.SubConn, s balan panic("not used") } -func (testBalancerKeepAddresses) Close() { -} +func (testBalancerKeepAddresses) Close() {} + +func (testBalancerKeepAddresses) ExitIdle() {} // Make sure that non-grpclb balancers don't get grpclb addresses even if name // resolver sends them @@ -698,10 +707,7 @@ func (s) TestEmptyAddrs(t *testing.T) { // Initialize pickfirst client pfr := manual.NewBuilderWithScheme("whatever") - pfrnCalled := grpcsync.NewEvent() - pfr.ResolveNowCallback = func(resolver.ResolveNowOptions) { - pfrnCalled.Fire() - } + pfr.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) pfcc, err := grpc.DialContext(ctx, pfr.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(pfr)) @@ -718,16 +724,10 @@ func (s) TestEmptyAddrs(t *testing.T) { // Remove all addresses. pfr.UpdateState(resolver.State{}) - // Wait for a ResolveNow call on the pick first client's resolver. - <-pfrnCalled.Done() // Initialize roundrobin client rrr := manual.NewBuilderWithScheme("whatever") - rrrnCalled := grpcsync.NewEvent() - rrr.ResolveNowCallback = func(resolver.ResolveNowOptions) { - rrrnCalled.Fire() - } rrr.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) rrcc, err := grpc.DialContext(ctx, rrr.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(rrr), @@ -745,8 +745,6 @@ func (s) TestEmptyAddrs(t *testing.T) { // Remove all addresses. rrr.UpdateState(resolver.State{}) - // Wait for a ResolveNow call on the round robin client's resolver. - <-rrrnCalled.Done() // Confirm several new RPCs succeed on pick first. for i := 0; i < 10; i++ { diff --git a/test/bufconn/bufconn.go b/test/bufconn/bufconn.go index 168cdb8578d..3f77f4876eb 100644 --- a/test/bufconn/bufconn.go +++ b/test/bufconn/bufconn.go @@ -21,6 +21,7 @@ package bufconn import ( + "context" "fmt" "io" "net" @@ -86,8 +87,17 @@ func (l *Listener) Addr() net.Addr { return addr{} } // providing it the server half of the connection, and returns the client half // of the connection. func (l *Listener) Dial() (net.Conn, error) { + return l.DialContext(context.Background()) +} + +// DialContext creates an in-memory full-duplex network connection, unblocks Accept by +// providing it the server half of the connection, and returns the client half +// of the connection. If ctx is Done, returns ctx.Err() +func (l *Listener) DialContext(ctx context.Context) (net.Conn, error) { p1, p2 := newPipe(l.sz), newPipe(l.sz) select { + case <-ctx.Done(): + return nil, ctx.Err() case <-l.done: return nil, errClosed case l.ch <- &conn{p1, p2}: diff --git a/test/channelz_linux_go110_test.go b/test/channelz_linux_test.go similarity index 99% rename from test/channelz_linux_go110_test.go rename to test/channelz_linux_test.go index dea374bfc08..aa6febe537a 100644 --- a/test/channelz_linux_go110_test.go +++ b/test/channelz_linux_test.go @@ -1,5 +1,3 @@ -// +build linux - /* * * Copyright 2018 gRPC authors. diff --git a/test/channelz_test.go b/test/channelz_test.go index 47e7eb92716..6cb09dd8d89 100644 --- a/test/channelz_test.go +++ b/test/channelz_test.go @@ -1689,8 +1689,22 @@ func (s) TestCZSubChannelPickedNewAddress(t *testing.T) { } te.srvs[0].Stop() te.srvs[1].Stop() - // Here, we just wait for all sockets to be up. In the future, if we implement - // IDLE, we may need to make several rpc calls to create the sockets. + // Here, we just wait for all sockets to be up. Make several rpc calls to + // create the sockets since we do not automatically reconnect. + done := make(chan struct{}) + defer close(done) + go func() { + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + tc.EmptyCall(ctx, &testpb.Empty{}) + cancel() + select { + case <-time.After(10 * time.Millisecond): + case <-done: + return + } + } + }() if err := verifyResultWithDelay(func() (bool, error) { tcs, _ := channelz.GetTopChannels(0, 0) if len(tcs) != 1 { diff --git a/test/end2end_test.go b/test/end2end_test.go index 902e9424104..957d13f731f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -55,12 +55,14 @@ import ( "google.golang.org/grpc/health" healthgrpc "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/transport" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" @@ -69,6 +71,7 @@ import ( "google.golang.org/grpc/stats" "google.golang.org/grpc/status" "google.golang.org/grpc/tap" + "google.golang.org/grpc/test/bufconn" testpb "google.golang.org/grpc/test/grpc_testing" "google.golang.org/grpc/testdata" ) @@ -506,10 +509,6 @@ type test struct { customDialOptions []grpc.DialOption resolverScheme string - // All test dialing is blocking by default. Set this to true if dial - // should be non-blocking. - nonBlockingDial bool - // These are are set once startServer is called. The common case is to have // only one testServer. srv stopper @@ -826,10 +825,6 @@ func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) if te.customCodec != nil { opts = append(opts, grpc.WithDefaultCallOptions(grpc.ForceCodec(te.customCodec))) } - if !te.nonBlockingDial && te.srvAddr != "" { - // Only do a blocking dial if server is up. - opts = append(opts, grpc.WithBlock()) - } if te.srvAddr == "" { te.srvAddr = "client.side.only.test" } @@ -1333,6 +1328,131 @@ func testConcurrentServerStopAndGoAway(t *testing.T, e env) { awaitNewConnLogOutput() } +func (s) TestDetailedConnectionCloseErrorPropagatesToRpcError(t *testing.T) { + rpcStartedOnServer := make(chan struct{}) + rpcDoneOnClient := make(chan struct{}) + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + close(rpcStartedOnServer) + <-rpcDoneOnClient + return status.Error(codes.Internal, "arbitrary status") + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // The precise behavior of this test is subject to raceyness around the timing of when TCP packets + // are sent from client to server, and when we tell the server to stop, so we need to account for both + // of these possible error messages: + // 1) If the call to ss.S.Stop() causes the server's sockets to close while there's still in-fight + // data from the client on the TCP connection, then the kernel can send an RST back to the client (also + // see https://stackoverflow.com/questions/33053507/econnreset-in-send-linux-c). Note that while this + // condition is expected to be rare due to the rpcStartedOnServer synchronization, in theory it should + // be possible, e.g. if the client sends a BDP ping at the right time. + // 2) If, for example, the call to ss.S.Stop() happens after the RPC headers have been received at the + // server, then the TCP connection can shutdown gracefully when the server's socket closes. + const possibleConnResetMsg = "connection reset by peer" + const possibleEOFMsg = "error reading from server: EOF" + // Start an RPC. Then, while the RPC is still being accepted or handled at the server, abruptly + // stop the server, killing the connection. The RPC error message should include details about the specific + // connection error that was encountered. + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall = _, %v, want _, ", ss.Client, err) + } + // Block until the RPC has been started on the server. This ensures that the ClientConn will find a healthy + // connection for the RPC to go out on initially, and that the TCP connection will shut down strictly after + // the RPC has been started on it. + <-rpcStartedOnServer + ss.S.Stop() + if _, err := stream.Recv(); err == nil || (!strings.Contains(err.Error(), possibleConnResetMsg) && !strings.Contains(err.Error(), possibleEOFMsg)) { + t.Fatalf("%v.Recv() = _, %v, want _, rpc error containing substring: %q OR %q", stream, err, possibleConnResetMsg, possibleEOFMsg) + } + close(rpcDoneOnClient) +} + +func (s) TestDetailedGoawayErrorOnGracefulClosePropagatesToRPCError(t *testing.T) { + rpcDoneOnClient := make(chan struct{}) + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + <-rpcDoneOnClient + return status.Error(codes.Internal, "arbitrary status") + }, + } + sopts := []grpc.ServerOption{ + grpc.KeepaliveParams(keepalive.ServerParameters{ + MaxConnectionAge: time.Millisecond * 100, + MaxConnectionAgeGrace: time.Millisecond, + }), + } + if err := ss.Start(sopts); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall = _, %v, want _, ", ss.Client, err) + } + const expectedErrorMessageSubstring = "received prior goaway: code: NO_ERROR" + _, err = stream.Recv() + close(rpcDoneOnClient) + if err == nil || !strings.Contains(err.Error(), expectedErrorMessageSubstring) { + t.Fatalf("%v.Recv() = _, %v, want _, rpc error containing substring: %q", stream, err, expectedErrorMessageSubstring) + } +} + +func (s) TestDetailedGoawayErrorOnAbruptClosePropagatesToRPCError(t *testing.T) { + // set the min keepalive time very low so that this test can take + // a reasonable amount of time + prev := internal.KeepaliveMinPingTime + internal.KeepaliveMinPingTime = time.Millisecond + defer func() { internal.KeepaliveMinPingTime = prev }() + + rpcDoneOnClient := make(chan struct{}) + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + <-rpcDoneOnClient + return status.Error(codes.Internal, "arbitrary status") + }, + } + sopts := []grpc.ServerOption{ + grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ + MinTime: time.Second * 1000, /* arbitrary, large value */ + }), + } + dopts := []grpc.DialOption{ + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: time.Millisecond, /* should trigger "too many pings" error quickly */ + Timeout: time.Second * 1000, /* arbitrary, large value */ + PermitWithoutStream: false, + }), + } + if err := ss.Start(sopts, dopts...); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall = _, %v, want _, ", ss.Client, err) + } + const expectedErrorMessageSubstring = `received prior goaway: code: ENHANCE_YOUR_CALM, debug data: "too_many_pings"` + _, err = stream.Recv() + close(rpcDoneOnClient) + if err == nil || !strings.Contains(err.Error(), expectedErrorMessageSubstring) { + t.Fatalf("%v.Recv() = _, %v, want _, rpc error containing substring: |%v|", stream, err, expectedErrorMessageSubstring) + } +} + func (s) TestClientConnCloseAfterGoAwayWithActiveStream(t *testing.T) { for _, e := range listTestEnv() { if e.name == "handler-tls" { @@ -1744,7 +1864,6 @@ func (s) TestServiceConfigMaxMsgSize(t *testing.T) { defer te1.tearDown() te1.resolverScheme = r.Scheme() - te1.nonBlockingDial = true te1.startServer(&testServer{security: e.security}) cc1 := te1.clientConn(grpc.WithResolvers(r)) @@ -1832,7 +1951,6 @@ func (s) TestServiceConfigMaxMsgSize(t *testing.T) { // Case2: Client API set maxReqSize to 1024 (send), maxRespSize to 1024 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv). te2 := testServiceConfigSetup(t, e) te2.resolverScheme = r.Scheme() - te2.nonBlockingDial = true te2.maxClientReceiveMsgSize = newInt(1024) te2.maxClientSendMsgSize = newInt(1024) @@ -1892,7 +2010,6 @@ func (s) TestServiceConfigMaxMsgSize(t *testing.T) { // Case3: Client API set maxReqSize to 4096 (send), maxRespSize to 4096 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv). te3 := testServiceConfigSetup(t, e) te3.resolverScheme = r.Scheme() - te3.nonBlockingDial = true te3.maxClientReceiveMsgSize = newInt(4096) te3.maxClientSendMsgSize = newInt(4096) @@ -1985,7 +2102,6 @@ func (s) TestStreamingRPCWithTimeoutInServiceConfigRecv(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") te.resolverScheme = r.Scheme() - te.nonBlockingDial = true cc := te.clientConn(grpc.WithResolvers(r)) tc := testpb.NewTestServiceClient(cc) @@ -2121,6 +2237,61 @@ func testPreloaderClientSend(t *testing.T, e env) { } } +func (s) TestPreloaderSenderSend(t *testing.T) { + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + for i := 0; i < 10; i++ { + preparedMsg := &grpc.PreparedMsg{} + err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{ + Body: []byte{'0' + uint8(i)}, + }, + }) + if err != nil { + return err + } + stream.SendMsg(preparedMsg) + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } + + var ngot int + var buf bytes.Buffer + for { + reply, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + ngot++ + if buf.Len() > 0 { + buf.WriteByte(',') + } + buf.Write(reply.GetPayload().GetBody()) + } + if want := 10; ngot != want { + t.Errorf("Got %d replies, want %d", ngot, want) + } + if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want { + t.Errorf("Got replies %q; want %q", got, want) + } +} + func (s) TestMaxMsgSizeClientDefault(t *testing.T) { for _, e := range listTestEnv() { testMaxMsgSizeClientDefault(t, e) @@ -2380,10 +2551,13 @@ type myTap struct { func (t *myTap) handle(ctx context.Context, info *tap.Info) (context.Context, error) { if info != nil { - if info.FullMethodName == "/grpc.testing.TestService/EmptyCall" { + switch info.FullMethodName { + case "/grpc.testing.TestService/EmptyCall": t.cnt++ - } else if info.FullMethodName == "/grpc.testing.TestService/UnaryCall" { + case "/grpc.testing.TestService/UnaryCall": return nil, fmt.Errorf("tap error") + case "/grpc.testing.TestService/FullDuplexCall": + return nil, status.Errorf(codes.FailedPrecondition, "test custom error") } } return ctx, nil @@ -2423,8 +2597,15 @@ func testTap(t *testing.T, e env) { ResponseSize: 45, Payload: payload, } - if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.Unavailable { - t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.Unavailable) + if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.PermissionDenied { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.PermissionDenied) + } + str, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected error creating stream: %v", err) + } + if _, err := str.Recv(); status.Code(err) != codes.FailedPrecondition { + t.Fatalf("FullDuplexCall Recv() = _, %v, want _, %s", err, codes.FailedPrecondition) } } @@ -3512,66 +3693,79 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) { } } +// Tests that the client transparently retries correctly when receiving a +// RST_STREAM with code REFUSED_STREAM. func (s) TestTransparentRetry(t *testing.T) { - for _, e := range listTestEnv() { - if e.name == "handler-tls" { - // Fails with RST_STREAM / FLOW_CONTROL_ERROR - continue - } - testTransparentRetry(t, e) - } -} - -// This test makes sure RPCs are retried times when they receive a RST_STREAM -// with the REFUSED_STREAM error code, which the InTapHandle provokes. -func testTransparentRetry(t *testing.T, e env) { - te := newTest(t, e) - attempts := 0 - successAttempt := 2 - te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { - attempts++ - if attempts < successAttempt { - return nil, errors.New("not now") - } - return ctx, nil - } - te.startServer(&testServer{security: e.security}) - defer te.tearDown() - - cc := te.clientConn() - tsc := testpb.NewTestServiceClient(cc) testCases := []struct { - successAttempt int - failFast bool - errCode codes.Code + failFast bool + errCode codes.Code }{{ - successAttempt: 1, + // success attempt: 1, (stream ID 1) }, { - successAttempt: 2, + // success attempt: 2, (stream IDs 3, 5) }, { - successAttempt: 3, - errCode: codes.Unavailable, + // no success attempt (stream IDs 7, 9) + errCode: codes.Unavailable, }, { - successAttempt: 1, - failFast: true, + // success attempt: 1 (stream ID 11), + failFast: true, }, { - successAttempt: 2, - failFast: true, + // success attempt: 2 (stream IDs 13, 15), + failFast: true, }, { - successAttempt: 3, - failFast: true, - errCode: codes.Unavailable, + // no success attempt (stream IDs 17, 19) + failFast: true, + errCode: codes.Unavailable, }} - for _, tc := range testCases { - attempts = 0 - successAttempt = tc.successAttempt - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := tsc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(!tc.failFast)) - cancel() - if status.Code(err) != tc.errCode { - t.Errorf("%+v: tsc.EmptyCall(_, _) = _, %v, want _, Code=%v", tc, err, tc.errCode) + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen. Err: %v", err) + } + defer lis.Close() + server := &httpServer{ + responses: []httpServerResponse{{ + trailers: [][]string{{ + ":status", "200", + "content-type", "application/grpc", + "grpc-status", "0", + }}, + }}, + refuseStream: func(i uint32) bool { + switch i { + case 1, 5, 11, 15: // these stream IDs succeed + return false + } + return true // these are refused + }, + } + server.start(t, lis) + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("failed to dial due to err: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client := testpb.NewTestServiceClient(cc) + + for i, tc := range testCases { + stream, err := client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("error creating stream due to err: %v", err) + } + code := func(err error) codes.Code { + if err == io.EOF { + return codes.OK + } + return status.Code(err) } + if _, err := stream.Recv(); code(err) != tc.errCode { + t.Fatalf("%v: stream.Recv() = _, %v, want error code: %v", i, err, tc.errCode) + } + } } @@ -4937,8 +5131,7 @@ func (s) TestFlowControlLogicalRace(t *testing.T) { go s.Serve(lis) - ctx := context.Background() - cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) if err != nil { t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) } @@ -4947,7 +5140,7 @@ func (s) TestFlowControlLogicalRace(t *testing.T) { failures := 0 for i := 0; i < requestCount; i++ { - ctx, cancel := context.WithTimeout(ctx, requestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) output, err := cl.StreamingOutputCall(ctx, &testpb.StreamingOutputCallRequest{}) if err != nil { t.Fatalf("StreamingOutputCall; err = %q", err) @@ -5145,7 +5338,7 @@ func (s) TestGRPCMethod(t *testing.T) { } defer ss.Stop() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { @@ -5157,6 +5350,86 @@ func (s) TestGRPCMethod(t *testing.T) { } } +// renameProtoCodec is an encoding.Codec wrapper that allows customizing the +// Name() of another codec. +type renameProtoCodec struct { + encoding.Codec + name string +} + +func (r *renameProtoCodec) Name() string { return r.name } + +// TestForceCodecName confirms that the ForceCodec call option sets the subtype +// in the content-type header according to the Name() of the codec provided. +func (s) TestForceCodecName(t *testing.T) { + wantContentTypeCh := make(chan []string, 1) + defer close(wantContentTypeCh) + + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Internal, "no metadata in context") + } + if got, want := md["content-type"], <-wantContentTypeCh; !reflect.DeepEqual(got, want) { + return nil, status.Errorf(codes.Internal, "got content-type=%q; want [%q]", got, want) + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.ForceServerCodec(encoding.GetCodec("proto"))}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + codec := &renameProtoCodec{Codec: encoding.GetCodec("proto"), name: "some-test-name"} + wantContentTypeCh <- []string{"application/grpc+some-test-name"} + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } + + // Confirm the name is converted to lowercase before transmitting. + codec.name = "aNoTHeRNaME" + wantContentTypeCh <- []string{"application/grpc+anothername"} + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } +} + +func (s) TestForceServerCodec(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + codec := &countingProtoCodec{} + if err := ss.Start([]grpc.ServerOption{grpc.ForceServerCodec(codec)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } + + unmarshalCount := atomic.LoadInt32(&codec.unmarshalCount) + const wantUnmarshalCount = 1 + if unmarshalCount != wantUnmarshalCount { + t.Fatalf("protoCodec.unmarshalCount = %d; want %d", unmarshalCount, wantUnmarshalCount) + } + marshalCount := atomic.LoadInt32(&codec.marshalCount) + const wantMarshalCount = 1 + if marshalCount != wantMarshalCount { + t.Fatalf("protoCodec.marshalCount = %d; want %d", marshalCount, wantMarshalCount) + } +} + func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) { const mdkey = "somedata" @@ -5526,6 +5799,33 @@ func (c *errCodec) Name() string { return "Fermat's near-miss." } +type countingProtoCodec struct { + marshalCount int32 + unmarshalCount int32 +} + +func (p *countingProtoCodec) Marshal(v interface{}) ([]byte, error) { + atomic.AddInt32(&p.marshalCount, 1) + vv, ok := v.(proto.Message) + if !ok { + return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) + } + return proto.Marshal(vv) +} + +func (p *countingProtoCodec) Unmarshal(data []byte, v interface{}) error { + atomic.AddInt32(&p.unmarshalCount, 1) + vv, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) + } + return proto.Unmarshal(data, vv) +} + +func (*countingProtoCodec) Name() string { + return "proto" +} + func (s) TestEncodeDoesntPanic(t *testing.T) { for _, e := range listTestEnv() { testEncodeDoesntPanic(t, e) @@ -6036,6 +6336,23 @@ func testServiceConfigMaxMsgSizeTD(t *testing.T, e env) { } } +// TestMalformedStreamMethod starts a test server and sends an RPC with a +// malformed method name. The server should respond with an UNIMPLEMENTED status +// code in this case. +func (s) TestMalformedStreamMethod(t *testing.T) { + const testMethod = "a-method-name-without-any-slashes" + te := newTest(t, tcpClearRREnv) + te.startServer(nil) + defer te.tearDown() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + err := te.clientConn().Invoke(ctx, testMethod, nil, nil) + if gotCode := status.Code(err); gotCode != codes.Unimplemented { + t.Fatalf("Invoke with method %q, got code %s, want %s", testMethod, gotCode, codes.Unimplemented) + } +} + func (s) TestMethodFromServerStream(t *testing.T) { const testMethod = "/package.service/method" e := tcpClearRREnv @@ -6244,7 +6561,7 @@ func (s) TestServeExitsWhenListenerClosed(t *testing.T) { close(done) }() - cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) if err != nil { t.Fatalf("Failed to dial server: %v", err) } @@ -6457,15 +6774,13 @@ func (s) TestDisabledIOBuffers(t *testing.T) { t.Fatalf("Failed to create listener: %v", err) } - done := make(chan struct{}) go func() { s.Serve(lis) - close(done) }() defer s.Stop() dctx, dcancel := context.WithTimeout(context.Background(), 5*time.Second) defer dcancel() - cc, err := grpc.DialContext(dctx, lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock(), grpc.WithWriteBufferSize(0), grpc.WithReadBufferSize(0)) + cc, err := grpc.DialContext(dctx, lis.Addr().String(), grpc.WithInsecure(), grpc.WithWriteBufferSize(0), grpc.WithReadBufferSize(0)) if err != nil { t.Fatalf("Failed to dial server") } @@ -6738,6 +7053,10 @@ func (s) TestGoAwayThenClose(t *testing.T) { return &testpb.SimpleResponse{}, nil }, fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil { + t.Errorf("unexpected error from send: %v", err) + return err + } // Wait forever. _, err := stream.Recv() if err == nil { @@ -6757,11 +7076,11 @@ func (s) TestGoAwayThenClose(t *testing.T) { s2 := grpc.NewServer() defer s2.Stop() testpb.RegisterTestServiceServer(s2, ts) - go s2.Serve(lis2) r := manual.NewBuilderWithScheme("whatever") r.InitialState(resolver.State{Addresses: []resolver.Address{ {Addr: lis1.Addr().String()}, + {Addr: lis2.Addr().String()}, }}) cc, err := grpc.DialContext(ctx, r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithInsecure()) if err != nil { @@ -6771,30 +7090,49 @@ func (s) TestGoAwayThenClose(t *testing.T) { client := testpb.NewTestServiceClient(cc) - // Should go on connection 1. We use a long-lived RPC because it will cause GracefulStop to send GO_AWAY, but the - // connection doesn't get closed until the server stops and the client receives. + // We make a streaming RPC and do an one-message-round-trip to make sure + // it's created on connection 1. + // + // We use a long-lived RPC because it will cause GracefulStop to send + // GO_AWAY, but the connection doesn't get closed until the server stops and + // the client receives the error. stream, err := client.FullDuplexCall(ctx) if err != nil { t.Fatalf("FullDuplexCall(_) = _, %v; want _, nil", err) } + if _, err = stream.Recv(); err != nil { + t.Fatalf("unexpected error from first recv: %v", err) + } - r.UpdateState(resolver.State{Addresses: []resolver.Address{ - {Addr: lis1.Addr().String()}, - {Addr: lis2.Addr().String()}, - }}) + go s2.Serve(lis2) // Send GO_AWAY to connection 1. go s1.GracefulStop() - // Wait for connection 2 to be established. + // Wait for the ClientConn to enter IDLE state. + state := cc.GetState() + for ; state != connectivity.Idle && cc.WaitForStateChange(ctx, state); state = cc.GetState() { + } + if state != connectivity.Idle { + t.Fatalf("timed out waiting for IDLE channel state; last state = %v", state) + } + + // Initiate another RPC to create another connection. + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { + t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err) + } + + // Assert that connection 2 has been established. <-conn2Established.Done() + // Close the listener for server2 to prevent it from allowing new connections. + lis2.Close() + // Close connection 1. s1.Stop() // Wait for client to close. - _, err = stream.Recv() - if err == nil { + if _, err = stream.Recv(); err == nil { t.Fatal("expected the stream to die, but got a successful Recv") } @@ -6831,7 +7169,6 @@ func (s) TestRPCWaitsForResolver(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") te.resolverScheme = r.Scheme() - te.nonBlockingDial = true cc := te.clientConn(grpc.WithResolvers(r)) tc := testpb.NewTestServiceClient(cc) @@ -6918,7 +7255,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { ":status", "403", "content-type", "application/grpc", }, - errCode: codes.Unknown, + errCode: codes.PermissionDenied, }, { // malformed grpc-status. @@ -6937,7 +7274,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { "grpc-status", "0", "grpc-tags-bin", "???", }, - errCode: codes.Internal, + errCode: codes.Unavailable, }, { // gRPC status error. @@ -6946,14 +7283,14 @@ func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { "content-type", "application/grpc", "grpc-status", "3", }, - errCode: codes.InvalidArgument, + errCode: codes.Unavailable, }, } { doHTTPHeaderTest(t, test.errCode, test.header) } } -// Testing non-Trailers-only Trailers (delievered in second HEADERS frame) +// Testing non-Trailers-only Trailers (delivered in second HEADERS frame) func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { for _, test := range []struct { responseHeader []string @@ -6969,7 +7306,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { // trailer missing grpc-status ":status", "502", }, - errCode: codes.Unknown, + errCode: codes.Unavailable, }, { responseHeader: []string{ @@ -6981,6 +7318,18 @@ func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { "grpc-status", "0", "grpc-status-details-bin", "????", }, + errCode: codes.Unimplemented, + }, + { + responseHeader: []string{ + ":status", "200", + "content-type", "application/grpc", + }, + trailer: []string{ + // malformed grpc-status-details-bin field + "grpc-status", "0", + "grpc-status-details-bin", "????", + }, errCode: codes.Internal, }, } { @@ -6996,8 +7345,18 @@ func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) { doHTTPHeaderTest(t, codes.Internal, header, header, header) } +type httpServerResponse struct { + headers [][]string + payload []byte + trailers [][]string +} + type httpServer struct { - headerFields [][]string + // If waitForEndStream is set, wait for the client to send a frame with end + // stream in it before sending a response/refused stream. + waitForEndStream bool + refuseStream func(uint32) bool + responses []httpServerResponse } func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error { @@ -7021,6 +7380,10 @@ func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields }) } +func (s *httpServer) writePayload(framer *http2.Framer, sid uint32, payload []byte) error { + return framer.WriteData(sid, false, payload) +} + func (s *httpServer) start(t *testing.T, lis net.Listener) { // Launch an HTTP server to send back header. go func() { @@ -7045,24 +7408,66 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { writer.Flush() // necessary since client is expecting preface before declaring connection fully setup. var sid uint32 - // Read frames until a header is received. - for { - frame, err := framer.ReadFrame() - if err != nil { - t.Errorf("Error at server-side while reading frame. Err: %v", err) - return + // Loop until conn is closed and framer returns io.EOF + for requestNum := 0; ; requestNum = (requestNum + 1) % len(s.responses) { + // Read frames until a header is received. + for { + frame, err := framer.ReadFrame() + if err != nil { + if err != io.EOF { + t.Errorf("Error at server-side while reading frame. Err: %v", err) + } + return + } + sid = 0 + switch fr := frame.(type) { + case *http2.HeadersFrame: + // Respond after this if we are not waiting for an end + // stream or if this frame ends it. + if !s.waitForEndStream || fr.StreamEnded() { + sid = fr.Header().StreamID + } + + case *http2.DataFrame: + // Respond after this if we were waiting for an end stream + // and this frame ends it. (If we were not waiting for an + // end stream, this stream was already responded to when + // the headers were received.) + if s.waitForEndStream && fr.StreamEnded() { + sid = fr.Header().StreamID + } + } + if sid != 0 { + if s.refuseStream == nil || !s.refuseStream(sid) { + break + } + framer.WriteRSTStream(sid, http2.ErrCodeRefusedStream) + writer.Flush() + } } - if hframe, ok := frame.(*http2.HeadersFrame); ok { - sid = hframe.Header().StreamID - break + + response := s.responses[requestNum] + for _, header := range response.headers { + if err = s.writeHeader(framer, sid, header, false); err != nil { + t.Errorf("Error at server-side while writing headers. Err: %v", err) + return + } + writer.Flush() } - } - for i, headers := range s.headerFields { - if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil { - t.Errorf("Error at server-side while writing headers. Err: %v", err) - return + if response.payload != nil { + if err = s.writePayload(framer, sid, response.payload); err != nil { + t.Errorf("Error at server-side while writing payload. Err: %v", err) + return + } + writer.Flush() + } + for i, trailer := range response.trailers { + if err = s.writeHeader(framer, sid, trailer, i == len(response.trailers)-1); err != nil { + t.Errorf("Error at server-side while writing trailers. Err: %v", err) + return + } + writer.Flush() } - writer.Flush() } }() } @@ -7075,7 +7480,7 @@ func doHTTPHeaderTest(t *testing.T, errCode codes.Code, headerFields ...[]string } defer lis.Close() server := &httpServer{ - headerFields: headerFields, + responses: []httpServerResponse{{trailers: headerFields}}, } server.start(t, lis) cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) @@ -7251,3 +7656,395 @@ func (s) TestCanceledRPCCallOptionRace(t *testing.T) { } wg.Wait() } + +func (s) TestClientSettingsFloodCloseConn(t *testing.T) { + // Tests that the server properly closes its transport if the client floods + // settings frames and then closes the connection. + + // Minimize buffer sizes to stimulate failure condition more quickly. + s := grpc.NewServer(grpc.WriteBufferSize(20)) + l := bufconn.Listen(20) + go s.Serve(l) + + // Dial our server and handshake. + conn, err := l.Dial() + if err != nil { + t.Fatalf("Error dialing bufconn: %v", err) + } + + n, err := conn.Write([]byte(http2.ClientPreface)) + if err != nil || n != len(http2.ClientPreface) { + t.Fatalf("Error writing client preface: %v, %v", n, err) + } + + fr := http2.NewFramer(conn, conn) + f, err := fr.ReadFrame() + if err != nil { + t.Fatalf("Error reading initial settings frame: %v", err) + } + if _, ok := f.(*http2.SettingsFrame); ok { + if err := fr.WriteSettingsAck(); err != nil { + t.Fatalf("Error writing settings ack: %v", err) + } + } else { + t.Fatalf("Error reading initial settings frame: type=%T", f) + } + + // Confirm settings can be written, and that an ack is read. + if err = fr.WriteSettings(); err != nil { + t.Fatalf("Error writing settings frame: %v", err) + } + if f, err = fr.ReadFrame(); err != nil { + t.Fatalf("Error reading frame: %v", err) + } + if sf, ok := f.(*http2.SettingsFrame); !ok || !sf.IsAck() { + t.Fatalf("Unexpected frame: %v", f) + } + + // Flood settings frames until a timeout occurs, indiciating the server has + // stopped reading from the connection, then close the conn. + for { + conn.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)) + if err := fr.WriteSettings(); err != nil { + if to, ok := err.(interface{ Timeout() bool }); !ok || !to.Timeout() { + t.Fatalf("Received unexpected write error: %v", err) + } + break + } + } + conn.Close() + + // If the server does not handle this situation correctly, it will never + // close the transport. This is because its loopyWriter.run() will have + // exited, and thus not handle the goAway the draining process initiates. + // Also, we would see a goroutine leak in this case, as the reader would be + // blocked on the controlBuf's throttle() method indefinitely. + + timer := time.AfterFunc(5*time.Second, func() { + t.Errorf("Timeout waiting for GracefulStop to return") + s.Stop() + }) + s.GracefulStop() + timer.Stop() +} + +// TestDeadlineSetOnConnectionOnClientCredentialHandshake tests that there is a deadline +// set on the net.Conn when a credential handshake happens in http2_client. +func (s) TestDeadlineSetOnConnectionOnClientCredentialHandshake(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + connCh := make(chan net.Conn, 1) + go func() { + defer close(connCh) + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + connCh <- conn + }() + defer func() { + conn := <-connCh + if conn != nil { + conn.Close() + } + }() + deadlineCh := testutils.NewChannel() + cvd := &credentialsVerifyDeadline{ + deadlineCh: deadlineCh, + } + dOpt := grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + return &infoConn{Conn: conn}, nil + }) + cc, err := grpc.Dial(lis.Addr().String(), dOpt, grpc.WithTransportCredentials(cvd)) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + deadline, err := deadlineCh.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from credsInvoked: %v", err) + } + // Default connection timeout is 20 seconds, so if the deadline exceeds now + // + 18 seconds it should be valid. + if !deadline.(time.Time).After(time.Now().Add(time.Second * 18)) { + t.Fatalf("Connection did not have deadline set.") + } +} + +type infoConn struct { + net.Conn + deadline time.Time +} + +func (c *infoConn) SetDeadline(t time.Time) error { + c.deadline = t + return c.Conn.SetDeadline(t) +} + +type credentialsVerifyDeadline struct { + deadlineCh *testutils.Channel +} + +func (cvd *credentialsVerifyDeadline) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} + +func (cvd *credentialsVerifyDeadline) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + cvd.deadlineCh.Send(rawConn.(*infoConn).deadline) + return rawConn, nil, nil +} + +func (cvd *credentialsVerifyDeadline) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (cvd *credentialsVerifyDeadline) Clone() credentials.TransportCredentials { + return cvd +} +func (cvd *credentialsVerifyDeadline) OverrideServerName(s string) error { + return nil +} + +func unaryInterceptorVerifyConn(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + conn := transport.GetConnection(ctx) + if conn == nil { + return nil, status.Error(codes.NotFound, "connection was not in context") + } + return nil, status.Error(codes.OK, "") +} + +// TestUnaryServerInterceptorGetsConnection tests whether the accepted conn on +// the server gets to any unary interceptors on the server side. +func (s) TestUnaryServerInterceptorGetsConnection(t *testing.T) { + ss := &stubserver.StubServer{} + if err := ss.Start([]grpc.ServerOption{grpc.UnaryInterceptor(unaryInterceptorVerifyConn)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.OK { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v, want _, error code %s", err, codes.OK) + } +} + +func streamingInterceptorVerifyConn(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + conn := transport.GetConnection(ss.Context()) + if conn == nil { + return status.Error(codes.NotFound, "connection was not in context") + } + return status.Error(codes.OK, "") +} + +// TestStreamingServerInterceptorGetsConnection tests whether the accepted conn on +// the server gets to any streaming interceptors on the server side. +func (s) TestStreamingServerInterceptorGetsConnection(t *testing.T) { + ss := &stubserver.StubServer{} + if err := ss.Start([]grpc.ServerOption{grpc.StreamInterceptor(streamingInterceptorVerifyConn)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + s, err := ss.Client.StreamingOutputCall(ctx, &testpb.StreamingOutputCallRequest{}) + if err != nil { + t.Fatalf("ss.Client.StreamingOutputCall(_) = _, %v, want _, ", err) + } + if _, err := s.Recv(); err != io.EOF { + t.Fatalf("ss.Client.StreamingInputCall(_) = _, %v, want _, %v", err, io.EOF) + } +} + +// unaryInterceptorVerifyAuthority verifies there is an unambiguous :authority +// once the request gets to an interceptor. An unambiguous :authority is defined +// as at most a single :authority header, and no host header according to A41. +func unaryInterceptorVerifyAuthority(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.NotFound, "metadata was not in context") + } + authority := md.Get(":authority") + if len(authority) > 1 { // Should be an unambiguous authority by the time it gets to interceptor. + return nil, status.Error(codes.NotFound, ":authority value had more than one value") + } + // Host header shouldn't be present by the time it gets to the interceptor + // level (should either be renamed to :authority or explicitly deleted). + host := md.Get("host") + if len(host) != 0 { + return nil, status.Error(codes.NotFound, "host header should not be present in metadata") + } + // Pass back the authority for verification on client - NotFound so + // grpc-message will be available to read for verification. + if len(authority) == 0 { + // Represent no :authority header present with an empty string. + return nil, status.Error(codes.NotFound, "") + } + return nil, status.Error(codes.NotFound, authority[0]) +} + +// TestAuthorityHeader tests that the eventual :authority that reaches the grpc +// layer is unambiguous due to logic added in A41. +func (s) TestAuthorityHeader(t *testing.T) { + tests := []struct { + name string + headers []string + wantAuthority string + }{ + // "If :authority is missing, Host must be renamed to :authority." - A41 + { + name: "Missing :authority", + // Codepath triggered by incoming headers with no authority but with + // a host. + headers: []string{ + ":method", "POST", + ":path", "/grpc.testing.TestService/UnaryCall", + "content-type", "application/grpc", + "te", "trailers", + "host", "localhost", + }, + wantAuthority: "localhost", + }, + { + name: "Missing :authority and host", + // Codepath triggered by incoming headers with no :authority and no + // host. + headers: []string{ + ":method", "POST", + ":path", "/grpc.testing.TestService/UnaryCall", + "content-type", "application/grpc", + "te", "trailers", + }, + wantAuthority: "", + }, + // "If :authority is present, Host must be discarded." - A41 + { + name: ":authority and host present", + // Codepath triggered by incoming headers with both an authority + // header and a host header. + headers: []string{ + ":method", "POST", + ":path", "/grpc.testing.TestService/UnaryCall", + ":authority", "localhost", + "content-type", "application/grpc", + "host", "localhost2", + }, + wantAuthority: "localhost", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + te := newTest(t, tcpClearRREnv) + ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }} + te.unaryServerInt = unaryInterceptorVerifyAuthority + te.startServer(ts) + defer te.tearDown() + success := testutils.NewChannel() + te.withServerTester(func(st *serverTester) { + st.writeHeaders(http2.HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(test.headers...), + EndStream: false, + EndHeaders: true, + }) + st.writeData(1, true, []byte{0, 0, 0, 0, 0}) + + for { + frame := st.wantAnyFrame() + f, ok := frame.(*http2.MetaHeadersFrame) + if !ok { + continue + } + for _, header := range f.Fields { + if header.Name == "grpc-message" { + success.Send(header.Value) + return + } + } + } + }) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotAuthority, err := success.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from channel: %v", err) + } + if gotAuthority != test.wantAuthority { + t.Fatalf("gotAuthority: %v, wantAuthority %v", gotAuthority, test.wantAuthority) + } + }) + } +} + +// wrapCloseListener tracks Accepts/Closes and maintains a counter of the +// number of open connections. +type wrapCloseListener struct { + net.Listener + connsOpen int32 +} + +// wrapCloseListener is returned by wrapCloseListener.Accept and decrements its +// connsOpen when Close is called. +type wrapCloseConn struct { + net.Conn + lis *wrapCloseListener + closeOnce sync.Once +} + +func (w *wrapCloseListener) Accept() (net.Conn, error) { + conn, err := w.Listener.Accept() + if err != nil { + return nil, err + } + atomic.AddInt32(&w.connsOpen, 1) + return &wrapCloseConn{Conn: conn, lis: w}, nil +} + +func (w *wrapCloseConn) Close() error { + defer w.closeOnce.Do(func() { atomic.AddInt32(&w.lis.connsOpen, -1) }) + return w.Conn.Close() +} + +// TestServerClosesConn ensures conn.Close is always closed even if the client +// doesn't complete the HTTP/2 handshake. +func (s) TestServerClosesConn(t *testing.T) { + lis := bufconn.Listen(20) + wrapLis := &wrapCloseListener{Listener: lis} + + s := grpc.NewServer() + go s.Serve(wrapLis) + defer s.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + for i := 0; i < 10; i++ { + conn, err := lis.DialContext(ctx) + if err != nil { + t.Fatalf("Dial = _, %v; want _, nil", err) + } + conn.Close() + } + for ctx.Err() == nil { + if atomic.LoadInt32(&wrapLis.connsOpen) == 0 { + return + } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("timed out waiting for conns to be closed by server; still open: %v", atomic.LoadInt32(&wrapLis.connsOpen)) +} diff --git a/test/go_vet/vet.go b/test/go_vet/vet.go deleted file mode 100644 index 475e8d683fc..00000000000 --- a/test/go_vet/vet.go +++ /dev/null @@ -1,53 +0,0 @@ -/* - * - * Copyright 2018 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// vet checks whether files that are supposed to be built on appengine running -// Go 1.10 or earlier import an unsupported package (e.g. "unsafe", "syscall"). -package main - -import ( - "fmt" - "go/build" - "os" -) - -func main() { - fail := false - b := build.Default - b.BuildTags = []string{"appengine", "appenginevm"} - argsWithoutProg := os.Args[1:] - for _, dir := range argsWithoutProg { - p, err := b.Import(".", dir, 0) - if _, ok := err.(*build.NoGoError); ok { - continue - } else if err != nil { - fmt.Printf("build.Import failed due to %v\n", err) - fail = true - continue - } - for _, pkg := range p.Imports { - if pkg == "syscall" || pkg == "unsafe" { - fmt.Printf("Package %s/%s importing %s package without appengine build tag is NOT ALLOWED!\n", p.Dir, p.Name, pkg) - fail = true - } - } - } - if fail { - os.Exit(1) - } -} diff --git a/test/grpc_testing/test_grpc.pb.go b/test/grpc_testing/test_grpc.pb.go index ab3b68a92bc..76b3935620c 100644 --- a/test/grpc_testing/test_grpc.pb.go +++ b/test/grpc_testing/test_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: test/grpc_testing/test.proto package grpc_testing diff --git a/test/insecure_creds_test.go b/test/insecure_creds_test.go index 19f8bb8b791..791cf650887 100644 --- a/test/insecure_creds_test.go +++ b/test/insecure_creds_test.go @@ -124,21 +124,19 @@ func (s) TestInsecureCreds(t *testing.T) { go s.Serve(lis) addr := lis.Addr().String() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - cOpts := []grpc.DialOption{grpc.WithBlock()} + opts := []grpc.DialOption{grpc.WithInsecure()} if test.clientInsecureCreds { - cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } else { - cOpts = append(cOpts, grpc.WithInsecure()) + opts = []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} } - cc, err := grpc.DialContext(ctx, addr, cOpts...) + cc, err := grpc.Dial(addr, opts...) if err != nil { t.Fatalf("grpc.Dial(%q) failed: %v", addr, err) } defer cc.Close() c := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() if _, err = c.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("EmptyCall(_, _) = _, %v; want _, ", err) } @@ -151,19 +149,16 @@ func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) { desc string perRPCCredsViaDialOptions bool perRPCCredsViaCallOptions bool - wantErr string }{ { desc: "send PerRPCCredentials via DialOptions", perRPCCredsViaDialOptions: true, perRPCCredsViaCallOptions: false, - wantErr: "context deadline exceeded", }, { desc: "send PerRPCCredentials via CallOptions", perRPCCredsViaDialOptions: false, perRPCCredsViaCallOptions: true, - wantErr: "transport: cannot send secure credentials on an insecure connection", }, } for _, test := range tests { @@ -174,44 +169,38 @@ func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) { }, } - sOpts := []grpc.ServerOption{} - sOpts = append(sOpts, grpc.Creds(insecure.NewCredentials())) - s := grpc.NewServer(sOpts...) + s := grpc.NewServer(grpc.Creds(insecure.NewCredentials())) defer s.Stop() - testpb.RegisterTestServiceServer(s, ss) lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("net.Listen(tcp, localhost:0) failed: %v", err) } - go s.Serve(lis) addr := lis.Addr().String() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cOpts := []grpc.DialOption{grpc.WithBlock()} - cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + dopts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} if test.perRPCCredsViaDialOptions { - cOpts = append(cOpts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{})) - if _, err := grpc.DialContext(ctx, addr, cOpts...); !strings.Contains(err.Error(), test.wantErr) { - t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_DialOptions = %v; want %s", err, test.wantErr) - } + dopts = append(dopts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{})) } - + copts := []grpc.CallOption{} if test.perRPCCredsViaCallOptions { - cc, err := grpc.DialContext(ctx, addr, cOpts...) - if err != nil { - t.Fatalf("grpc.Dial(%q) failed: %v", addr, err) - } - defer cc.Close() - - c := testpb.NewTestServiceClient(cc) - if _, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})); !strings.Contains(err.Error(), test.wantErr) { - t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, test.wantErr) - } + copts = append(copts, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})) + } + cc, err := grpc.Dial(addr, dopts...) + if err != nil { + t.Fatalf("grpc.Dial(%q) failed: %v", addr, err) + } + defer cc.Close() + + const wantErr = "transport: cannot send secure credentials on an insecure connection" + c := testpb.NewTestServiceClient(cc) + if _, err = c.EmptyCall(ctx, &testpb.Empty{}, copts...); err == nil || !strings.Contains(err.Error(), wantErr) { + t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, wantErr) } }) } diff --git a/test/kokoro/xds.cfg b/test/kokoro/xds.cfg index d1a078217b8..a1e4ed0bb5e 100644 --- a/test/kokoro/xds.cfg +++ b/test/kokoro/xds.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-go/test/kokoro/xds.sh" -timeout_mins: 120 +timeout_mins: 360 action { define_artifacts { regex: "**/*sponge_log.*" diff --git a/test/kokoro/xds.sh b/test/kokoro/xds.sh index 36a3f4563cf..7b7f48dba30 100755 --- a/test/kokoro/xds.sh +++ b/test/kokoro/xds.sh @@ -3,6 +3,12 @@ set -exu -o pipefail [[ -f /VERSION ]] && cat /VERSION + +echo "Remove the expired letsencrypt.org cert and update the CA certificates" +sudo apt-get install -y ca-certificates +sudo rm /usr/share/ca-certificates/mozilla/DST_Root_CA_X3.crt +sudo update-ca-certificates + cd github export GOPATH="${HOME}/gopath" @@ -27,7 +33,7 @@ grpc/tools/run_tests/helper_scripts/prep_xds.sh # they are added into "all". GRPC_GO_LOG_VERBOSITY_LEVEL=99 GRPC_GO_LOG_SEVERITY_LEVEL=info \ python3 grpc/tools/run_tests/run_xds_tests.py \ - --test_case="all,path_matching,header_matching,circuit_breaking,timeout" \ + --test_case="all,circuit_breaking,timeout,fault_injection,csds" \ --project_id=grpc-testing \ --project_num=830293263384 \ --source_image=projects/grpc-testing/global/images/xds-test-server-4 \ diff --git a/test/kokoro/xds_k8s.cfg b/test/kokoro/xds_k8s.cfg new file mode 100644 index 00000000000..4d5e019991f --- /dev/null +++ b/test/kokoro/xds_k8s.cfg @@ -0,0 +1,13 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-go/test/kokoro/xds_k8s.sh" +timeout_mins: 120 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*sponge_log.log" + strip_prefix: "artifacts" + } +} diff --git a/test/kokoro/xds_k8s.sh b/test/kokoro/xds_k8s.sh new file mode 100755 index 00000000000..f91d1d026d6 --- /dev/null +++ b/test/kokoro/xds_k8s.sh @@ -0,0 +1,155 @@ +#!/usr/bin/env bash +# Copyright 2021 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +# Constants +readonly GITHUB_REPOSITORY_NAME="grpc-go" +readonly TEST_DRIVER_INSTALL_SCRIPT_URL="https://raw.githubusercontent.com/${TEST_DRIVER_REPO_OWNER:-grpc}/grpc/${TEST_DRIVER_BRANCH:-master}/tools/internal_ci/linux/grpc_xds_k8s_install_test_driver.sh" +## xDS test server/client Docker images +readonly SERVER_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/go-server" +readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/go-client" +readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" + +####################################### +# Builds test app Docker images and pushes them to GCR +# Globals: +# SERVER_IMAGE_NAME: Test server Docker image name +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# None +# Outputs: +# Writes the output of `gcloud builds submit` to stdout, stderr +####################################### +build_test_app_docker_images() { + echo "Building Go xDS interop test app Docker images" + docker build -f "${SRC_DIR}/interop/xds/client/Dockerfile" -t "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" "${SRC_DIR}" + docker build -f "${SRC_DIR}/interop/xds/server/Dockerfile" -t "${SERVER_IMAGE_NAME}:${GIT_COMMIT}" "${SRC_DIR}" + gcloud -q auth configure-docker + docker push "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" + docker push "${SERVER_IMAGE_NAME}:${GIT_COMMIT}" + if [[ -n $KOKORO_JOB_NAME ]]; then + branch_name=$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/go/([^/]+)/.*|\1|') + tag_and_push_docker_image "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" "${branch_name}" + tag_and_push_docker_image "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}" "${branch_name}" + fi +} + +####################################### +# Builds test app and its docker images unless they already exist +# Globals: +# SERVER_IMAGE_NAME: Test server Docker image name +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# FORCE_IMAGE_BUILD +# Arguments: +# None +# Outputs: +# Writes the output to stdout, stderr +####################################### +build_docker_images_if_needed() { + # Check if images already exist + server_tags="$(gcloud_gcr_list_image_tags "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}")" + printf "Server image: %s:%s\n" "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}" + echo "${server_tags:-Server image not found}" + + client_tags="$(gcloud_gcr_list_image_tags "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}")" + printf "Client image: %s:%s\n" "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" + echo "${client_tags:-Client image not found}" + + # Build if any of the images are missing, or FORCE_IMAGE_BUILD=1 + if [[ "${FORCE_IMAGE_BUILD}" == "1" || -z "${server_tags}" || -z "${client_tags}" ]]; then + build_test_app_docker_images + else + echo "Skipping Go test app build" + fi +} + +####################################### +# Executes the test case +# Globals: +# TEST_DRIVER_FLAGFILE: Relative path to test driver flagfile +# KUBE_CONTEXT: The name of kubectl context with GKE cluster access +# TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report +# SERVER_IMAGE_NAME: Test server Docker image name +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# Test case name +# Outputs: +# Writes the output of test execution to stdout, stderr +# Test xUnit report to ${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml +####################################### +run_test() { + # Test driver usage: + # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage + local test_name="${1:?Usage: run_test test_name}" + set -x + python -m "tests.${test_name}" \ + --flagfile="${TEST_DRIVER_FLAGFILE}" \ + --kube_context="${KUBE_CONTEXT}" \ + --server_image="${SERVER_IMAGE_NAME}:${GIT_COMMIT}" \ + --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ + --xml_output_file="${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml" \ + --force_cleanup \ + --nocheck_local_certs + set +x +} + +####################################### +# Main function: provision software necessary to execute tests, and run them +# Globals: +# KOKORO_ARTIFACTS_DIR +# GITHUB_REPOSITORY_NAME +# SRC_DIR: Populated with absolute path to the source repo +# TEST_DRIVER_REPO_DIR: Populated with the path to the repo containing +# the test driver +# TEST_DRIVER_FULL_DIR: Populated with the path to the test driver source code +# TEST_DRIVER_FLAGFILE: Populated with relative path to test driver flagfile +# TEST_XML_OUTPUT_DIR: Populated with the path to test xUnit XML report +# GIT_ORIGIN_URL: Populated with the origin URL of git repo used for the build +# GIT_COMMIT: Populated with the SHA-1 of git commit being built +# GIT_COMMIT_SHORT: Populated with the short SHA-1 of git commit being built +# KUBE_CONTEXT: Populated with name of kubectl context with GKE cluster access +# Arguments: +# None +# Outputs: +# Writes the output of test execution to stdout, stderr +####################################### +main() { + local script_dir + script_dir="$(dirname "$0")" + + # Source the test driver from the master branch. + echo "Sourcing test driver install script from: ${TEST_DRIVER_INSTALL_SCRIPT_URL}" + source /dev/stdin <<< "$(curl -s "${TEST_DRIVER_INSTALL_SCRIPT_URL}")" + + activate_gke_cluster GKE_CLUSTER_PSM_SECURITY + + set -x + if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then + kokoro_setup_test_driver "${GITHUB_REPOSITORY_NAME}" + else + local_setup_test_driver "${script_dir}" + fi + build_docker_images_if_needed + # Run tests + cd "${TEST_DRIVER_FULL_DIR}" + run_test baseline_test + run_test security_test +} + +main "$@" diff --git a/test/kokoro/xds_url_map.cfg b/test/kokoro/xds_url_map.cfg new file mode 100644 index 00000000000..f6fd84a419a --- /dev/null +++ b/test/kokoro/xds_url_map.cfg @@ -0,0 +1,13 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-go/test/kokoro/xds_url_map.sh" +timeout_mins: 60 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*sponge_log.log" + strip_prefix: "artifacts" + } +} diff --git a/test/kokoro/xds_url_map.sh b/test/kokoro/xds_url_map.sh new file mode 100755 index 00000000000..34805d43a13 --- /dev/null +++ b/test/kokoro/xds_url_map.sh @@ -0,0 +1,138 @@ +#!/usr/bin/env bash +# Copyright 2021 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +# Constants +readonly GITHUB_REPOSITORY_NAME="grpc-go" +readonly TEST_DRIVER_INSTALL_SCRIPT_URL="https://raw.githubusercontent.com/${TEST_DRIVER_REPO_OWNER:-grpc}/grpc/${TEST_DRIVER_BRANCH:-master}/tools/internal_ci/linux/grpc_xds_k8s_install_test_driver.sh" +## xDS test client Docker images +readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/go-client" +readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" + +####################################### +# Builds test app Docker images and pushes them to GCR +# Globals: +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# None +# Outputs: +# Writes the output of `gcloud builds submit` to stdout, stderr +####################################### +build_test_app_docker_images() { + echo "Building Go xDS interop test app Docker images" + docker build -f "${SRC_DIR}/interop/xds/client/Dockerfile" -t "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" "${SRC_DIR}" + gcloud -q auth configure-docker + docker push "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" +} + +####################################### +# Builds test app and its docker images unless they already exist +# Globals: +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# FORCE_IMAGE_BUILD +# Arguments: +# None +# Outputs: +# Writes the output to stdout, stderr +####################################### +build_docker_images_if_needed() { + # Check if images already exist + client_tags="$(gcloud_gcr_list_image_tags "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}")" + printf "Client image: %s:%s\n" "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" + echo "${client_tags:-Client image not found}" + + # Build if any of the images are missing, or FORCE_IMAGE_BUILD=1 + if [[ "${FORCE_IMAGE_BUILD}" == "1" || -z "${client_tags}" ]]; then + build_test_app_docker_images + else + echo "Skipping Go test app build" + fi +} + +####################################### +# Executes the test case +# Globals: +# TEST_DRIVER_FLAGFILE: Relative path to test driver flagfile +# KUBE_CONTEXT: The name of kubectl context with GKE cluster access +# TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# Test case name +# Outputs: +# Writes the output of test execution to stdout, stderr +# Test xUnit report to ${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml +####################################### +run_test() { + # Test driver usage: + # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage + local test_name="${1:?Usage: run_test test_name}" + set -x + python -m "tests.${test_name}" \ + --flagfile="${TEST_DRIVER_FLAGFILE}" \ + --kube_context="${KUBE_CONTEXT}" \ + --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ + --testing_version=$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/go/([^/]+)/.*|\1|') \ + --xml_output_file="${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml" \ + --flagfile="config/url-map.cfg" + set +x +} + +####################################### +# Main function: provision software necessary to execute tests, and run them +# Globals: +# KOKORO_ARTIFACTS_DIR +# GITHUB_REPOSITORY_NAME +# SRC_DIR: Populated with absolute path to the source repo +# TEST_DRIVER_REPO_DIR: Populated with the path to the repo containing +# the test driver +# TEST_DRIVER_FULL_DIR: Populated with the path to the test driver source code +# TEST_DRIVER_FLAGFILE: Populated with relative path to test driver flagfile +# TEST_XML_OUTPUT_DIR: Populated with the path to test xUnit XML report +# GIT_ORIGIN_URL: Populated with the origin URL of git repo used for the build +# GIT_COMMIT: Populated with the SHA-1 of git commit being built +# GIT_COMMIT_SHORT: Populated with the short SHA-1 of git commit being built +# KUBE_CONTEXT: Populated with name of kubectl context with GKE cluster access +# Arguments: +# None +# Outputs: +# Writes the output of test execution to stdout, stderr +####################################### +main() { + local script_dir + script_dir="$(dirname "$0")" + + # Source the test driver from the master branch. + echo "Sourcing test driver install script from: ${TEST_DRIVER_INSTALL_SCRIPT_URL}" + source /dev/stdin <<< "$(curl -s "${TEST_DRIVER_INSTALL_SCRIPT_URL}")" + + activate_gke_cluster GKE_CLUSTER_PSM_BASIC + + set -x + if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then + kokoro_setup_test_driver "${GITHUB_REPOSITORY_NAME}" + else + local_setup_test_driver "${script_dir}" + fi + build_docker_images_if_needed + # Run tests + cd "${TEST_DRIVER_FULL_DIR}" + run_test url_map +} + +main "$@" diff --git a/test/kokoro/xds_v3.cfg b/test/kokoro/xds_v3.cfg index c4c8aad9e6f..1991efd325d 100644 --- a/test/kokoro/xds_v3.cfg +++ b/test/kokoro/xds_v3.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-go/test/kokoro/xds_v3.sh" -timeout_mins: 120 +timeout_mins: 360 action { define_artifacts { regex: "**/*sponge_log.*" diff --git a/test/race.go b/test/race.go index acfa0dfae37..d99f0a410ac 100644 --- a/test/race.go +++ b/test/race.go @@ -1,3 +1,4 @@ +//go:build race // +build race /* diff --git a/test/retry_test.go b/test/retry_test.go index f93c9ac053f..7f068d79f44 100644 --- a/test/retry_test.go +++ b/test/retry_test.go @@ -22,9 +22,11 @@ import ( "context" "fmt" "io" - "os" + "net" + "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -34,6 +36,7 @@ import ( "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" ) @@ -112,68 +115,6 @@ func (s) TestRetryUnary(t *testing.T) { } } -func (s) TestRetryDisabledByDefault(t *testing.T) { - if strings.EqualFold(os.Getenv("GRPC_GO_RETRY"), "on") { - return - } - i := -1 - ss := &stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - i++ - switch i { - case 0: - return nil, status.New(codes.AlreadyExists, "retryable error").Err() - } - return &testpb.Empty{}, nil - }, - } - if err := ss.Start([]grpc.ServerOption{}); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) - } - defer ss.Stop() - ss.NewServiceConfig(`{ - "methodConfig": [{ - "name": [{"service": "grpc.testing.TestService"}], - "waitForReady": true, - "retryPolicy": { - "MaxAttempts": 4, - "InitialBackoff": ".01s", - "MaxBackoff": ".01s", - "BackoffMultiplier": 1.0, - "RetryableStatusCodes": [ "ALREADY_EXISTS" ] - } - }]}`) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - for { - if ctx.Err() != nil { - t.Fatalf("Timed out waiting for service config update") - } - if ss.CC.GetMethodConfig("/grpc.testing.TestService/EmptyCall").WaitForReady != nil { - break - } - time.Sleep(time.Millisecond) - } - cancel() - - testCases := []struct { - code codes.Code - count int - }{ - {codes.AlreadyExists, 0}, - } - for _, tc := range testCases { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) - cancel() - if status.Code(err) != tc.code { - t.Fatalf("EmptyCall(_, _) = _, %v; want _, ", err, tc.code) - } - if i != tc.count { - t.Fatalf("i = %v; want %v", i, tc.count) - } - } -} - func (s) TestRetryThrottling(t *testing.T) { defer enableRetry()() i := -1 @@ -549,3 +490,167 @@ func (s) TestRetryStreaming(t *testing.T) { }() } } + +type retryStatsHandler struct { + mu sync.Mutex + s []stats.RPCStats +} + +func (*retryStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { + return ctx +} +func (h *retryStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) { + h.mu.Lock() + h.s = append(h.s, s) + h.mu.Unlock() +} +func (*retryStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { + return ctx +} +func (*retryStatsHandler) HandleConn(context.Context, stats.ConnStats) {} + +func (s) TestRetryStats(t *testing.T) { + defer enableRetry()() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen. Err: %v", err) + } + defer lis.Close() + server := &httpServer{ + waitForEndStream: true, + responses: []httpServerResponse{{ + trailers: [][]string{{ + ":status", "200", + "content-type", "application/grpc", + "grpc-status", "14", // UNAVAILABLE + "grpc-message", "unavailable retry", + "grpc-retry-pushback-ms", "10", + }}, + }, { + headers: [][]string{{ + ":status", "200", + "content-type", "application/grpc", + }}, + payload: []byte{0, 0, 0, 0, 0}, // header for 0-byte response message. + trailers: [][]string{{ + "grpc-status", "0", // OK + }}, + }}, + refuseStream: func(i uint32) bool { + return i == 1 + }, + } + server.start(t, lis) + handler := &retryStatsHandler{} + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithStatsHandler(handler), + grpc.WithDefaultServiceConfig((`{ + "methodConfig": [{ + "name": [{"service": "grpc.testing.TestService"}], + "retryPolicy": { + "MaxAttempts": 4, + "InitialBackoff": ".01s", + "MaxBackoff": ".01s", + "BackoffMultiplier": 1.0, + "RetryableStatusCodes": [ "UNAVAILABLE" ] + } + }]}`))) + if err != nil { + t.Fatalf("failed to dial due to err: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + client := testpb.NewTestServiceClient(cc) + + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("unexpected EmptyCall error: %v", err) + } + handler.mu.Lock() + want := []stats.RPCStats{ + &stats.Begin{}, + &stats.OutHeader{FullMethod: "/grpc.testing.TestService/EmptyCall"}, + &stats.OutPayload{WireLength: 5}, + &stats.End{}, + + &stats.Begin{IsTransparentRetryAttempt: true}, + &stats.OutHeader{FullMethod: "/grpc.testing.TestService/EmptyCall"}, + &stats.OutPayload{WireLength: 5}, + &stats.InTrailer{Trailer: metadata.Pairs("content-type", "application/grpc", "grpc-retry-pushback-ms", "10")}, + &stats.End{}, + + &stats.Begin{}, + &stats.OutHeader{FullMethod: "/grpc.testing.TestService/EmptyCall"}, + &stats.OutPayload{WireLength: 5}, + &stats.InHeader{}, + &stats.InPayload{WireLength: 5}, + &stats.InTrailer{}, + &stats.End{}, + } + + toString := func(ss []stats.RPCStats) (ret []string) { + for _, s := range ss { + ret = append(ret, fmt.Sprintf("%T - %v", s, s)) + } + return ret + } + t.Logf("Handler received frames:\n%v\n---\nwant:\n%v\n", + strings.Join(toString(handler.s), "\n"), + strings.Join(toString(want), "\n")) + + if len(handler.s) != len(want) { + t.Fatalf("received unexpected number of RPCStats: got %v; want %v", len(handler.s), len(want)) + } + + // There is a race between receiving the payload (triggered by the + // application / gRPC library) and receiving the trailer (triggered at the + // transport layer). Adjust the received stats accordingly if necessary. + const tIdx, pIdx = 13, 14 + _, okT := handler.s[tIdx].(*stats.InTrailer) + _, okP := handler.s[pIdx].(*stats.InPayload) + if okT && okP { + handler.s[pIdx], handler.s[tIdx] = handler.s[tIdx], handler.s[pIdx] + } + + for i := range handler.s { + w, s := want[i], handler.s[i] + + // Validate the event type + if reflect.TypeOf(w) != reflect.TypeOf(s) { + t.Fatalf("at position %v: got %T; want %T", i, s, w) + } + wv, sv := reflect.ValueOf(w).Elem(), reflect.ValueOf(s).Elem() + + // Validate that Client is always true + if sv.FieldByName("Client").Interface().(bool) != true { + t.Fatalf("at position %v: got Client=false; want true", i) + } + + // Validate any set fields in want + for i := 0; i < wv.NumField(); i++ { + if !wv.Field(i).IsZero() { + if got, want := sv.Field(i).Interface(), wv.Field(i).Interface(); !reflect.DeepEqual(got, want) { + name := reflect.TypeOf(w).Elem().Field(i).Name + t.Fatalf("at position %v, field %v: got %v; want %v", i, name, got, want) + } + } + } + + // Since the above only tests non-zero-value fields, test + // IsTransparentRetryAttempt=false explicitly when needed. + if wb, ok := w.(*stats.Begin); ok && !wb.IsTransparentRetryAttempt { + if s.(*stats.Begin).IsTransparentRetryAttempt { + t.Fatalf("at position %v: got IsTransparentRetryAttempt=true; want false", i) + } + } + } + + // Validate timings between last Begin and preceding End. + end := handler.s[8].(*stats.End) + begin := handler.s[9].(*stats.Begin) + diff := begin.BeginTime.Sub(end.EndTime) + if diff < 10*time.Millisecond || diff > 50*time.Millisecond { + t.Fatalf("pushback time before final attempt = %v; want ~10ms", diff) + } +} diff --git a/test/tools/go.mod b/test/tools/go.mod index 874268d34fc..9c964971413 100644 --- a/test/tools/go.mod +++ b/test/tools/go.mod @@ -1,6 +1,6 @@ module google.golang.org/grpc/test/tools -go 1.11 +go 1.14 require ( github.com/client9/misspell v0.3.4 diff --git a/test/tools/tools.go b/test/tools/tools.go index 511dc253446..646a144ccca 100644 --- a/test/tools/tools.go +++ b/test/tools/tools.go @@ -1,3 +1,4 @@ +//go:build tools // +build tools /* @@ -18,10 +19,9 @@ * */ -// This package exists to cause `go mod` and `go get` to believe these tools -// are dependencies, even though they are not runtime dependencies of any grpc -// package. This means they will appear in our `go.mod` file, but will not be -// a part of the build. +// This file is not intended to be compiled. Because some of these imports are +// not actual go packages, we use a build constraint at the top of this file to +// prevent tools from inspecting the imports. package tools diff --git a/xds/go113.go b/test/tools/tools_vet.go similarity index 75% rename from xds/go113.go rename to test/tools/tools_vet.go index 40f82cde5c1..06ab2fd10be 100644 --- a/xds/go113.go +++ b/test/tools/tools_vet.go @@ -1,8 +1,6 @@ -// +build go1.13 - /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +16,6 @@ * */ -package xds - -import ( - _ "google.golang.org/grpc/credentials/tls/certprovider/meshca" // Register the MeshCA certificate provider plugin. -) +// Package tools is used to pin specific versions of external tools in this +// module's go.mod that gRPC uses for internal testing. +package tools diff --git a/version.go b/version.go index c149b22ad8a..6ba1fd5bdb7 100644 --- a/version.go +++ b/version.go @@ -19,4 +19,4 @@ package grpc // Version is the current grpc version. -const Version = "1.37.0-dev" +const Version = "1.42.0-dev" diff --git a/vet.sh b/vet.sh index dcd939bb390..d923187a7b3 100755 --- a/vet.sh +++ b/vet.sh @@ -32,26 +32,14 @@ PATH="${HOME}/go/bin:${GOROOT}/bin:${PATH}" go version if [[ "$1" = "-install" ]]; then - # Check for module support - if go help mod >& /dev/null; then - # Install the pinned versions as defined in module tools. - pushd ./test/tools - go install \ - golang.org/x/lint/golint \ - golang.org/x/tools/cmd/goimports \ - honnef.co/go/tools/cmd/staticcheck \ - github.com/client9/misspell/cmd/misspell - popd - else - # Ye olde `go get` incantation. - # Note: this gets the latest version of all tools (vs. the pinned versions - # with Go modules). - go get -u \ - golang.org/x/lint/golint \ - golang.org/x/tools/cmd/goimports \ - honnef.co/go/tools/cmd/staticcheck \ - github.com/client9/misspell/cmd/misspell - fi + # Install the pinned versions as defined in module tools. + pushd ./test/tools + go install \ + golang.org/x/lint/golint \ + golang.org/x/tools/cmd/goimports \ + honnef.co/go/tools/cmd/staticcheck \ + github.com/client9/misspell/cmd/misspell + popd if [[ -z "${VET_SKIP_PROTO}" ]]; then if [[ "${TRAVIS}" = "true" ]]; then PROTOBUF_VERSION=3.14.0 @@ -101,16 +89,6 @@ not git grep "\(import \|^\s*\)\"github.com/golang/protobuf/ptypes/" -- "*.go" # - Ensure all xds proto imports are renamed to *pb or *grpc. git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.pb.go' | not grep -v 'pb "\|grpc "' -# - Check imports that are illegal in appengine (until Go 1.11). -# TODO: Remove when we drop Go 1.10 support -go list -f {{.Dir}} ./... | xargs go run test/go_vet/vet.go - -# - gofmt, goimports, golint (with exceptions for generated code), go vet. -gofmt -s -d -l . 2>&1 | fail_on_output -goimports -l . 2>&1 | not grep -vE "\.pb\.go" -golint ./... 2>&1 | not grep -vE "/testv3\.pb\.go:" -go vet -all ./... - misspell -error . # - Check that generated proto files are up to date. @@ -120,12 +98,22 @@ if [[ -z "${VET_SKIP_PROTO}" ]]; then (git status; git --no-pager diff; exit 1) fi -# - Check that our modules are tidy. -if go help mod >& /dev/null; then - find . -name 'go.mod' | xargs -IXXX bash -c 'cd $(dirname XXX); go mod tidy' +# - gofmt, goimports, golint (with exceptions for generated code), go vet, +# go mod tidy. +# Perform these checks on each module inside gRPC. +for MOD_FILE in $(find . -name 'go.mod'); do + MOD_DIR=$(dirname ${MOD_FILE}) + pushd ${MOD_DIR} + go vet -all ./... | fail_on_output + gofmt -s -d -l . 2>&1 | fail_on_output + goimports -l . 2>&1 | not grep -vE "\.pb\.go" + golint ./... 2>&1 | not grep -vE "/testv3\.pb\.go:" + + go mod tidy git status --porcelain 2>&1 | fail_on_output || \ (git status; git --no-pager diff; exit 1) -fi + popd +done # - Collection of static analysis checks # diff --git a/xds/csds/csds.go b/xds/csds/csds.go new file mode 100644 index 00000000000..c4477a55d1a --- /dev/null +++ b/xds/csds/csds.go @@ -0,0 +1,305 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package csds implements features to dump the status (xDS responses) the +// xds_client is using. +// +// Notice: This package is EXPERIMENTAL and may be changed or removed in a later +// release. +package csds + +import ( + "context" + "io" + "time" + + v3adminpb "github.com/envoyproxy/go-control-plane/envoy/admin/v3" + v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3statusgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + v3statuspb "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + "github.com/golang/protobuf/proto" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/protobuf/types/known/timestamppb" + + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register v2 xds_client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register v3 xds_client. +) + +var ( + logger = grpclog.Component("xds") + newXDSClient = func() xdsclient.XDSClient { + c, err := xdsclient.New() + if err != nil { + logger.Warningf("failed to create xds client: %v", err) + return nil + } + return c + } +) + +// ClientStatusDiscoveryServer implementations interface ClientStatusDiscoveryServiceServer. +type ClientStatusDiscoveryServer struct { + // xdsClient will always be the same in practice. But we keep a copy in each + // server instance for testing. + xdsClient xdsclient.XDSClient +} + +// NewClientStatusDiscoveryServer returns an implementation of the CSDS server that can be +// registered on a gRPC server. +func NewClientStatusDiscoveryServer() (*ClientStatusDiscoveryServer, error) { + return &ClientStatusDiscoveryServer{xdsClient: newXDSClient()}, nil +} + +// StreamClientStatus implementations interface ClientStatusDiscoveryServiceServer. +func (s *ClientStatusDiscoveryServer) StreamClientStatus(stream v3statusgrpc.ClientStatusDiscoveryService_StreamClientStatusServer) error { + for { + req, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + resp, err := s.buildClientStatusRespForReq(req) + if err != nil { + return err + } + if err := stream.Send(resp); err != nil { + return err + } + } +} + +// FetchClientStatus implementations interface ClientStatusDiscoveryServiceServer. +func (s *ClientStatusDiscoveryServer) FetchClientStatus(_ context.Context, req *v3statuspb.ClientStatusRequest) (*v3statuspb.ClientStatusResponse, error) { + return s.buildClientStatusRespForReq(req) +} + +// buildClientStatusRespForReq fetches the status from the client, and returns +// the response to be sent back to xdsclient. +// +// If it returns an error, the error is a status error. +func (s *ClientStatusDiscoveryServer) buildClientStatusRespForReq(req *v3statuspb.ClientStatusRequest) (*v3statuspb.ClientStatusResponse, error) { + if s.xdsClient == nil { + return &v3statuspb.ClientStatusResponse{}, nil + } + // Field NodeMatchers is unsupported, by design + // https://github.com/grpc/proposal/blob/master/A40-csds-support.md#detail-node-matching. + if len(req.NodeMatchers) != 0 { + return nil, status.Errorf(codes.InvalidArgument, "node_matchers are not supported, request contains node_matchers: %v", req.NodeMatchers) + } + + ret := &v3statuspb.ClientStatusResponse{ + Config: []*v3statuspb.ClientConfig{ + { + Node: nodeProtoToV3(s.xdsClient.BootstrapConfig().NodeProto), + XdsConfig: []*v3statuspb.PerXdsConfig{ + s.buildLDSPerXDSConfig(), + s.buildRDSPerXDSConfig(), + s.buildCDSPerXDSConfig(), + s.buildEDSPerXDSConfig(), + }, + }, + }, + } + return ret, nil +} + +// Close cleans up the resources. +func (s *ClientStatusDiscoveryServer) Close() { + if s.xdsClient != nil { + s.xdsClient.Close() + } +} + +// nodeProtoToV3 converts the given proto into a v3.Node. n is from bootstrap +// config, it can be either v2.Node or v3.Node. +// +// If n is already a v3.Node, return it. +// If n is v2.Node, marshal and unmarshal it to v3. +// Otherwise, return nil. +// +// The default case (not v2 or v3) is nil, instead of error, because the +// resources in the response are more important than the node. The worst case is +// that the user will receive no Node info, but will still get resources. +func nodeProtoToV3(n proto.Message) *v3corepb.Node { + var node *v3corepb.Node + switch nn := n.(type) { + case *v3corepb.Node: + node = nn + case *v2corepb.Node: + v2, err := proto.Marshal(nn) + if err != nil { + logger.Warningf("Failed to marshal node (%v): %v", n, err) + break + } + node = new(v3corepb.Node) + if err := proto.Unmarshal(v2, node); err != nil { + logger.Warningf("Failed to unmarshal node (%v): %v", v2, err) + } + default: + logger.Warningf("node from bootstrap is %#v, only v2.Node and v3.Node are supported", nn) + } + return node +} + +func (s *ClientStatusDiscoveryServer) buildLDSPerXDSConfig() *v3statuspb.PerXdsConfig { + version, dump := s.xdsClient.DumpLDS() + resources := make([]*v3adminpb.ListenersConfigDump_DynamicListener, 0, len(dump)) + for name, d := range dump { + configDump := &v3adminpb.ListenersConfigDump_DynamicListener{ + Name: name, + ClientStatus: serviceStatusToProto(d.MD.Status), + } + if (d.MD.Timestamp != time.Time{}) { + configDump.ActiveState = &v3adminpb.ListenersConfigDump_DynamicListenerState{ + VersionInfo: d.MD.Version, + Listener: d.Raw, + LastUpdated: timestamppb.New(d.MD.Timestamp), + } + } + if errState := d.MD.ErrState; errState != nil { + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + LastUpdateAttempt: timestamppb.New(errState.Timestamp), + Details: errState.Err.Error(), + VersionInfo: errState.Version, + } + } + resources = append(resources, configDump) + } + return &v3statuspb.PerXdsConfig{ + PerXdsConfig: &v3statuspb.PerXdsConfig_ListenerConfig{ + ListenerConfig: &v3adminpb.ListenersConfigDump{ + VersionInfo: version, + DynamicListeners: resources, + }, + }, + } +} + +func (s *ClientStatusDiscoveryServer) buildRDSPerXDSConfig() *v3statuspb.PerXdsConfig { + _, dump := s.xdsClient.DumpRDS() + resources := make([]*v3adminpb.RoutesConfigDump_DynamicRouteConfig, 0, len(dump)) + for _, d := range dump { + configDump := &v3adminpb.RoutesConfigDump_DynamicRouteConfig{ + VersionInfo: d.MD.Version, + ClientStatus: serviceStatusToProto(d.MD.Status), + } + if (d.MD.Timestamp != time.Time{}) { + configDump.RouteConfig = d.Raw + configDump.LastUpdated = timestamppb.New(d.MD.Timestamp) + } + if errState := d.MD.ErrState; errState != nil { + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + LastUpdateAttempt: timestamppb.New(errState.Timestamp), + Details: errState.Err.Error(), + VersionInfo: errState.Version, + } + } + resources = append(resources, configDump) + } + return &v3statuspb.PerXdsConfig{ + PerXdsConfig: &v3statuspb.PerXdsConfig_RouteConfig{ + RouteConfig: &v3adminpb.RoutesConfigDump{ + DynamicRouteConfigs: resources, + }, + }, + } +} + +func (s *ClientStatusDiscoveryServer) buildCDSPerXDSConfig() *v3statuspb.PerXdsConfig { + version, dump := s.xdsClient.DumpCDS() + resources := make([]*v3adminpb.ClustersConfigDump_DynamicCluster, 0, len(dump)) + for _, d := range dump { + configDump := &v3adminpb.ClustersConfigDump_DynamicCluster{ + VersionInfo: d.MD.Version, + ClientStatus: serviceStatusToProto(d.MD.Status), + } + if (d.MD.Timestamp != time.Time{}) { + configDump.Cluster = d.Raw + configDump.LastUpdated = timestamppb.New(d.MD.Timestamp) + } + if errState := d.MD.ErrState; errState != nil { + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + LastUpdateAttempt: timestamppb.New(errState.Timestamp), + Details: errState.Err.Error(), + VersionInfo: errState.Version, + } + } + resources = append(resources, configDump) + } + return &v3statuspb.PerXdsConfig{ + PerXdsConfig: &v3statuspb.PerXdsConfig_ClusterConfig{ + ClusterConfig: &v3adminpb.ClustersConfigDump{ + VersionInfo: version, + DynamicActiveClusters: resources, + }, + }, + } +} + +func (s *ClientStatusDiscoveryServer) buildEDSPerXDSConfig() *v3statuspb.PerXdsConfig { + _, dump := s.xdsClient.DumpEDS() + resources := make([]*v3adminpb.EndpointsConfigDump_DynamicEndpointConfig, 0, len(dump)) + for _, d := range dump { + configDump := &v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{ + VersionInfo: d.MD.Version, + ClientStatus: serviceStatusToProto(d.MD.Status), + } + if (d.MD.Timestamp != time.Time{}) { + configDump.EndpointConfig = d.Raw + configDump.LastUpdated = timestamppb.New(d.MD.Timestamp) + } + if errState := d.MD.ErrState; errState != nil { + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + LastUpdateAttempt: timestamppb.New(errState.Timestamp), + Details: errState.Err.Error(), + VersionInfo: errState.Version, + } + } + resources = append(resources, configDump) + } + return &v3statuspb.PerXdsConfig{ + PerXdsConfig: &v3statuspb.PerXdsConfig_EndpointConfig{ + EndpointConfig: &v3adminpb.EndpointsConfigDump{ + DynamicEndpointConfigs: resources, + }, + }, + } +} + +func serviceStatusToProto(serviceStatus xdsclient.ServiceStatus) v3adminpb.ClientResourceStatus { + switch serviceStatus { + case xdsclient.ServiceStatusUnknown: + return v3adminpb.ClientResourceStatus_UNKNOWN + case xdsclient.ServiceStatusRequested: + return v3adminpb.ClientResourceStatus_REQUESTED + case xdsclient.ServiceStatusNotExist: + return v3adminpb.ClientResourceStatus_DOES_NOT_EXIST + case xdsclient.ServiceStatusACKed: + return v3adminpb.ClientResourceStatus_ACKED + case xdsclient.ServiceStatusNACKed: + return v3adminpb.ClientResourceStatus_NACKED + default: + return v3adminpb.ClientResourceStatus_UNKNOWN + } +} diff --git a/xds/csds/csds_test.go b/xds/csds/csds_test.go new file mode 100644 index 00000000000..9de83d37fec --- /dev/null +++ b/xds/csds/csds_test.go @@ -0,0 +1,739 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package csds + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds" + _ "google.golang.org/grpc/xds/internal/httpfilter/router" + xtestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" + + v3adminpb "github.com/envoyproxy/go-control-plane/envoy/admin/v3" + v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3statuspb "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + v3statuspbgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" +) + +const ( + defaultTestTimeout = 10 * time.Second +) + +var cmpOpts = cmp.Options{ + cmpopts.EquateEmpty(), + cmp.Comparer(func(a, b *timestamppb.Timestamp) bool { return true }), + protocmp.IgnoreFields(&v3adminpb.UpdateFailureState{}, "last_update_attempt", "details"), + protocmp.SortRepeated(func(a, b *v3adminpb.ListenersConfigDump_DynamicListener) bool { + return strings.Compare(a.Name, b.Name) < 0 + }), + protocmp.SortRepeated(func(a, b *v3adminpb.RoutesConfigDump_DynamicRouteConfig) bool { + if a.RouteConfig == nil { + return false + } + if b.RouteConfig == nil { + return true + } + var at, bt v3routepb.RouteConfiguration + if err := ptypes.UnmarshalAny(a.RouteConfig, &at); err != nil { + panic("failed to unmarshal RouteConfig" + err.Error()) + } + if err := ptypes.UnmarshalAny(b.RouteConfig, &bt); err != nil { + panic("failed to unmarshal RouteConfig" + err.Error()) + } + return strings.Compare(at.Name, bt.Name) < 0 + }), + protocmp.SortRepeated(func(a, b *v3adminpb.ClustersConfigDump_DynamicCluster) bool { + if a.Cluster == nil { + return false + } + if b.Cluster == nil { + return true + } + var at, bt v3clusterpb.Cluster + if err := ptypes.UnmarshalAny(a.Cluster, &at); err != nil { + panic("failed to unmarshal Cluster" + err.Error()) + } + if err := ptypes.UnmarshalAny(b.Cluster, &bt); err != nil { + panic("failed to unmarshal Cluster" + err.Error()) + } + return strings.Compare(at.Name, bt.Name) < 0 + }), + protocmp.SortRepeated(func(a, b *v3adminpb.EndpointsConfigDump_DynamicEndpointConfig) bool { + if a.EndpointConfig == nil { + return false + } + if b.EndpointConfig == nil { + return true + } + var at, bt v3endpointpb.ClusterLoadAssignment + if err := ptypes.UnmarshalAny(a.EndpointConfig, &at); err != nil { + panic("failed to unmarshal Endpoints" + err.Error()) + } + if err := ptypes.UnmarshalAny(b.EndpointConfig, &bt); err != nil { + panic("failed to unmarshal Endpoints" + err.Error()) + } + return strings.Compare(at.ClusterName, bt.ClusterName) < 0 + }), + protocmp.IgnoreFields(&v3adminpb.ListenersConfigDump_DynamicListenerState{}, "last_updated"), + protocmp.IgnoreFields(&v3adminpb.RoutesConfigDump_DynamicRouteConfig{}, "last_updated"), + protocmp.IgnoreFields(&v3adminpb.ClustersConfigDump_DynamicCluster{}, "last_updated"), + protocmp.IgnoreFields(&v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{}, "last_updated"), + protocmp.Transform(), +} + +var ( + ldsTargets = []string{"lds.target.good:0000", "lds.target.good:1111"} + listeners = make([]*v3listenerpb.Listener, len(ldsTargets)) + listenerAnys = make([]*anypb.Any, len(ldsTargets)) + + rdsTargets = []string{"route-config-0", "route-config-1"} + routes = make([]*v3routepb.RouteConfiguration, len(rdsTargets)) + routeAnys = make([]*anypb.Any, len(rdsTargets)) + + cdsTargets = []string{"cluster-0", "cluster-1"} + clusters = make([]*v3clusterpb.Cluster, len(cdsTargets)) + clusterAnys = make([]*anypb.Any, len(cdsTargets)) + + edsTargets = []string{"endpoints-0", "endpoints-1"} + endpoints = make([]*v3endpointpb.ClusterLoadAssignment, len(edsTargets)) + endpointAnys = make([]*anypb.Any, len(edsTargets)) + ips = []string{"0.0.0.0", "1.1.1.1"} + ports = []uint32{123, 456} +) + +func init() { + for i := range ldsTargets { + listeners[i] = e2e.DefaultClientListener(ldsTargets[i], rdsTargets[i]) + listenerAnys[i] = testutils.MarshalAny(listeners[i]) + } + for i := range rdsTargets { + routes[i] = e2e.DefaultRouteConfig(rdsTargets[i], ldsTargets[i], cdsTargets[i]) + routeAnys[i] = testutils.MarshalAny(routes[i]) + } + for i := range cdsTargets { + clusters[i] = e2e.DefaultCluster(cdsTargets[i], edsTargets[i], e2e.SecurityLevelNone) + clusterAnys[i] = testutils.MarshalAny(clusters[i]) + } + for i := range edsTargets { + endpoints[i] = e2e.DefaultEndpoint(edsTargets[i], ips[i], ports[i:i+1]) + endpointAnys[i] = testutils.MarshalAny(endpoints[i]) + } +} + +func TestCSDS(t *testing.T) { + const retryCount = 10 + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + xdsC, mgmServer, nodeID, stream, cleanup := commonSetup(ctx, t) + defer cleanup() + + for _, target := range ldsTargets { + xdsC.WatchListener(target, func(xdsclient.ListenerUpdate, error) {}) + } + for _, target := range rdsTargets { + xdsC.WatchRouteConfig(target, func(xdsclient.RouteConfigUpdate, error) {}) + } + for _, target := range cdsTargets { + xdsC.WatchCluster(target, func(xdsclient.ClusterUpdate, error) {}) + } + for _, target := range edsTargets { + xdsC.WatchEndpoints(target, func(xdsclient.EndpointsUpdate, error) {}) + } + + for i := 0; i < retryCount; i++ { + err := checkForRequested(stream) + if err == nil { + break + } + if i == retryCount-1 { + t.Fatalf("%v", err) + } + time.Sleep(time.Millisecond * 100) + } + + if err := mgmServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: listeners, + Routes: routes, + Clusters: clusters, + Endpoints: endpoints, + }); err != nil { + t.Fatal(err) + } + for i := 0; i < retryCount; i++ { + err := checkForACKed(stream) + if err == nil { + break + } + if i == retryCount-1 { + t.Fatalf("%v", err) + } + time.Sleep(time.Millisecond * 100) + } + + const nackResourceIdx = 0 + var ( + nackListeners = append([]*v3listenerpb.Listener{}, listeners...) + nackRoutes = append([]*v3routepb.RouteConfiguration{}, routes...) + nackClusters = append([]*v3clusterpb.Cluster{}, clusters...) + nackEndpoints = append([]*v3endpointpb.ClusterLoadAssignment{}, endpoints...) + ) + nackListeners[0] = &v3listenerpb.Listener{Name: ldsTargets[nackResourceIdx], ApiListener: &v3listenerpb.ApiListener{}} // 0 will be nacked. 1 will stay the same. + nackRoutes[0] = &v3routepb.RouteConfiguration{ + Name: rdsTargets[nackResourceIdx], VirtualHosts: []*v3routepb.VirtualHost{{Routes: []*v3routepb.Route{{}}}}, + } + nackClusters[0] = &v3clusterpb.Cluster{ + Name: cdsTargets[nackResourceIdx], ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, + } + nackEndpoints[0] = &v3endpointpb.ClusterLoadAssignment{ + ClusterName: edsTargets[nackResourceIdx], Endpoints: []*v3endpointpb.LocalityLbEndpoints{{}}, + } + if err := mgmServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: nackListeners, + Routes: nackRoutes, + Clusters: nackClusters, + Endpoints: nackEndpoints, + SkipValidation: true, + }); err != nil { + t.Fatal(err) + } + for i := 0; i < retryCount; i++ { + err := checkForNACKed(nackResourceIdx, stream) + if err == nil { + break + } + if i == retryCount-1 { + t.Fatalf("%v", err) + } + time.Sleep(time.Millisecond * 100) + } +} + +func commonSetup(ctx context.Context, t *testing.T) (xdsclient.XDSClient, *e2e.ManagementServer, string, v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient, func()) { + t.Helper() + + // Spin up a xDS management server on a local port. + nodeID := uuid.New().String() + fs, err := e2e.StartManagementServer() + if err != nil { + t.Fatal(err) + } + + // Create a bootstrap file in a temporary directory. + bootstrapCleanup, err := xds.SetupBootstrapFile(xds.BootstrapOptions{ + Version: xds.TransportV3, + NodeID: nodeID, + ServerURI: fs.Address, + }) + if err != nil { + t.Fatal(err) + } + // Create xds_client. + xdsC, err := xdsclient.New() + if err != nil { + t.Fatalf("failed to create xds client: %v", err) + } + oldNewXDSClient := newXDSClient + newXDSClient = func() xdsclient.XDSClient { return xdsC } + + // Initialize an gRPC server and register CSDS on it. + server := grpc.NewServer() + csdss, err := NewClientStatusDiscoveryServer() + if err != nil { + t.Fatal(err) + } + v3statuspbgrpc.RegisterClientStatusDiscoveryServiceServer(server, csdss) + // Create a local listener and pass it to Serve(). + lis, err := xtestutils.LocalTCPListener() + if err != nil { + t.Fatalf("xtestutils.LocalTCPListener() failed: %v", err) + } + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + // Create CSDS client. + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("cannot connect to server: %v", err) + } + c := v3statuspbgrpc.NewClientStatusDiscoveryServiceClient(conn) + stream, err := c.StreamClientStatus(ctx, grpc.WaitForReady(true)) + if err != nil { + t.Fatalf("cannot get ServerReflectionInfo: %v", err) + } + + return xdsC, fs, nodeID, stream, func() { + fs.Stop() + conn.Close() + server.Stop() + csdss.Close() + newXDSClient = oldNewXDSClient + xdsC.Close() + bootstrapCleanup() + } +} + +func checkForRequested(stream v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient) error { + if err := stream.Send(&v3statuspb.ClientStatusRequest{Node: nil}); err != nil { + return fmt.Errorf("failed to send request: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + return fmt.Errorf("failed to recv response: %v", err) + } + + if n := len(r.Config); n != 1 { + return fmt.Errorf("got %d configs, want 1: %v", n, proto.MarshalTextString(r)) + } + if n := len(r.Config[0].XdsConfig); n != 4 { + return fmt.Errorf("got %d xds configs (one for each type), want 4: %v", n, proto.MarshalTextString(r)) + } + for _, cfg := range r.Config[0].XdsConfig { + switch config := cfg.PerXdsConfig.(type) { + case *v3statuspb.PerXdsConfig_ListenerConfig: + var wantLis []*v3adminpb.ListenersConfigDump_DynamicListener + for i := range ldsTargets { + wantLis = append(wantLis, &v3adminpb.ListenersConfigDump_DynamicListener{ + Name: ldsTargets[i], + ClientStatus: v3adminpb.ClientResourceStatus_REQUESTED, + }) + } + wantDump := &v3adminpb.ListenersConfigDump{ + DynamicListeners: wantLis, + } + if diff := cmp.Diff(config.ListenerConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_RouteConfig: + var wantRoutes []*v3adminpb.RoutesConfigDump_DynamicRouteConfig + for range rdsTargets { + wantRoutes = append(wantRoutes, &v3adminpb.RoutesConfigDump_DynamicRouteConfig{ + ClientStatus: v3adminpb.ClientResourceStatus_REQUESTED, + }) + } + wantDump := &v3adminpb.RoutesConfigDump{ + DynamicRouteConfigs: wantRoutes, + } + if diff := cmp.Diff(config.RouteConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_ClusterConfig: + var wantCluster []*v3adminpb.ClustersConfigDump_DynamicCluster + for range cdsTargets { + wantCluster = append(wantCluster, &v3adminpb.ClustersConfigDump_DynamicCluster{ + ClientStatus: v3adminpb.ClientResourceStatus_REQUESTED, + }) + } + wantDump := &v3adminpb.ClustersConfigDump{ + DynamicActiveClusters: wantCluster, + } + if diff := cmp.Diff(config.ClusterConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_EndpointConfig: + var wantEndpoint []*v3adminpb.EndpointsConfigDump_DynamicEndpointConfig + for range cdsTargets { + wantEndpoint = append(wantEndpoint, &v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{ + ClientStatus: v3adminpb.ClientResourceStatus_REQUESTED, + }) + } + wantDump := &v3adminpb.EndpointsConfigDump{ + DynamicEndpointConfigs: wantEndpoint, + } + if diff := cmp.Diff(config.EndpointConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + default: + return fmt.Errorf("unexpected PerXdsConfig: %+v; %v", cfg.PerXdsConfig, protoToJSON(r)) + } + } + return nil +} + +func checkForACKed(stream v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient) error { + const wantVersion = "1" + + if err := stream.Send(&v3statuspb.ClientStatusRequest{Node: nil}); err != nil { + return fmt.Errorf("failed to send: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + return fmt.Errorf("failed to recv response: %v", err) + } + + if n := len(r.Config); n != 1 { + return fmt.Errorf("got %d configs, want 1: %v", n, proto.MarshalTextString(r)) + } + if n := len(r.Config[0].XdsConfig); n != 4 { + return fmt.Errorf("got %d xds configs (one for each type), want 4: %v", n, proto.MarshalTextString(r)) + } + for _, cfg := range r.Config[0].XdsConfig { + switch config := cfg.PerXdsConfig.(type) { + case *v3statuspb.PerXdsConfig_ListenerConfig: + var wantLis []*v3adminpb.ListenersConfigDump_DynamicListener + for i := range ldsTargets { + wantLis = append(wantLis, &v3adminpb.ListenersConfigDump_DynamicListener{ + Name: ldsTargets[i], + ActiveState: &v3adminpb.ListenersConfigDump_DynamicListenerState{ + VersionInfo: wantVersion, + Listener: listenerAnys[i], + LastUpdated: nil, + }, + ErrorState: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + }) + } + wantDump := &v3adminpb.ListenersConfigDump{ + VersionInfo: wantVersion, + DynamicListeners: wantLis, + } + if diff := cmp.Diff(config.ListenerConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_RouteConfig: + var wantRoutes []*v3adminpb.RoutesConfigDump_DynamicRouteConfig + for i := range rdsTargets { + wantRoutes = append(wantRoutes, &v3adminpb.RoutesConfigDump_DynamicRouteConfig{ + VersionInfo: wantVersion, + RouteConfig: routeAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + }) + } + wantDump := &v3adminpb.RoutesConfigDump{ + DynamicRouteConfigs: wantRoutes, + } + if diff := cmp.Diff(config.RouteConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_ClusterConfig: + var wantCluster []*v3adminpb.ClustersConfigDump_DynamicCluster + for i := range cdsTargets { + wantCluster = append(wantCluster, &v3adminpb.ClustersConfigDump_DynamicCluster{ + VersionInfo: wantVersion, + Cluster: clusterAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + }) + } + wantDump := &v3adminpb.ClustersConfigDump{ + VersionInfo: wantVersion, + DynamicActiveClusters: wantCluster, + } + if diff := cmp.Diff(config.ClusterConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_EndpointConfig: + var wantEndpoint []*v3adminpb.EndpointsConfigDump_DynamicEndpointConfig + for i := range cdsTargets { + wantEndpoint = append(wantEndpoint, &v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{ + VersionInfo: wantVersion, + EndpointConfig: endpointAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + }) + } + wantDump := &v3adminpb.EndpointsConfigDump{ + DynamicEndpointConfigs: wantEndpoint, + } + if diff := cmp.Diff(config.EndpointConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + default: + return fmt.Errorf("unexpected PerXdsConfig: %+v; %v", cfg.PerXdsConfig, protoToJSON(r)) + } + } + return nil +} + +func checkForNACKed(nackResourceIdx int, stream v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient) error { + const ( + ackVersion = "1" + nackVersion = "2" + ) + if err := stream.Send(&v3statuspb.ClientStatusRequest{Node: nil}); err != nil { + return fmt.Errorf("failed to send: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + return fmt.Errorf("failed to recv response: %v", err) + } + + if n := len(r.Config); n != 1 { + return fmt.Errorf("got %d configs, want 1: %v", n, proto.MarshalTextString(r)) + } + if n := len(r.Config[0].XdsConfig); n != 4 { + return fmt.Errorf("got %d xds configs (one for each type), want 4: %v", n, proto.MarshalTextString(r)) + } + for _, cfg := range r.Config[0].XdsConfig { + switch config := cfg.PerXdsConfig.(type) { + case *v3statuspb.PerXdsConfig_ListenerConfig: + var wantLis []*v3adminpb.ListenersConfigDump_DynamicListener + for i := range ldsTargets { + configDump := &v3adminpb.ListenersConfigDump_DynamicListener{ + Name: ldsTargets[i], + ActiveState: &v3adminpb.ListenersConfigDump_DynamicListenerState{ + VersionInfo: nackVersion, + Listener: listenerAnys[i], + LastUpdated: nil, + }, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + } + if i == nackResourceIdx { + configDump.ActiveState.VersionInfo = ackVersion + configDump.ClientStatus = v3adminpb.ClientResourceStatus_NACKED + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + Details: "blahblah", + VersionInfo: nackVersion, + } + } + wantLis = append(wantLis, configDump) + } + wantDump := &v3adminpb.ListenersConfigDump{ + VersionInfo: nackVersion, + DynamicListeners: wantLis, + } + if diff := cmp.Diff(config.ListenerConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_RouteConfig: + var wantRoutes []*v3adminpb.RoutesConfigDump_DynamicRouteConfig + for i := range rdsTargets { + configDump := &v3adminpb.RoutesConfigDump_DynamicRouteConfig{ + VersionInfo: nackVersion, + RouteConfig: routeAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + } + if i == nackResourceIdx { + configDump.VersionInfo = ackVersion + configDump.ClientStatus = v3adminpb.ClientResourceStatus_NACKED + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + Details: "blahblah", + VersionInfo: nackVersion, + } + } + wantRoutes = append(wantRoutes, configDump) + } + wantDump := &v3adminpb.RoutesConfigDump{ + DynamicRouteConfigs: wantRoutes, + } + if diff := cmp.Diff(config.RouteConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_ClusterConfig: + var wantCluster []*v3adminpb.ClustersConfigDump_DynamicCluster + for i := range cdsTargets { + configDump := &v3adminpb.ClustersConfigDump_DynamicCluster{ + VersionInfo: nackVersion, + Cluster: clusterAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + } + if i == nackResourceIdx { + configDump.VersionInfo = ackVersion + configDump.ClientStatus = v3adminpb.ClientResourceStatus_NACKED + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + Details: "blahblah", + VersionInfo: nackVersion, + } + } + wantCluster = append(wantCluster, configDump) + } + wantDump := &v3adminpb.ClustersConfigDump{ + VersionInfo: nackVersion, + DynamicActiveClusters: wantCluster, + } + if diff := cmp.Diff(config.ClusterConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_EndpointConfig: + var wantEndpoint []*v3adminpb.EndpointsConfigDump_DynamicEndpointConfig + for i := range cdsTargets { + configDump := &v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{ + VersionInfo: nackVersion, + EndpointConfig: endpointAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + } + if i == nackResourceIdx { + configDump.VersionInfo = ackVersion + configDump.ClientStatus = v3adminpb.ClientResourceStatus_NACKED + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + Details: "blahblah", + VersionInfo: nackVersion, + } + } + wantEndpoint = append(wantEndpoint, configDump) + } + wantDump := &v3adminpb.EndpointsConfigDump{ + DynamicEndpointConfigs: wantEndpoint, + } + if diff := cmp.Diff(config.EndpointConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + default: + return fmt.Errorf("unexpected PerXdsConfig: %+v; %v", cfg.PerXdsConfig, protoToJSON(r)) + } + } + return nil +} + +func protoToJSON(p proto.Message) string { + mm := jsonpb.Marshaler{ + Indent: " ", + } + ret, _ := mm.MarshalToString(p) + return ret +} + +func TestCSDSNoXDSClient(t *testing.T) { + oldNewXDSClient := newXDSClient + newXDSClient = func() xdsclient.XDSClient { return nil } + defer func() { newXDSClient = oldNewXDSClient }() + + // Initialize an gRPC server and register CSDS on it. + server := grpc.NewServer() + csdss, err := NewClientStatusDiscoveryServer() + if err != nil { + t.Fatal(err) + } + defer csdss.Close() + v3statuspbgrpc.RegisterClientStatusDiscoveryServiceServer(server, csdss) + // Create a local listener and pass it to Serve(). + lis, err := xtestutils.LocalTCPListener() + if err != nil { + t.Fatalf("xtestutils.LocalTCPListener() failed: %v", err) + } + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + defer server.Stop() + + // Create CSDS client. + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("cannot connect to server: %v", err) + } + defer conn.Close() + c := v3statuspbgrpc.NewClientStatusDiscoveryServiceClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := c.StreamClientStatus(ctx, grpc.WaitForReady(true)) + if err != nil { + t.Fatalf("cannot get ServerReflectionInfo: %v", err) + } + + if err := stream.Send(&v3statuspb.ClientStatusRequest{Node: nil}); err != nil { + t.Fatalf("failed to send: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + t.Fatalf("failed to recv response: %v", err) + } + if n := len(r.Config); n != 0 { + t.Fatalf("got %d configs, want 0: %v", n, proto.MarshalTextString(r)) + } +} + +func Test_nodeProtoToV3(t *testing.T) { + const ( + testID = "test-id" + testCluster = "test-cluster" + testZone = "test-zone" + ) + tests := []struct { + name string + n proto.Message + want *v3corepb.Node + }{ + { + name: "v3", + n: &v3corepb.Node{ + Id: testID, + Cluster: testCluster, + Locality: &v3corepb.Locality{Zone: testZone}, + }, + want: &v3corepb.Node{ + Id: testID, + Cluster: testCluster, + Locality: &v3corepb.Locality{Zone: testZone}, + }, + }, + { + name: "v2", + n: &v2corepb.Node{ + Id: testID, + Cluster: testCluster, + Locality: &v2corepb.Locality{Zone: testZone}, + }, + want: &v3corepb.Node{ + Id: testID, + Cluster: testCluster, + Locality: &v3corepb.Locality{Zone: testZone}, + }, + }, + { + name: "not node", + n: &v2corepb.Locality{Zone: testZone}, + want: nil, // Input is not a node, should return nil. + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := nodeProtoToV3(tt.n) + if diff := cmp.Diff(got, tt.want, protocmp.Transform()); diff != "" { + t.Errorf("nodeProtoToV3() got unexpected result, diff (-got, +want): %v", diff) + } + }) + } +} diff --git a/xds/googledirectpath/googlec2p.go b/xds/googledirectpath/googlec2p.go new file mode 100644 index 00000000000..b9f1c712014 --- /dev/null +++ b/xds/googledirectpath/googlec2p.go @@ -0,0 +1,178 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package googledirectpath implements a resolver that configures xds to make +// cloud to prod directpath connection. +// +// It's a combo of DNS and xDS resolvers. It delegates to DNS if +// - not on GCE, or +// - xDS bootstrap env var is set (so this client needs to do normal xDS, not +// direct path, and clients with this scheme is not part of the xDS mesh). +package googledirectpath + +import ( + "fmt" + "time" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/google" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/googlecloud" + internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/resolver" + _ "google.golang.org/grpc/xds" // To register xds resolvers and balancers. + "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/protobuf/types/known/structpb" +) + +const ( + c2pScheme = "google-c2p" + + tdURL = "directpath-pa.googleapis.com" + httpReqTimeout = 10 * time.Second + zoneURL = "http://metadata.google.internal/computeMetadata/v1/instance/zone" + ipv6URL = "http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/ipv6s" + + gRPCUserAgentName = "gRPC Go" + clientFeatureNoOverprovisioning = "envoy.lb.does_not_support_overprovisioning" + ipv6CapableMetadataName = "TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE" + + logPrefix = "[google-c2p-resolver]" + + dnsName, xdsName = "dns", "xds" +) + +// For overriding in unittests. +var ( + onGCE = googlecloud.OnGCE + + newClientWithConfig = func(config *bootstrap.Config) (xdsclient.XDSClient, error) { + return xdsclient.NewWithConfig(config) + } + + logger = internalgrpclog.NewPrefixLogger(grpclog.Component("directpath"), logPrefix) +) + +func init() { + if env.C2PResolverSupport { + resolver.Register(c2pResolverBuilder{}) + } +} + +type c2pResolverBuilder struct{} + +func (c2pResolverBuilder) Build(t resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { + if !runDirectPath() { + // If not xDS, fallback to DNS. + t.Scheme = dnsName + return resolver.Get(dnsName).Build(t, cc, opts) + } + + // Note that the following calls to getZone() and getIPv6Capable() does I/O, + // and has 10 seconds timeout each. + // + // This should be fine in most of the cases. In certain error cases, this + // could block Dial() for up to 10 seconds (each blocking call has its own + // goroutine). + zoneCh, ipv6CapableCh := make(chan string), make(chan bool) + go func() { zoneCh <- getZone(httpReqTimeout) }() + go func() { ipv6CapableCh <- getIPv6Capable(httpReqTimeout) }() + + balancerName := env.C2PResolverTestOnlyTrafficDirectorURI + if balancerName == "" { + balancerName = tdURL + } + config := &bootstrap.Config{ + BalancerName: balancerName, + Creds: grpc.WithCredentialsBundle(google.NewDefaultCredentials()), + TransportAPI: version.TransportV3, + NodeProto: newNode(<-zoneCh, <-ipv6CapableCh), + } + + // Create singleton xds client with this config. The xds client will be + // used by the xds resolver later. + xdsC, err := newClientWithConfig(config) + if err != nil { + return nil, fmt.Errorf("failed to start xDS client: %v", err) + } + + // Create and return an xDS resolver. + t.Scheme = xdsName + xdsR, err := resolver.Get(xdsName).Build(t, cc, opts) + if err != nil { + xdsC.Close() + return nil, err + } + return &c2pResolver{ + Resolver: xdsR, + client: xdsC, + }, nil +} + +func (c2pResolverBuilder) Scheme() string { + return c2pScheme +} + +type c2pResolver struct { + resolver.Resolver + client xdsclient.XDSClient +} + +func (r *c2pResolver) Close() { + r.Resolver.Close() + r.client.Close() +} + +var ipv6EnabledMetadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + ipv6CapableMetadataName: structpb.NewBoolValue(true), + }, +} + +var id = fmt.Sprintf("C2P-%d", grpcrand.Int()) + +// newNode makes a copy of defaultNode, and populate it's Metadata and +// Locality fields. +func newNode(zone string, ipv6Capable bool) *v3corepb.Node { + ret := &v3corepb.Node{ + // Not all required fields are set in defaultNote. Metadata will be set + // if ipv6 is enabled. Locality will be set to the value from metadata. + Id: id, + UserAgentName: gRPCUserAgentName, + UserAgentVersionType: &v3corepb.Node_UserAgentVersion{UserAgentVersion: grpc.Version}, + ClientFeatures: []string{clientFeatureNoOverprovisioning}, + } + ret.Locality = &v3corepb.Locality{Zone: zone} + if ipv6Capable { + ret.Metadata = ipv6EnabledMetadata + } + return ret +} + +// runDirectPath returns whether this resolver should use direct path. +// +// direct path is enabled if this client is running on GCE, and the normal xDS +// is not used (bootstrap env vars are not set). +func runDirectPath() bool { + return env.BootstrapFileName == "" && env.BootstrapFileContent == "" && onGCE() +} diff --git a/xds/googledirectpath/googlec2p_test.go b/xds/googledirectpath/googlec2p_test.go new file mode 100644 index 00000000000..a208fad66c5 --- /dev/null +++ b/xds/googledirectpath/googlec2p_test.go @@ -0,0 +1,242 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package googledirectpath + +import ( + "strconv" + "testing" + "time" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/structpb" +) + +type emptyResolver struct { + resolver.Resolver + scheme string +} + +func (er *emptyResolver) Build(_ resolver.Target, _ resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + return er, nil +} + +func (er *emptyResolver) Scheme() string { + return er.scheme +} + +func (er *emptyResolver) Close() {} + +var ( + testDNSResolver = &emptyResolver{scheme: "dns"} + testXDSResolver = &emptyResolver{scheme: "xds"} +) + +func replaceResolvers() func() { + var registerForTesting bool + if resolver.Get(c2pScheme) == nil { + // If env var to enable c2p is not set, the resolver isn't registered. + // Need to register and unregister in defer. + registerForTesting = true + resolver.Register(&c2pResolverBuilder{}) + } + oldDNS := resolver.Get("dns") + resolver.Register(testDNSResolver) + oldXDS := resolver.Get("xds") + resolver.Register(testXDSResolver) + return func() { + if oldDNS != nil { + resolver.Register(oldDNS) + } else { + resolver.UnregisterForTesting("dns") + } + if oldXDS != nil { + resolver.Register(oldXDS) + } else { + resolver.UnregisterForTesting("xds") + } + if registerForTesting { + resolver.UnregisterForTesting(c2pScheme) + } + } +} + +// Test that when bootstrap env is set, fallback to DNS. +func TestBuildWithBootstrapEnvSet(t *testing.T) { + defer replaceResolvers()() + builder := resolver.Get(c2pScheme) + + for i, envP := range []*string{&env.BootstrapFileName, &env.BootstrapFileContent} { + t.Run(strconv.Itoa(i), func(t *testing.T) { + // Set bootstrap config env var. + oldEnv := *envP + *envP = "does not matter" + defer func() { *envP = oldEnv }() + + // Build should return DNS, not xDS. + r, err := builder.Build(resolver.Target{}, nil, resolver.BuildOptions{}) + if err != nil { + t.Fatalf("failed to build resolver: %v", err) + } + if r != testDNSResolver { + t.Fatalf("want dns resolver, got %#v", r) + } + }) + } +} + +// Test that when not on GCE, fallback to DNS. +func TestBuildNotOnGCE(t *testing.T) { + defer replaceResolvers()() + builder := resolver.Get(c2pScheme) + + oldOnGCE := onGCE + onGCE = func() bool { return false } + defer func() { onGCE = oldOnGCE }() + + // Build should return DNS, not xDS. + r, err := builder.Build(resolver.Target{}, nil, resolver.BuildOptions{}) + if err != nil { + t.Fatalf("failed to build resolver: %v", err) + } + if r != testDNSResolver { + t.Fatalf("want dns resolver, got %#v", r) + } +} + +type testXDSClient struct { + xdsclient.XDSClient + closed chan struct{} +} + +func (c *testXDSClient) Close() { + c.closed <- struct{}{} +} + +// Test that when xDS is built, the client is built with the correct config. +func TestBuildXDS(t *testing.T) { + defer replaceResolvers()() + builder := resolver.Get(c2pScheme) + + oldOnGCE := onGCE + onGCE = func() bool { return true } + defer func() { onGCE = oldOnGCE }() + + const testZone = "test-zone" + oldGetZone := getZone + getZone = func(time.Duration) string { return testZone } + defer func() { getZone = oldGetZone }() + + for _, tt := range []struct { + name string + ipv6 bool + tdURI string // traffic director URI will be overridden if this is set. + }{ + {name: "ipv6 true", ipv6: true}, + {name: "ipv6 false", ipv6: false}, + {name: "override TD URI", ipv6: true, tdURI: "test-uri"}, + } { + t.Run(tt.name, func(t *testing.T) { + oldGetIPv6Capability := getIPv6Capable + getIPv6Capable = func(time.Duration) bool { return tt.ipv6 } + defer func() { getIPv6Capable = oldGetIPv6Capability }() + + if tt.tdURI != "" { + oldURI := env.C2PResolverTestOnlyTrafficDirectorURI + env.C2PResolverTestOnlyTrafficDirectorURI = tt.tdURI + defer func() { + env.C2PResolverTestOnlyTrafficDirectorURI = oldURI + }() + } + + tXDSClient := &testXDSClient{closed: make(chan struct{}, 1)} + + configCh := make(chan *bootstrap.Config, 1) + oldNewClient := newClientWithConfig + newClientWithConfig = func(config *bootstrap.Config) (xdsclient.XDSClient, error) { + configCh <- config + return tXDSClient, nil + } + defer func() { newClientWithConfig = oldNewClient }() + + // Build should return DNS, not xDS. + r, err := builder.Build(resolver.Target{}, nil, resolver.BuildOptions{}) + if err != nil { + t.Fatalf("failed to build resolver: %v", err) + } + rr := r.(*c2pResolver) + if rrr := rr.Resolver; rrr != testXDSResolver { + t.Fatalf("want xds resolver, got %#v, ", rrr) + } + + wantNode := &v3corepb.Node{ + Id: id, + Metadata: nil, + Locality: &v3corepb.Locality{Zone: testZone}, + UserAgentName: gRPCUserAgentName, + UserAgentVersionType: &v3corepb.Node_UserAgentVersion{UserAgentVersion: grpc.Version}, + ClientFeatures: []string{clientFeatureNoOverprovisioning}, + } + if tt.ipv6 { + wantNode.Metadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + ipv6CapableMetadataName: { + Kind: &structpb.Value_BoolValue{BoolValue: true}, + }, + }, + } + } + wantConfig := &bootstrap.Config{ + BalancerName: tdURL, + TransportAPI: version.TransportV3, + NodeProto: wantNode, + } + if tt.tdURI != "" { + wantConfig.BalancerName = tt.tdURI + } + cmpOpts := cmp.Options{ + cmpopts.IgnoreFields(bootstrap.Config{}, "Creds"), + protocmp.Transform(), + } + select { + case c := <-configCh: + if diff := cmp.Diff(c, wantConfig, cmpOpts); diff != "" { + t.Fatalf("%v", diff) + } + case <-time.After(time.Second): + t.Fatalf("timeout waiting for client config") + } + + r.Close() + select { + case <-tXDSClient.closed: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for client close") + } + }) + } +} diff --git a/xds/googledirectpath/utils.go b/xds/googledirectpath/utils.go new file mode 100644 index 00000000000..60044197978 --- /dev/null +++ b/xds/googledirectpath/utils.go @@ -0,0 +1,96 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package googledirectpath + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "sync" + "time" +) + +func getFromMetadata(timeout time.Duration, urlStr string) ([]byte, error) { + parsedURL, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + client := &http.Client{Timeout: timeout} + req := &http.Request{ + Method: http.MethodGet, + URL: parsedURL, + Header: http.Header{"Metadata-Flavor": {"Google"}}, + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed communicating with metadata server: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("metadata server returned resp with non-OK: %v", resp) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed reading from metadata server: %v", err) + } + return body, nil +} + +var ( + zone string + zoneOnce sync.Once +) + +// Defined as var to be overridden in tests. +var getZone = func(timeout time.Duration) string { + zoneOnce.Do(func() { + qualifiedZone, err := getFromMetadata(timeout, zoneURL) + if err != nil { + logger.Warningf("could not discover instance zone: %v", err) + return + } + i := bytes.LastIndexByte(qualifiedZone, '/') + if i == -1 { + logger.Warningf("could not parse zone from metadata server: %s", qualifiedZone) + return + } + zone = string(qualifiedZone[i+1:]) + }) + return zone +} + +var ( + ipv6Capable bool + ipv6CapableOnce sync.Once +) + +// Defined as var to be overridden in tests. +var getIPv6Capable = func(timeout time.Duration) bool { + ipv6CapableOnce.Do(func() { + _, err := getFromMetadata(timeout, ipv6URL) + if err != nil { + logger.Warningf("could not discover ipv6 capability: %v", err) + return + } + ipv6Capable = true + }) + return ipv6Capable +} diff --git a/xds/internal/balancer/balancer.go b/xds/internal/balancer/balancer.go index 5883027a2c5..86656736a61 100644 --- a/xds/internal/balancer/balancer.go +++ b/xds/internal/balancer/balancer.go @@ -20,8 +20,10 @@ package balancer import ( - _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // Register the CDS balancer - _ "google.golang.org/grpc/xds/internal/balancer/clustermanager" // Register the xds_cluster_manager balancer - _ "google.golang.org/grpc/xds/internal/balancer/edsbalancer" // Register the EDS balancer - _ "google.golang.org/grpc/xds/internal/balancer/weightedtarget" // Register the weighted_target balancer + _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // Register the CDS balancer + _ "google.golang.org/grpc/xds/internal/balancer/clusterimpl" // Register the xds_cluster_impl balancer + _ "google.golang.org/grpc/xds/internal/balancer/clustermanager" // Register the xds_cluster_manager balancer + _ "google.golang.org/grpc/xds/internal/balancer/clusterresolver" // Register the xds_cluster_resolver balancer + _ "google.golang.org/grpc/xds/internal/balancer/priority" // Register the priority balancer + _ "google.golang.org/grpc/xds/internal/balancer/weightedtarget" // Register the weighted_target balancer ) diff --git a/xds/internal/balancer/balancergroup/balancergroup.go b/xds/internal/balancer/balancergroup/balancergroup.go index 2ec576a4b57..5798b03ac50 100644 --- a/xds/internal/balancer/balancergroup/balancergroup.go +++ b/xds/internal/balancer/balancergroup/balancergroup.go @@ -24,7 +24,7 @@ import ( "time" orcapb "github.com/cncf/udpa/go/udpa/data/orca/v1" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" @@ -104,6 +104,22 @@ func (sbc *subBalancerWrapper) startBalancer() { } } +func (sbc *subBalancerWrapper) exitIdle() { + b := sbc.balancer + if b == nil { + return + } + if ei, ok := b.(balancer.ExitIdler); ok { + ei.ExitIdle() + return + } + for sc, b := range sbc.group.scToSubBalancer { + if b == sbc { + sc.Connect() + } + } +} + func (sbc *subBalancerWrapper) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { b := sbc.balancer if b == nil { @@ -183,7 +199,7 @@ type BalancerGroup struct { cc balancer.ClientConn buildOpts balancer.BuildOptions logger *grpclog.PrefixLogger - loadStore load.PerClusterReporter + loadStore load.PerClusterReporter // TODO: delete this, no longer needed. It was used by EDS. // stateAggregator is where the state/picker updates will be sent to. It's // provided by the parent balancer, to build a picker with all the @@ -479,6 +495,10 @@ func (bg *BalancerGroup) Close() { } bg.incomingMu.Unlock() + // Clear(true) runs clear function to close sub-balancers in cache. It + // must be called out of outgoing mutex. + bg.balancerCache.Clear(true) + bg.outgoingMu.Lock() if bg.outgoingStarted { bg.outgoingStarted = false @@ -487,9 +507,17 @@ func (bg *BalancerGroup) Close() { } } bg.outgoingMu.Unlock() - // Clear(true) runs clear function to close sub-balancers in cache. It - // must be called out of outgoing mutex. - bg.balancerCache.Clear(true) +} + +// ExitIdle should be invoked when the parent LB policy's ExitIdle is invoked. +// It will trigger this on all sub-balancers, or reconnect their subconns if +// not supported. +func (bg *BalancerGroup) ExitIdle() { + bg.outgoingMu.Lock() + for _, config := range bg.idToBalancerConfig { + config.exitIdle() + } + bg.outgoingMu.Unlock() } const ( diff --git a/xds/internal/balancer/balancergroup/balancergroup_test.go b/xds/internal/balancer/balancergroup/balancergroup_test.go index 0ad4bf8df10..9cc7bd072ec 100644 --- a/xds/internal/balancer/balancergroup/balancergroup_test.go +++ b/xds/internal/balancer/balancergroup/balancergroup_test.go @@ -42,8 +42,8 @@ import ( "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/resolver" "google.golang.org/grpc/xds/internal/balancer/weightedtarget/weightedaggregator" - "google.golang.org/grpc/xds/internal/client/load" "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) var ( @@ -843,7 +843,7 @@ func (s) TestBalancerGroup_locality_caching_not_readd_within_timeout(t *testing. defer replaceDefaultSubBalancerCloseTimeout(time.Second)() _, _, cc, addrToSC := initBalancerGroupForCachingTest(t) - // The sub-balancer is not re-added withtin timeout. The subconns should be + // The sub-balancer is not re-added within timeout. The subconns should be // removed. removeTimeout := time.After(DefaultSubBalancerCloseTimeout) scToRemove := map[balancer.SubConn]int{ @@ -938,6 +938,36 @@ func (s) TestBalancerGroup_locality_caching_readd_with_different_builder(t *test } } +// After removing a sub-balancer, it will be kept in cache. Make sure that this +// sub-balancer's Close is called when the balancer group is closed. +func (s) TestBalancerGroup_CloseStopsBalancerInCache(t *testing.T) { + const balancerName = "stub-TestBalancerGroup_check_close" + closed := make(chan struct{}) + stub.Register(balancerName, stub.BalancerFuncs{Close: func(_ *stub.BalancerData) { + close(closed) + }}) + builder := balancer.Get(balancerName) + + defer replaceDefaultSubBalancerCloseTimeout(time.Second)() + gator, bg, _, _ := initBalancerGroupForCachingTest(t) + + // Add balancer, and remove + gator.Add(testBalancerIDs[2], 1) + bg.Add(testBalancerIDs[2], builder) + gator.Remove(testBalancerIDs[2]) + bg.Remove(testBalancerIDs[2]) + + // Immediately close balancergroup, before the cache timeout. + bg.Close() + + // Make sure the removed child balancer is closed eventually. + select { + case <-closed: + case <-time.After(time.Second * 2): + t.Fatalf("timeout waiting for the child balancer in cache to be closed") + } +} + // TestBalancerGroupBuildOptions verifies that the balancer.BuildOptions passed // to the balancergroup at creation time is passed to child policies. func (s) TestBalancerGroupBuildOptions(t *testing.T) { diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go index e4d349753e1..82d2a96958e 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go @@ -31,66 +31,57 @@ import ( xdsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/pretty" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal/balancer/edsbalancer" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/xds/internal/balancer/clusterresolver" + "google.golang.org/grpc/xds/internal/balancer/ringhash" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( cdsName = "cds_experimental" - edsName = "eds_experimental" ) var ( errBalancerClosed = errors.New("cdsBalancer is closed") - // newEDSBalancer is a helper function to build a new edsBalancer and will be - // overridden in unittests. - newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { - builder := balancer.Get(edsName) + // newChildBalancer is a helper function to build a new cluster_resolver + // balancer and will be overridden in unittests. + newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { + builder := balancer.Get(clusterresolver.Name) if builder == nil { - return nil, fmt.Errorf("xds: no balancer builder with name %v", edsName) + return nil, fmt.Errorf("xds: no balancer builder with name %v", clusterresolver.Name) } - // We directly pass the parent clientConn to the - // underlying edsBalancer because the cdsBalancer does - // not deal with subConns. + // We directly pass the parent clientConn to the underlying + // cluster_resolver balancer because the cdsBalancer does not deal with + // subConns. return builder.Build(cc, opts), nil } - newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } buildProvider = buildProviderFunc ) func init() { - balancer.Register(cdsBB{}) + balancer.Register(bb{}) } -// cdsBB (short for cdsBalancerBuilder) implements the balancer.Builder -// interface to help build a cdsBalancer. +// bb implements the balancer.Builder interface to help build a cdsBalancer. // It also implements the balancer.ConfigParser interface to help parse the // JSON service config, to be passed to the cdsBalancer. -type cdsBB struct{} +type bb struct{} // Build creates a new CDS balancer with the ClientConn. -func (cdsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { b := &cdsBalancer{ - bOpts: opts, - updateCh: buffer.NewUnbounded(), - closed: grpcsync.NewEvent(), - cancelWatch: func() {}, // No-op at this point. - xdsHI: xdsinternal.NewHandshakeInfo(nil, nil), + bOpts: opts, + updateCh: buffer.NewUnbounded(), + closed: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), + xdsHI: xdsinternal.NewHandshakeInfo(nil, nil), } b.logger = prefixLogger((b)) b.logger.Infof("Created") - - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.xdsClient = client - var creds credentials.TransportCredentials switch { case opts.DialCreds != nil: @@ -102,7 +93,7 @@ func (cdsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer. b.xdsCredsInUse = true } b.logger.Infof("xDS credentials in use: %v", b.xdsCredsInUse) - + b.clusterHandler = newClusterHandler(b) b.ccw = &ccWrapper{ ClientConn: cc, xdsHI: b.xdsHI, @@ -112,7 +103,7 @@ func (cdsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer. } // Name returns the name of balancers built by this builder. -func (cdsBB) Name() string { +func (bb) Name() string { return cdsName } @@ -125,7 +116,7 @@ type lbConfig struct { // ParseConfig parses the JSON load balancer config provided into an // internal form or returns an error if the config is invalid. -func (cdsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { var cfg lbConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, fmt.Errorf("xds: unable to unmarshal lbconfig: %s, error: %v", string(c), err) @@ -133,54 +124,40 @@ func (cdsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, return &cfg, nil } -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the cdsBalancer. This will be faked out in unittests. -type xdsClientInterface interface { - WatchCluster(string, func(xdsclient.ClusterUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - // ccUpdate wraps a clientConn update received from gRPC (pushed from the // xdsResolver). A valid clusterName causes the cdsBalancer to register a CDS // watcher with the xdsClient, while a non-nil error causes it to cancel the -// existing watch and propagate the error to the underlying edsBalancer. +// existing watch and propagate the error to the underlying cluster_resolver +// balancer. type ccUpdate struct { clusterName string err error } // scUpdate wraps a subConn update received from gRPC. This is directly passed -// on to the edsBalancer. +// on to the cluster_resolver balancer. type scUpdate struct { subConn balancer.SubConn state balancer.SubConnState } -// watchUpdate wraps the information received from a registered CDS watcher. A -// non-nil error is propagated to the underlying edsBalancer. A valid update -// results in creating a new edsBalancer (if one doesn't already exist) and -// pushing the update to it. -type watchUpdate struct { - cds xdsclient.ClusterUpdate - err error -} +type exitIdle struct{} -// cdsBalancer implements a CDS based LB policy. It instantiates an EDS based -// LB policy to further resolve the serviceName received from CDS, into -// localities and endpoints. Implements the balancer.Balancer interface which -// is exposed to gRPC and implements the balancer.ClientConn interface which is -// exposed to the edsBalancer. +// cdsBalancer implements a CDS based LB policy. It instantiates a +// cluster_resolver balancer to further resolve the serviceName received from +// CDS, into localities and endpoints. Implements the balancer.Balancer +// interface which is exposed to gRPC and implements the balancer.ClientConn +// interface which is exposed to the cluster_resolver balancer. type cdsBalancer struct { ccw *ccWrapper // ClientConn interface passed to child LB. bOpts balancer.BuildOptions // BuildOptions passed to child LB. updateCh *buffer.Unbounded // Channel for gRPC and xdsClient updates. - xdsClient xdsClientInterface // xDS client to watch Cluster resource. - cancelWatch func() // Cluster watch cancel func. - edsLB balancer.Balancer // EDS child policy. - clusterToWatch string + xdsClient xdsclient.XDSClient // xDS client to watch Cluster resource. + clusterHandler *clusterHandler // To watch the clusters. + childLB balancer.Balancer logger *grpclog.PrefixLogger closed *grpcsync.Event + done *grpcsync.Event // The certificate providers are cached here to that they can be closed when // a new provider is to be created. @@ -193,25 +170,15 @@ type cdsBalancer struct { // handleClientConnUpdate handles a ClientConnUpdate received from gRPC. Good // updates lead to registration of a CDS watch. Updates with error lead to // cancellation of existing watch and propagation of the same error to the -// edsBalancer. +// cluster_resolver balancer. func (b *cdsBalancer) handleClientConnUpdate(update *ccUpdate) { // We first handle errors, if any, and then proceed with handling the // update, only if the status quo has changed. if err := update.err; err != nil { b.handleErrorFromUpdate(err, true) - } - if b.clusterToWatch == update.clusterName { return } - if update.clusterName != "" { - cancelWatch := b.xdsClient.WatchCluster(update.clusterName, b.handleClusterUpdate) - b.logger.Infof("Watch started on resource name %v with xds-client %p", update.clusterName, b.xdsClient) - b.cancelWatch = func() { - cancelWatch() - b.logger.Infof("Watch cancelled on resource name %v with xds-client %p", update.clusterName, b.xdsClient) - } - b.clusterToWatch = update.clusterName - } + b.clusterHandler.updateRootCluster(update.clusterName) } // handleSecurityConfig processes the security configuration received from the @@ -235,7 +202,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err // one where fallback credentials are to be used. b.xdsHI.SetRootCertProvider(nil) b.xdsHI.SetIdentityCertProvider(nil) - b.xdsHI.SetAcceptedSANs(nil) + b.xdsHI.SetSANMatchers(nil) return nil } @@ -278,7 +245,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err // could have been non-nil earlier. b.xdsHI.SetRootCertProvider(rootProvider) b.xdsHI.SetIdentityCertProvider(identityProvider) - b.xdsHI.SetAcceptedSANs(config.AcceptedSANs) + b.xdsHI.SetSANMatchers(config.SubjectAltNameMatchers) return nil } @@ -303,22 +270,22 @@ func buildProviderFunc(configs map[string]*certprovider.BuildableConfig, instanc } // handleWatchUpdate handles a watch update from the xDS Client. Good updates -// lead to clientConn updates being invoked on the underlying edsBalancer. -func (b *cdsBalancer) handleWatchUpdate(update *watchUpdate) { +// lead to clientConn updates being invoked on the underlying cluster_resolver balancer. +func (b *cdsBalancer) handleWatchUpdate(update clusterHandlerUpdate) { if err := update.err; err != nil { b.logger.Warningf("Watch error from xds-client %p: %v", b.xdsClient, err) b.handleErrorFromUpdate(err, false) return } - b.logger.Infof("Watch update from xds-client %p, content: %+v", b.xdsClient, update.cds) + b.logger.Infof("Watch update from xds-client %p, content: %+v, security config: %v", b.xdsClient, pretty.ToJSON(update.updates), pretty.ToJSON(update.securityCfg)) // Process the security config from the received update before building the // child policy or forwarding the update to it. We do this because the child // policy may try to create a new subConn inline. Processing the security // configuration here and setting up the handshakeInfo will make sure that // such attempts are handled properly. - if err := b.handleSecurityConfig(update.cds.SecurityCfg); err != nil { + if err := b.handleSecurityConfig(update.securityCfg); err != nil { // If the security config is invalid, for example, if the provider // instance is not found in the bootstrap config, we need to put the // channel in transient failure. @@ -328,33 +295,67 @@ func (b *cdsBalancer) handleWatchUpdate(update *watchUpdate) { } // The first good update from the watch API leads to the instantiation of an - // edsBalancer. Further updates/errors are propagated to the existing - // edsBalancer. - if b.edsLB == nil { - edsLB, err := newEDSBalancer(b.ccw, b.bOpts) + // cluster_resolver balancer. Further updates/errors are propagated to the existing + // cluster_resolver balancer. + if b.childLB == nil { + childLB, err := newChildBalancer(b.ccw, b.bOpts) if err != nil { - b.logger.Errorf("Failed to create child policy of type %s, %v", edsName, err) + b.logger.Errorf("Failed to create child policy of type %s, %v", clusterresolver.Name, err) return } - b.edsLB = edsLB - b.logger.Infof("Created child policy %p of type %s", b.edsLB, edsName) + b.childLB = childLB + b.logger.Infof("Created child policy %p of type %s", b.childLB, clusterresolver.Name) + } + + dms := make([]clusterresolver.DiscoveryMechanism, len(update.updates)) + for i, cu := range update.updates { + switch cu.ClusterType { + case xdsclient.ClusterTypeEDS: + dms[i] = clusterresolver.DiscoveryMechanism{ + Type: clusterresolver.DiscoveryMechanismTypeEDS, + Cluster: cu.ClusterName, + EDSServiceName: cu.EDSServiceName, + MaxConcurrentRequests: cu.MaxRequests, + } + if cu.EnableLRS { + // An empty string here indicates that the cluster_resolver balancer should use the + // same xDS server for load reporting as it does for EDS + // requests/responses. + dms[i].LoadReportingServerName = new(string) + + } + case xdsclient.ClusterTypeLogicalDNS: + dms[i] = clusterresolver.DiscoveryMechanism{ + Type: clusterresolver.DiscoveryMechanismTypeLogicalDNS, + DNSHostname: cu.DNSHostName, + } + default: + b.logger.Infof("unexpected cluster type %v when handling update from cluster handler", cu.ClusterType) + } } - lbCfg := &edsbalancer.EDSConfig{ - EDSServiceName: update.cds.ServiceName, - MaxConcurrentRequests: update.cds.MaxRequests, + lbCfg := &clusterresolver.LBConfig{ + DiscoveryMechanisms: dms, } - if update.cds.EnableLRS { - // An empty string here indicates that the edsBalancer should use the - // same xDS server for load reporting as it does for EDS - // requests/responses. - lbCfg.LrsLoadReportingServerName = new(string) + // lbPolicy is set only when the policy is ringhash. The default (when it's + // not set) is roundrobin. And similarly, we only need to set XDSLBPolicy + // for ringhash (it also defaults to roundrobin). + if lbp := update.lbPolicy; lbp != nil { + lbCfg.XDSLBPolicy = &internalserviceconfig.BalancerConfig{ + Name: ringhash.Name, + Config: &ringhash.LBConfig{ + MinRingSize: lbp.MinimumRingSize, + MaxRingSize: lbp.MaximumRingSize, + }, + } } + ccState := balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, b.xdsClient), BalancerConfig: lbCfg, } - if err := b.edsLB.UpdateClientConnState(ccState); err != nil { - b.logger.Errorf("xds: edsBalancer.UpdateClientConnState(%+v) returned error: %v", ccState, err) + if err := b.childLB.UpdateClientConnState(ccState); err != nil { + b.logger.Errorf("xds: cluster_resolver balancer.UpdateClientConnState(%+v) returned error: %v", ccState, err) } } @@ -371,28 +372,41 @@ func (b *cdsBalancer) run() { b.handleClientConnUpdate(update) case *scUpdate: // SubConn updates are passthrough and are simply handed over to - // the underlying edsBalancer. - if b.edsLB == nil { - b.logger.Errorf("xds: received scUpdate {%+v} with no edsBalancer", update) + // the underlying cluster_resolver balancer. + if b.childLB == nil { + b.logger.Errorf("xds: received scUpdate {%+v} with no cluster_resolver balancer", update) break } - b.edsLB.UpdateSubConnState(update.subConn, update.state) - case *watchUpdate: - b.handleWatchUpdate(update) + b.childLB.UpdateSubConnState(update.subConn, update.state) + case exitIdle: + if b.childLB == nil { + b.logger.Errorf("xds: received ExitIdle with no child balancer") + break + } + // This implementation assumes the child balancer supports + // ExitIdle (but still checks for the interface's existence to + // avoid a panic if not). If the child does not, no subconns + // will be connected. + if ei, ok := b.childLB.(balancer.ExitIdler); ok { + ei.ExitIdle() + } } - - // Close results in cancellation of the CDS watch and closing of the - // underlying edsBalancer and is the only way to exit this goroutine. + case u := <-b.clusterHandler.updateChannel: + b.handleWatchUpdate(u) case <-b.closed.Done(): - b.cancelWatch() - b.cancelWatch = func() {} - - if b.edsLB != nil { - b.edsLB.Close() - b.edsLB = nil + b.clusterHandler.close() + if b.childLB != nil { + b.childLB.Close() + b.childLB = nil + } + if b.cachedRoot != nil { + b.cachedRoot.Close() + } + if b.cachedIdentity != nil { + b.cachedIdentity.Close() } - // This is the *ONLY* point of return from this function. b.logger.Infof("Shutdown") + b.done.Fire() return } } @@ -411,23 +425,22 @@ func (b *cdsBalancer) run() { // - If it's from xds client, it means CDS resource were removed. The CDS // watcher should keep watching. // -// In both cases, the error will be forwarded to EDS balancer. And if error is -// resource-not-found, the child EDS balancer will stop watching EDS. +// In both cases, the error will be forwarded to the child balancer. And if +// error is resource-not-found, the child balancer will stop watching EDS. func (b *cdsBalancer) handleErrorFromUpdate(err error, fromParent bool) { - // TODO: connection errors will be sent to the eds balancers directly, and - // also forwarded by the parent balancers/resolvers. So the eds balancer may - // see the same error multiple times. We way want to only forward the error - // to eds if it's not a connection error. - // // This is not necessary today, because xds client never sends connection // errors. if fromParent && xdsclient.ErrType(err) == xdsclient.ErrorTypeResourceNotFound { - b.cancelWatch() + b.clusterHandler.close() } - if b.edsLB != nil { - b.edsLB.ResolverError(err) + if b.childLB != nil { + if xdsclient.ErrType(err) != xdsclient.ErrorTypeConnection { + // Connection errors will be sent to the child balancers directly. + // There's no need to forward them. + b.childLB.ResolverError(err) + } } else { - // If eds balancer was never created, fail the RPCs with + // If child balancer was never created, fail the RPCs with // errors. b.ccw.UpdateState(balancer.State{ ConnectivityState: connectivity.TransientFailure, @@ -436,16 +449,6 @@ func (b *cdsBalancer) handleErrorFromUpdate(err error, fromParent bool) { } } -// handleClusterUpdate is the CDS watch API callback. It simply pushes the -// received information on to the update channel for run() to pick it up. -func (b *cdsBalancer) handleClusterUpdate(cu xdsclient.ClusterUpdate, err error) { - if b.closed.HasFired() { - b.logger.Warningf("xds: received cluster update {%+v} after cdsBalancer was closed", cu) - return - } - b.updateCh.Put(&watchUpdate{cds: cu, err: err}) -} - // UpdateClientConnState receives the serviceConfig (which contains the // clusterName to watch for in CDS) and the xdsClient object from the // xdsResolver. @@ -455,7 +458,15 @@ func (b *cdsBalancer) UpdateClientConnState(state balancer.ClientConnState) erro return errBalancerClosed } - b.logger.Infof("Received update from resolver, balancer config: %+v", state.BalancerConfig) + if b.xdsClient == nil { + c := xdsclient.FromResolverState(state.ResolverState) + if c == nil { + return balancer.ErrBadResolverState + } + b.xdsClient = c + } + + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(state.BalancerConfig)) // The errors checked here should ideally never happen because the // ServiceConfig in this case is prepared by the xdsResolver and is not // something that is received on the wire. @@ -490,10 +501,15 @@ func (b *cdsBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Sub b.updateCh.Put(&scUpdate{subConn: sc, state: state}) } -// Close closes the cdsBalancer and the underlying edsBalancer. +// Close cancels the CDS watch, closes the child policy and closes the +// cdsBalancer. func (b *cdsBalancer) Close() { b.closed.Fire() - b.xdsClient.Close() + <-b.done.Done() +} + +func (b *cdsBalancer) ExitIdle() { + b.updateCh.Put(exitIdle{}) } // ccWrapper wraps the balancer.ClientConn passed to the CDS balancer at diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go index fee48c262eb..9483818e306 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -20,47 +20,63 @@ import ( "context" "errors" "fmt" + "regexp" "testing" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/credentials/local" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/internal" - xdsinternal "google.golang.org/grpc/internal/credentials/xds" + xdscredsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/resolver" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( fakeProvider1Name = "fake-certificate-provider-1" fakeProvider2Name = "fake-certificate-provider-2" fakeConfig = "my fake config" + testSAN = "test-san" ) var ( + testSANMatchers = []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP(testSAN), nil, nil, nil, nil, true), + matcher.StringMatcherForTesting(nil, newStringP(testSAN), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP(testSAN), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(testSAN), false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP(testSAN), nil, false), + } fpb1, fpb2 *fakeProviderBuilder bootstrapConfig *bootstrap.Config cdsUpdateWithGoodSecurityCfg = xdsclient.ClusterUpdate{ - ServiceName: serviceName, + ClusterName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", - IdentityInstanceName: "default2", + RootInstanceName: "default1", + IdentityInstanceName: "default2", + SubjectAltNameMatchers: testSANMatchers, }, } cdsUpdateWithMissingSecurityCfg = xdsclient.ClusterUpdate{ - ServiceName: serviceName, + ClusterName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ RootInstanceName: "not-default", }, } ) +func newStringP(s string) *string { + return &s +} + func init() { fpb1 = &fakeProviderBuilder{name: fakeProvider1Name} fpb2 = &fakeProviderBuilder{name: fakeProvider2Name} @@ -115,11 +131,7 @@ func (p *fakeProvider) Close() { // xDSCredentials. func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) { t.Helper() - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - builder := balancer.Get(cdsName) if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", cdsName) @@ -139,14 +151,14 @@ func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDS // Override the creation of the EDS balancer to return a fake EDS balancer // implementation. edsB := newTestEDSBalancer() - oldEDSBalancerBuilder := newEDSBalancer - newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { + oldEDSBalancerBuilder := newChildBalancer + newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { edsB.parentCC = cc return edsB, nil } // Push a ClientConnState update to the CDS balancer with a cluster name. - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } @@ -163,8 +175,8 @@ func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDS } return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() { - newXDSClient = oldNewXDSClient - newEDSBalancer = oldEDSBalancerBuilder + newChildBalancer = oldEDSBalancerBuilder + xdsC.Close() } } @@ -190,7 +202,7 @@ func makeNewSubConn(ctx context.Context, edsCC balancer.ClientConn, parentCC *xd if got, want := gotAddrs[0].Addr, addrs[0].Addr; got != want { return nil, fmt.Errorf("resolver.Address passed to parent ClientConn has address %q, want %q", got, want) } - getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdsinternal.HandshakeInfo) + getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo) hi := getHI(gotAddrs[0].Attributes) if hi == nil { return nil, errors.New("resolver.Address passed to parent ClientConn doesn't contain attributes") @@ -198,6 +210,11 @@ func makeNewSubConn(ctx context.Context, edsCC balancer.ClientConn, parentCC *xd if gotFallback := hi.UseFallbackCreds(); gotFallback != wantFallback { return nil, fmt.Errorf("resolver.Address HandshakeInfo uses fallback creds? %v, want %v", gotFallback, wantFallback) } + if !wantFallback { + if diff := cmp.Diff(testSANMatchers, hi.GetSANMatchersForTesting(), cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + return nil, fmt.Errorf("unexpected diff in the list of SAN matchers (-got, +want):\n%s", diff) + } + } } return sc, nil } @@ -232,9 +249,9 @@ func (s) TestSecurityConfigWithoutXDSCreds(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -287,10 +304,10 @@ func (s) TestNoSecurityConfigWithXDSCreds(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. No security config is + // newChildBalancer function as part of test setup. No security config is // passed to the CDS balancer as part of this update. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -445,8 +462,8 @@ func (s) TestSecurityConfigUpdate_BadToGood(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + wantCCS := edsCCS(serviceName, nil, false, nil) if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } @@ -479,8 +496,8 @@ func (s) TestGoodSecurityConfig(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { @@ -507,7 +524,7 @@ func (s) TestGoodSecurityConfig(t *testing.T) { if got, want := gotAddrs[0].Addr, addrs[0].Addr; got != want { t.Fatalf("resolver.Address passed to parent ClientConn through UpdateAddresses() has address %q, want %q", got, want) } - getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdsinternal.HandshakeInfo) + getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo) hi := getHI(gotAddrs[0].Attributes) if hi == nil { t.Fatal("resolver.Address passed to parent ClientConn through UpdateAddresses() doesn't contain attributes") @@ -532,8 +549,8 @@ func (s) TestSecurityConfigUpdate_GoodToFallback(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { @@ -549,7 +566,7 @@ func (s) TestSecurityConfigUpdate_GoodToFallback(t *testing.T) { // an update which contains bad security config. So, we expect the CDS // balancer to forward this error to the EDS balancer and eventually the // channel needs to be put in a bad state. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } @@ -582,8 +599,8 @@ func (s) TestSecurityConfigUpdate_GoodToBad(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { @@ -617,7 +634,7 @@ func (s) TestSecurityConfigUpdate_GoodToBad(t *testing.T) { // registered watch should not be cancelled. sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } } @@ -653,14 +670,15 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. cdsUpdate := xdsclient.ClusterUpdate{ - ServiceName: serviceName, + ClusterName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", + RootInstanceName: "default1", + SubjectAltNameMatchers: testSANMatchers, }, } - wantCCS := edsCCS(serviceName, nil, false) + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -679,9 +697,10 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { // Push another update with a new security configuration. cdsUpdate = xdsclient.ClusterUpdate{ - ServiceName: serviceName, + ClusterName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default2", + RootInstanceName: "default2", + SubjectAltNameMatchers: testSANMatchers, }, } if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go index 9c7bc2362ab..30b612fc7d0 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go @@ -26,19 +26,19 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpctest" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal/balancer/edsbalancer" - "google.golang.org/grpc/xds/internal/client" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/balancer/clusterresolver" + "google.golang.org/grpc/xds/internal/balancer/ringhash" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -83,7 +83,8 @@ type testEDSBalancer struct { // resolverErrCh is a channel used to signal a resolver error. resolverErrCh *testutils.Channel // closeCh is a channel used to signal the closing of this balancer. - closeCh *testutils.Channel + closeCh *testutils.Channel + exitIdleCh *testutils.Channel // parentCC is the balancer.ClientConn passed to this test balancer as part // of the Build() call. parentCC balancer.ClientConn @@ -100,6 +101,7 @@ func newTestEDSBalancer() *testEDSBalancer { scStateCh: testutils.NewChannel(), resolverErrCh: testutils.NewChannel(), closeCh: testutils.NewChannel(), + exitIdleCh: testutils.NewChannel(), } } @@ -120,6 +122,10 @@ func (tb *testEDSBalancer) Close() { tb.closeCh.Send(struct{}{}) } +func (tb *testEDSBalancer) ExitIdle() { + tb.exitIdleCh.Send(struct{}{}) +} + // waitForClientConnUpdate verifies if the testEDSBalancer receives the // provided ClientConnState within a reasonable amount of time. func (tb *testEDSBalancer) waitForClientConnUpdate(ctx context.Context, wantCCS balancer.ClientConnState) error { @@ -128,8 +134,11 @@ func (tb *testEDSBalancer) waitForClientConnUpdate(ctx context.Context, wantCCS return err } gotCCS := ccs.(balancer.ClientConnState) - if !cmp.Equal(gotCCS, wantCCS, cmpopts.IgnoreUnexported(attributes.Attributes{})) { - return fmt.Errorf("received ClientConnState: %+v, want %+v", gotCCS, wantCCS) + if xdsclient.FromResolverState(gotCCS.ResolverState) == nil { + return fmt.Errorf("want resolver state with XDSClient attached, got one without") + } + if diff := cmp.Diff(gotCCS, wantCCS, cmpopts.IgnoreFields(resolver.State{}, "Attributes")); diff != "" { + return fmt.Errorf("received unexpected ClientConnState, diff (-got +want): %v", diff) } return nil } @@ -172,7 +181,7 @@ func (tb *testEDSBalancer) waitForClose(ctx context.Context) error { // cdsCCS is a helper function to construct a good update passed from the // xdsResolver to the cdsBalancer. -func cdsCCS(cluster string) balancer.ClientConnState { +func cdsCCS(cluster string, xdsC xdsclient.XDSClient) balancer.ClientConnState { const cdsLBConfig = `{ "loadBalancingConfig":[ { @@ -184,37 +193,40 @@ func cdsCCS(cluster string) balancer.ClientConnState { }` jsonSC := fmt.Sprintf(cdsLBConfig, cluster) return balancer.ClientConnState{ - ResolverState: resolver.State{ + ResolverState: xdsclient.SetClient(resolver.State{ ServiceConfig: internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(jsonSC), - }, + }, xdsC), BalancerConfig: &lbConfig{ClusterName: clusterName}, } } // edsCCS is a helper function to construct a good update passed from the // cdsBalancer to the edsBalancer. -func edsCCS(service string, countMax *uint32, enableLRS bool) balancer.ClientConnState { - lbCfg := &edsbalancer.EDSConfig{ - EDSServiceName: service, +func edsCCS(service string, countMax *uint32, enableLRS bool, xdslbpolicy *internalserviceconfig.BalancerConfig) balancer.ClientConnState { + discoveryMechanism := clusterresolver.DiscoveryMechanism{ + Type: clusterresolver.DiscoveryMechanismTypeEDS, + Cluster: service, MaxConcurrentRequests: countMax, } if enableLRS { - lbCfg.LrsLoadReportingServerName = new(string) + discoveryMechanism.LoadReportingServerName = new(string) + } + lbCfg := &clusterresolver.LBConfig{ + DiscoveryMechanisms: []clusterresolver.DiscoveryMechanism{discoveryMechanism}, + XDSLBPolicy: xdslbpolicy, + } + return balancer.ClientConnState{ BalancerConfig: lbCfg, } } // setup creates a cdsBalancer and an edsBalancer (and overrides the -// newEDSBalancer function to return it), and also returns a cleanup function. +// newChildBalancer function to return it), and also returns a cleanup function. func setup(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) { t.Helper() - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - builder := balancer.Get(cdsName) if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", cdsName) @@ -223,15 +235,15 @@ func setup(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *x cdsB := builder.Build(tcc, balancer.BuildOptions{}) edsB := newTestEDSBalancer() - oldEDSBalancerBuilder := newEDSBalancer - newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { + oldEDSBalancerBuilder := newChildBalancer + newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { edsB.parentCC = cc return edsB, nil } return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() { - newEDSBalancer = oldEDSBalancerBuilder - newXDSClient = oldNewXDSClient + newChildBalancer = oldEDSBalancerBuilder + xdsC.Close() } } @@ -241,7 +253,7 @@ func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBal t.Helper() xdsC, cdsB, edsB, tcc, cancel := setup(t) - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } @@ -261,6 +273,9 @@ func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBal // cdsBalancer with different inputs and verifies that the CDS watch API on the // provided xdsClient is invoked appropriately. func (s) TestUpdateClientConnState(t *testing.T) { + xdsC := fakeclient.NewClient() + defer xdsC.Close() + tests := []struct { name string ccs balancer.ClientConnState @@ -279,14 +294,14 @@ func (s) TestUpdateClientConnState(t *testing.T) { }, { name: "happy-good-case", - ccs: cdsCCS(clusterName), + ccs: cdsCCS(clusterName, xdsC), wantCluster: clusterName, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - xdsC, cdsB, _, _, cancel := setup(t) + _, cdsB, _, _, cancel := setup(t) defer func() { cancel() cdsB.Close() @@ -323,7 +338,7 @@ func (s) TestUpdateClientConnStateWithSameState(t *testing.T) { }() // This is the same clientConn update sent in setupWithWatch(). - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } // The above update should not result in a new watch being registered. @@ -352,13 +367,24 @@ func (s) TestHandleClusterUpdate(t *testing.T) { }{ { name: "happy-case-with-lrs", - cdsUpdate: xdsclient.ClusterUpdate{ServiceName: serviceName, EnableLRS: true}, - wantCCS: edsCCS(serviceName, nil, true), + cdsUpdate: xdsclient.ClusterUpdate{ClusterName: serviceName, EnableLRS: true}, + wantCCS: edsCCS(serviceName, nil, true, nil), }, { name: "happy-case-without-lrs", - cdsUpdate: xdsclient.ClusterUpdate{ServiceName: serviceName}, - wantCCS: edsCCS(serviceName, nil, false), + cdsUpdate: xdsclient.ClusterUpdate{ClusterName: serviceName}, + wantCCS: edsCCS(serviceName, nil, false, nil), + }, + { + name: "happy-case-with-ring-hash-lb-policy", + cdsUpdate: xdsclient.ClusterUpdate{ + ClusterName: serviceName, + LBPolicy: &xdsclient.ClusterLBPolicyRingHash{MinimumRingSize: 10, MaximumRingSize: 100}, + }, + wantCCS: edsCCS(serviceName, nil, false, &internalserviceconfig.BalancerConfig{ + Name: ringhash.Name, + Config: &ringhash.LBConfig{MinRingSize: 10, MaxRingSize: 100}, + }), }, } @@ -397,7 +423,7 @@ func (s) TestHandleClusterUpdateError(t *testing.T) { // registered watch should not be cancelled. sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } // The CDS balancer has not yet created an EDS balancer. So, this resolver @@ -424,9 +450,9 @@ func (s) TestHandleClusterUpdateError(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } @@ -436,7 +462,7 @@ func (s) TestHandleClusterUpdateError(t *testing.T) { // Make sure the registered watch is not cancelled. sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } // Make sure the error is forwarded to the EDS balancer. @@ -451,7 +477,7 @@ func (s) TestHandleClusterUpdateError(t *testing.T) { // request cluster resource is not found. We should continue to watch it. sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a resource-not-found-error") } // Make sure the error is forwarded to the EDS balancer. @@ -483,7 +509,7 @@ func (s) TestResolverError(t *testing.T) { // registered watch should not be cancelled. sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } // The CDS balancer has not yet created an EDS balancer. So, this resolver @@ -509,9 +535,9 @@ func (s) TestResolverError(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } @@ -521,7 +547,7 @@ func (s) TestResolverError(t *testing.T) { // Make sure the registered watch is not cancelled. sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } // Make sure the error is forwarded to the EDS balancer. @@ -533,7 +559,7 @@ func (s) TestResolverError(t *testing.T) { resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "cdsBalancer resource not found error") cdsB.ResolverError(resourceErr) // Make sure the registered watch is cancelled. - if err := xdsC.WaitForCancelClusterWatch(ctx); err != nil { + if _, err := xdsC.WaitForCancelClusterWatch(ctx); err != nil { t.Fatalf("want watch to be canceled, watchForCancel failed: %v", err) } // Make sure the error is forwarded to the EDS balancer. @@ -558,9 +584,9 @@ func (s) TestUpdateSubConnState(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -594,8 +620,8 @@ func (s) TestCircuitBreaking(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will update // the service's counter with the new max requests. var maxRequests uint32 = 1 - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName, MaxRequests: &maxRequests} - wantCCS := edsCCS(serviceName, &maxRequests, false) + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: clusterName, MaxRequests: &maxRequests} + wantCCS := edsCCS(clusterName, &maxRequests, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -604,7 +630,7 @@ func (s) TestCircuitBreaking(t *testing.T) { // Since the counter's max requests was set to 1, the first request should // succeed and the second should fail. - counter := client.GetServiceRequestsCounter(serviceName) + counter := xdsclient.GetClusterRequestsCounter(clusterName, "") if err := counter.StartRequest(maxRequests); err != nil { t.Fatal(err) } @@ -626,9 +652,9 @@ func (s) TestClose(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -640,7 +666,7 @@ func (s) TestClose(t *testing.T) { // Make sure that the cluster watch registered by the CDS balancer is // cancelled. - if err := xdsC.WaitForCancelClusterWatch(ctx); err != nil { + if _, err := xdsC.WaitForCancelClusterWatch(ctx); err != nil { t.Fatal(err) } @@ -659,7 +685,7 @@ func (s) TestClose(t *testing.T) { // Make sure that the UpdateClientConnState() method on the CDS balancer // returns error. - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != errBalancerClosed { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != errBalancerClosed { t.Fatalf("UpdateClientConnState() after close returned %v, want %v", err, errBalancerClosed) } @@ -683,6 +709,35 @@ func (s) TestClose(t *testing.T) { } } +func (s) TestExitIdle(t *testing.T) { + // This creates a CDS balancer, pushes a ClientConnState update with a fake + // xdsClient, and makes sure that the CDS balancer registers a watch on the + // provided xdsClient. + xdsC, cdsB, edsB, _, cancel := setupWithWatch(t) + defer func() { + cancel() + cdsB.Close() + }() + + // Here we invoke the watch callback registered on the fake xdsClient. This + // will trigger the watch handler on the CDS balancer, which will attempt to + // create a new EDS balancer. The fake EDS balancer created above will be + // returned to the CDS balancer, because we have overridden the + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { + t.Fatal(err) + } + + // Call ExitIdle on the CDS balancer. + cdsB.ExitIdle() + + edsB.exitIdleCh.Receive(ctx) +} + // TestParseConfig verifies the ParseConfig() method in the CDS balancer. func (s) TestParseConfig(t *testing.T) { bb := balancer.Get(cdsName) diff --git a/xds/internal/balancer/cdsbalancer/cluster_handler.go b/xds/internal/balancer/cdsbalancer/cluster_handler.go new file mode 100644 index 00000000000..163a8c0a2e1 --- /dev/null +++ b/xds/internal/balancer/cdsbalancer/cluster_handler.go @@ -0,0 +1,318 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cdsbalancer + +import ( + "errors" + "sync" + + "google.golang.org/grpc/xds/internal/xdsclient" +) + +var errNotReceivedUpdate = errors.New("tried to construct a cluster update on a cluster that has not received an update") + +// clusterHandlerUpdate wraps the information received from the registered CDS +// watcher. A non-nil error is propagated to the underlying cluster_resolver +// balancer. A valid update results in creating a new cluster_resolver balancer +// (if one doesn't already exist) and pushing the update to it. +type clusterHandlerUpdate struct { + // securityCfg is the Security Config from the top (root) cluster. + securityCfg *xdsclient.SecurityConfig + // lbPolicy is the lb policy from the top (root) cluster. + // + // Currently, we only support roundrobin or ringhash, and since roundrobin + // does need configs, this is only set to the ringhash config, if the policy + // is ringhash. In the future, if we support more policies, we can make this + // an interface, and set it to config of the other policies. + lbPolicy *xdsclient.ClusterLBPolicyRingHash + + // updates is a list of ClusterUpdates from all the leaf clusters. + updates []xdsclient.ClusterUpdate + err error +} + +// clusterHandler will be given a name representing a cluster. It will then +// update the CDS policy constantly with a list of Clusters to pass down to +// XdsClusterResolverLoadBalancingPolicyConfig in a stream like fashion. +type clusterHandler struct { + parent *cdsBalancer + + // A mutex to protect entire tree of clusters. + clusterMutex sync.Mutex + root *clusterNode + rootClusterName string + + // A way to ping CDS Balancer about any updates or errors to a Node in the + // tree. This will either get called from this handler constructing an + // update or from a child with an error. Capacity of one as the only update + // CDS Balancer cares about is the most recent update. + updateChannel chan clusterHandlerUpdate +} + +func newClusterHandler(parent *cdsBalancer) *clusterHandler { + return &clusterHandler{ + parent: parent, + updateChannel: make(chan clusterHandlerUpdate, 1), + } +} + +func (ch *clusterHandler) updateRootCluster(rootClusterName string) { + ch.clusterMutex.Lock() + defer ch.clusterMutex.Unlock() + if ch.root == nil { + // Construct a root node on first update. + ch.root = createClusterNode(rootClusterName, ch.parent.xdsClient, ch) + ch.rootClusterName = rootClusterName + return + } + // Check if root cluster was changed. If it was, delete old one and start + // new one, if not do nothing. + if rootClusterName != ch.rootClusterName { + ch.root.delete() + ch.root = createClusterNode(rootClusterName, ch.parent.xdsClient, ch) + ch.rootClusterName = rootClusterName + } +} + +// This function tries to construct a cluster update to send to CDS. +func (ch *clusterHandler) constructClusterUpdate() { + if ch.root == nil { + // If root is nil, this handler is closed, ignore the update. + return + } + clusterUpdate, err := ch.root.constructClusterUpdate() + if err != nil { + // If there was an error received no op, as this simply means one of the + // children hasn't received an update yet. + return + } + // For a ClusterUpdate, the only update CDS cares about is the most + // recent one, so opportunistically drain the update channel before + // sending the new update. + select { + case <-ch.updateChannel: + default: + } + ch.updateChannel <- clusterHandlerUpdate{ + securityCfg: ch.root.clusterUpdate.SecurityCfg, + lbPolicy: ch.root.clusterUpdate.LBPolicy, + updates: clusterUpdate, + } +} + +// close() is meant to be called by CDS when the CDS balancer is closed, and it +// cancels the watches for every cluster in the cluster tree. +func (ch *clusterHandler) close() { + ch.clusterMutex.Lock() + defer ch.clusterMutex.Unlock() + if ch.root == nil { + return + } + ch.root.delete() + ch.root = nil + ch.rootClusterName = "" +} + +// This logically represents a cluster. This handles all the logic for starting +// and stopping a cluster watch, handling any updates, and constructing a list +// recursively for the ClusterHandler. +type clusterNode struct { + // A way to cancel the watch for the cluster. + cancelFunc func() + + // A list of children, as the Node can be an aggregate Cluster. + children []*clusterNode + + // A ClusterUpdate in order to build a list of cluster updates for CDS to + // send down to child XdsClusterResolverLoadBalancingPolicy. + clusterUpdate xdsclient.ClusterUpdate + + // This boolean determines whether this Node has received an update or not. + // This isn't the best practice, but this will protect a list of Cluster + // Updates from being constructed if a cluster in the tree has not received + // an update yet. + receivedUpdate bool + + clusterHandler *clusterHandler +} + +// CreateClusterNode creates a cluster node from a given clusterName. This will +// also start the watch for that cluster. +func createClusterNode(clusterName string, xdsClient xdsclient.XDSClient, topLevelHandler *clusterHandler) *clusterNode { + c := &clusterNode{ + clusterHandler: topLevelHandler, + } + // Communicate with the xds client here. + topLevelHandler.parent.logger.Infof("CDS watch started on %v", clusterName) + cancel := xdsClient.WatchCluster(clusterName, c.handleResp) + c.cancelFunc = func() { + topLevelHandler.parent.logger.Infof("CDS watch canceled on %v", clusterName) + cancel() + } + return c +} + +// This function cancels the cluster watch on the cluster and all of it's +// children. +func (c *clusterNode) delete() { + c.cancelFunc() + for _, child := range c.children { + child.delete() + } +} + +// Construct cluster update (potentially a list of ClusterUpdates) for a node. +func (c *clusterNode) constructClusterUpdate() ([]xdsclient.ClusterUpdate, error) { + // If the cluster has not yet received an update, the cluster update is not + // yet ready. + if !c.receivedUpdate { + return nil, errNotReceivedUpdate + } + + // Base case - LogicalDNS or EDS. Both of these cluster types will be tied + // to a single ClusterUpdate. + if c.clusterUpdate.ClusterType != xdsclient.ClusterTypeAggregate { + return []xdsclient.ClusterUpdate{c.clusterUpdate}, nil + } + + // If an aggregate construct a list by recursively calling down to all of + // it's children. + var childrenUpdates []xdsclient.ClusterUpdate + for _, child := range c.children { + childUpdateList, err := child.constructClusterUpdate() + if err != nil { + return nil, err + } + childrenUpdates = append(childrenUpdates, childUpdateList...) + } + return childrenUpdates, nil +} + +// handleResp handles a xds response for a particular cluster. This function +// also handles any logic with regards to any child state that may have changed. +// At the end of the handleResp(), the clusterUpdate will be pinged in certain +// situations to try and construct an update to send back to CDS. +func (c *clusterNode) handleResp(clusterUpdate xdsclient.ClusterUpdate, err error) { + c.clusterHandler.clusterMutex.Lock() + defer c.clusterHandler.clusterMutex.Unlock() + if err != nil { // Write this error for run() to pick up in CDS LB policy. + // For a ClusterUpdate, the only update CDS cares about is the most + // recent one, so opportunistically drain the update channel before + // sending the new update. + select { + case <-c.clusterHandler.updateChannel: + default: + } + c.clusterHandler.updateChannel <- clusterHandlerUpdate{err: err} + return + } + + c.receivedUpdate = true + c.clusterUpdate = clusterUpdate + + // If the cluster was a leaf node, if the cluster update received had change + // in the cluster update then the overall cluster update would change and + // there is a possibility for the overall update to build so ping cluster + // handler to return. Also, if there was any children from previously, + // delete the children, as the cluster type is no longer an aggregate + // cluster. + if clusterUpdate.ClusterType != xdsclient.ClusterTypeAggregate { + for _, child := range c.children { + child.delete() + } + c.children = nil + // This is an update in the one leaf node, should try to send an update + // to the parent CDS balancer. + // + // Note that this update might be a duplicate from the previous one. + // Because the update contains not only the cluster name to watch, but + // also the extra fields (e.g. security config). There's no good way to + // compare all the fields. + c.clusterHandler.constructClusterUpdate() + return + } + + // Aggregate cluster handling. + newChildren := make(map[string]bool) + for _, childName := range clusterUpdate.PrioritizedClusterNames { + newChildren[childName] = true + } + + // These booleans help determine whether this callback will ping the overall + // clusterHandler to try and construct an update to send back to CDS. This + // will be determined by whether there would be a change in the overall + // clusterUpdate for the whole tree (ex. change in clusterUpdate for current + // cluster or a deleted child) and also if there's even a possibility for + // the update to build (ex. if a child is created and a watch is started, + // that child hasn't received an update yet due to the mutex lock on this + // callback). + var createdChild, deletedChild bool + + // This map will represent the current children of the cluster. It will be + // first added to in order to represent the new children. It will then have + // any children deleted that are no longer present. Then, from the cluster + // update received, will be used to construct the new child list. + mapCurrentChildren := make(map[string]*clusterNode) + for _, child := range c.children { + mapCurrentChildren[child.clusterUpdate.ClusterName] = child + } + + // Add and construct any new child nodes. + for child := range newChildren { + if _, inChildrenAlready := mapCurrentChildren[child]; !inChildrenAlready { + createdChild = true + mapCurrentChildren[child] = createClusterNode(child, c.clusterHandler.parent.xdsClient, c.clusterHandler) + } + } + + // Delete any child nodes no longer in the aggregate cluster's children. + for child := range mapCurrentChildren { + if _, stillAChild := newChildren[child]; !stillAChild { + deletedChild = true + mapCurrentChildren[child].delete() + delete(mapCurrentChildren, child) + } + } + + // The order of the children list matters, so use the clusterUpdate from + // xdsclient as the ordering, and use that logical ordering for the new + // children list. This will be a mixture of child nodes which are all + // already constructed in the mapCurrentChildrenMap. + var children = make([]*clusterNode, 0, len(clusterUpdate.PrioritizedClusterNames)) + + for _, orderedChild := range clusterUpdate.PrioritizedClusterNames { + // The cluster's already have watches started for them in xds client, so + // you can use these pointers to construct the new children list, you + // just have to put them in the correct order using the original cluster + // update. + currentChild := mapCurrentChildren[orderedChild] + children = append(children, currentChild) + } + + c.children = children + + // If the cluster is an aggregate cluster, if this callback created any new + // child cluster nodes, then there's no possibility for a full cluster + // update to successfully build, as those created children will not have + // received an update yet. However, if there was simply a child deleted, + // then there is a possibility that it will have a full cluster update to + // build and also will have a changed overall cluster update from the + // deleted child. + if deletedChild && !createdChild { + c.clusterHandler.constructClusterUpdate() + } +} diff --git a/xds/internal/balancer/cdsbalancer/cluster_handler_test.go b/xds/internal/balancer/cdsbalancer/cluster_handler_test.go new file mode 100644 index 00000000000..cb9b4e14da3 --- /dev/null +++ b/xds/internal/balancer/cdsbalancer/cluster_handler_test.go @@ -0,0 +1,685 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cdsbalancer + +import ( + "context" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const ( + edsService = "EDS Service" + logicalDNSService = "Logical DNS Service" + edsService2 = "EDS Service 2" + logicalDNSService2 = "Logical DNS Service 2" + aggregateClusterService = "Aggregate Cluster Service" +) + +// setupTests creates a clusterHandler with a fake xds client for control over +// xds client. +func setupTests(t *testing.T) (*clusterHandler, *fakeclient.Client) { + xdsC := fakeclient.NewClient() + ch := newClusterHandler(&cdsBalancer{xdsClient: xdsC}) + return ch, xdsC +} + +// Simplest case: the cluster handler receives a cluster name, handler starts a +// watch for that cluster, xds client returns that it is a Leaf Node (EDS or +// LogicalDNS), not a tree, so expectation that update is written to buffer +// which will be read by CDS LB. +func (s) TestSuccessCaseLeafNode(t *testing.T) { + tests := []struct { + name string + clusterName string + clusterUpdate xdsclient.ClusterUpdate + lbPolicy *xdsclient.ClusterLBPolicyRingHash + }{ + { + name: "test-update-root-cluster-EDS-success", + clusterName: edsService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, + }, + { + name: "test-update-root-cluster-EDS-with-ring-hash", + clusterName: logicalDNSService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + LBPolicy: &xdsclient.ClusterLBPolicyRingHash{MinimumRingSize: 10, MaximumRingSize: 100}, + }, + lbPolicy: &xdsclient.ClusterLBPolicyRingHash{MinimumRingSize: 10, MaximumRingSize: 100}, + }, + { + name: "test-update-root-cluster-Logical-DNS-success", + clusterName: logicalDNSService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ch, fakeClient := setupTests(t) + // When you first update the root cluster, it should hit the code + // path which will start a cluster node for that root. Updating the + // root cluster logically represents a ping from a ClientConn. + ch.updateRootCluster(test.clusterName) + // Starting a cluster node involves communicating with the + // xdsClient, telling it to watch a cluster. + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + gotCluster, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != test.clusterName { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, test.clusterName) + } + // Invoke callback with xds client with a certain clusterUpdate. Due + // to this cluster update filling out the whole cluster tree, as the + // cluster is of a root type (EDS or Logical DNS) and not an + // aggregate cluster, this should trigger the ClusterHandler to + // write to the update buffer to update the CDS policy. + fakeClient.InvokeWatchClusterCallback(test.clusterUpdate, nil) + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{test.clusterUpdate}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + if diff := cmp.Diff(chu.lbPolicy, test.lbPolicy); diff != "" { + t.Fatalf("got unexpected lb policy in cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } + // Close the clusterHandler. This is meant to be called when the CDS + // Balancer is closed, and the call should cancel the watch for this + // cluster. + ch.close() + clusterNameDeleted, err := fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != test.clusterName { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + }) + } +} + +// The cluster handler receives a cluster name, handler starts a watch for that +// cluster, xds client returns that it is a Leaf Node (EDS or LogicalDNS), not a +// tree, so expectation that first update is written to buffer which will be +// read by CDS LB. Then, send a new cluster update that is different, with the +// expectation that it is also written to the update buffer to send back to CDS. +func (s) TestSuccessCaseLeafNodeThenNewUpdate(t *testing.T) { + tests := []struct { + name string + clusterName string + clusterUpdate xdsclient.ClusterUpdate + newClusterUpdate xdsclient.ClusterUpdate + }{ + {name: "test-update-root-cluster-then-new-update-EDS-success", + clusterName: edsService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, + newClusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService2, + }, + }, + { + name: "test-update-root-cluster-then-new-update-Logical-DNS-success", + clusterName: logicalDNSService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, + newClusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService2, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ch, fakeClient := setupTests(t) + ch.updateRootCluster(test.clusterName) + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + fakeClient.InvokeWatchClusterCallback(test.clusterUpdate, nil) + select { + case <-ch.updateChannel: + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from updateChannel.") + } + + // Check that sending the same cluster update also induces an update + // to be written to update buffer. + fakeClient.InvokeWatchClusterCallback(test.clusterUpdate, nil) + shouldNotHappenCtx, shouldNotHappenCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + select { + case <-ch.updateChannel: + case <-shouldNotHappenCtx.Done(): + t.Fatal("Timed out waiting for update from updateChannel.") + } + + // Above represents same thing as the simple + // TestSuccessCaseLeafNode, extra behavior + validation (clusterNode + // which is a leaf receives a changed clusterUpdate, which should + // ping clusterHandler, which should then write to the update + // buffer). + fakeClient.InvokeWatchClusterCallback(test.newClusterUpdate, nil) + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{test.newClusterUpdate}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from updateChannel.") + } + }) + } +} + +// TestUpdateRootClusterAggregateSuccess tests the case where an aggregate +// cluster is a root pointing to two child clusters one of type EDS and the +// other of type LogicalDNS. This test will then send cluster updates for both +// the children, and at the end there should be a successful clusterUpdate +// written to the update buffer to send back to CDS. +func (s) TestUpdateRootClusterAggregateSuccess(t *testing.T) { + ch, fakeClient := setupTests(t) + ch.updateRootCluster(aggregateClusterService) + + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + gotCluster, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != aggregateClusterService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, aggregateClusterService) + } + + // The xdsClient telling the clusterNode that the cluster type is an + // aggregate cluster which will cause a lot of downstream behavior. For a + // cluster type that isn't an aggregate, the behavior is simple. The + // clusterNode will simply get a successful update, which will then ping the + // clusterHandler which will successfully build an update to send to the CDS + // policy. In the aggregate cluster case, the handleResp callback must also + // start watches for the aggregate cluster's children. The ping to the + // clusterHandler at the end of handleResp should be a no-op, as neither the + // EDS or LogicalDNS child clusters have received an update yet. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: aggregateClusterService, + PrioritizedClusterNames: []string{edsService, logicalDNSService}, + }, nil) + + // xds client should be called to start a watch for one of the child + // clusters of the aggregate. The order of the children in the update + // written to the buffer to send to CDS matters, however there is no + // guarantee on the order it will start the watches of the children. + gotCluster, err = fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService { + if gotCluster != logicalDNSService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, edsService) + } + } + + // xds client should then be called to start a watch for the second child + // cluster. + gotCluster, err = fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService { + if gotCluster != logicalDNSService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, logicalDNSService) + } + } + + // The handleResp() call on the root aggregate cluster should not ping the + // cluster handler to try and construct an update, as the handleResp() + // callback knows that when a child is created, it cannot possibly build a + // successful update yet. Thus, there should be nothing in the update + // channel. + + shouldNotHappenCtx, shouldNotHappenCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + + select { + case <-ch.updateChannel: + t.Fatal("Cluster Handler wrote an update to updateChannel when it shouldn't have, as each node in the full cluster tree has not yet received an update") + case <-shouldNotHappenCtx.Done(): + } + + // Send callback for the EDS child cluster. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, nil) + + // EDS child cluster will ping the Cluster Handler, to try an update, which + // still won't successfully build as the LogicalDNS child of the root + // aggregate cluster has not yet received and handled an update. + select { + case <-ch.updateChannel: + t.Fatal("Cluster Handler wrote an update to updateChannel when it shouldn't have, as each node in the full cluster tree has not yet received an update") + case <-shouldNotHappenCtx.Done(): + } + + // Invoke callback for Logical DNS child cluster. + + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, nil) + + // Will Ping Cluster Handler, which will finally successfully build an + // update as all nodes in the tree of clusters have received an update. + // Since this cluster is an aggregate cluster comprised of two children, the + // returned update should be length 2, as the xds cluster resolver LB policy + // only cares about the full list of LogicalDNS and EDS clusters + // representing the base nodes of the tree of clusters. This list should be + // ordered as per the cluster update. + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, { + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the cluster update to be written to the update buffer.") + } +} + +// TestUpdateRootClusterAggregateThenChangeChild tests the scenario where you +// have an aggregate cluster with an EDS child and a LogicalDNS child, then you +// change one of the children and send an update for the changed child. This +// should write a new update to the update buffer to send back to CDS. +func (s) TestUpdateRootClusterAggregateThenChangeChild(t *testing.T) { + // This initial code is the same as the test for the aggregate success case, + // except without validations. This will get this test to the point where it + // can change one of the children. + ch, fakeClient := setupTests(t) + ch.updateRootCluster(aggregateClusterService) + + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: aggregateClusterService, + PrioritizedClusterNames: []string{edsService, logicalDNSService}, + }, nil) + fakeClient.WaitForWatchCluster(ctx) + fakeClient.WaitForWatchCluster(ctx) + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, nil) + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, nil) + + select { + case <-ch.updateChannel: + case <-ctx.Done(): + t.Fatal("Timed out waiting for the cluster update to be written to the update buffer.") + } + + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: aggregateClusterService, + PrioritizedClusterNames: []string{edsService, logicalDNSService2}, + }, nil) + + // The cluster update let's the aggregate cluster know that it's children + // are now edsService and logicalDNSService2, which implies that the + // aggregateCluster lost it's old logicalDNSService child. Thus, the + // logicalDNSService child should be deleted. + clusterNameDeleted, err := fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != logicalDNSService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + + // The handleResp() callback should then start a watch for + // logicalDNSService2. + clusterNameCreated, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if clusterNameCreated != logicalDNSService2 { + t.Fatalf("xdsClient.WatchCDS called for cluster %v, want: %v", clusterNameCreated, logicalDNSService2) + } + + // handleResp() should try and send an update here, but it will fail as + // logicalDNSService2 has not yet received an update. + shouldNotHappenCtx, shouldNotHappenCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + select { + case <-ch.updateChannel: + t.Fatal("Cluster Handler wrote an update to updateChannel when it shouldn't have, as each node in the full cluster tree has not yet received an update") + case <-shouldNotHappenCtx.Done(): + } + + // Invoke a callback for the new logicalDNSService2 - this will fill out the + // tree with successful updates. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService2, + }, nil) + + // Behavior: This update make every node in the tree of cluster have + // received an update. Thus, at the end of this callback, when you ping the + // clusterHandler to try and construct an update, the update should now + // successfully be written to update buffer to send back to CDS. This new + // update should contain the new child of LogicalDNS2. + + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, { + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService2, + }}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the cluster update to be written to the update buffer.") + } +} + +// TestUpdateRootClusterAggregateThenChangeRootToEDS tests the situation where +// you have a fully updated aggregate cluster (where AggregateCluster success +// test gets you) as the root cluster, then you update that root cluster to a +// cluster of type EDS. +func (s) TestUpdateRootClusterAggregateThenChangeRootToEDS(t *testing.T) { + // This initial code is the same as the test for the aggregate success case, + // except without validations. This will get this test to the point where it + // can update the root cluster to one of type EDS. + ch, fakeClient := setupTests(t) + ch.updateRootCluster(aggregateClusterService) + + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: aggregateClusterService, + PrioritizedClusterNames: []string{edsService, logicalDNSService}, + }, nil) + fakeClient.WaitForWatchCluster(ctx) + fakeClient.WaitForWatchCluster(ctx) + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, nil) + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, nil) + + select { + case <-ch.updateChannel: + case <-ctx.Done(): + t.Fatal("Timed out waiting for the cluster update to be written to the update buffer.") + } + + // Changes the root aggregate cluster to a EDS cluster. This should delete + // the root aggregate cluster and all of it's children by successfully + // canceling the watches for them. + ch.updateRootCluster(edsService2) + + // Reads from the cancel channel, should first be type Aggregate, then EDS + // then Logical DNS. + clusterNameDeleted, err := fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != aggregateClusterService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + + clusterNameDeleted, err = fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != edsService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + + clusterNameDeleted, err = fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != logicalDNSService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + + // After deletion, it should start a watch for the EDS Cluster. The behavior + // for this EDS Cluster receiving an update from xds client and then + // successfully writing an update to send back to CDS is already tested in + // the updateEDS success case. + gotCluster, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService2 { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, edsService2) + } +} + +// TestHandleRespInvokedWithError tests that when handleResp is invoked with an +// error, that the error is successfully written to the update buffer. +func (s) TestHandleRespInvokedWithError(t *testing.T) { + ch, fakeClient := setupTests(t) + ch.updateRootCluster(edsService) + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{}, errors.New("some error")) + select { + case chu := <-ch.updateChannel: + if chu.err.Error() != "some error" { + t.Fatalf("Did not receive the expected error, instead received: %v", chu.err.Error()) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } +} + +// TestSwitchClusterNodeBetweenLeafAndAggregated tests having an existing +// cluster node switch between a leaf and an aggregated cluster. When the +// cluster switches from a leaf to an aggregated cluster, it should add +// children, and when it switches back to a leaf, it should delete those new +// children and also successfully write a cluster update to the update buffer. +func (s) TestSwitchClusterNodeBetweenLeafAndAggregated(t *testing.T) { + // Getting the test to the point where there's a root cluster which is a eds + // leaf. + ch, fakeClient := setupTests(t) + ch.updateRootCluster(edsService2) + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService2, + }, nil) + select { + case <-ch.updateChannel: + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } + // Switch the cluster to an aggregate cluster, this should cause two new + // child watches to be created. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: edsService2, + PrioritizedClusterNames: []string{edsService, logicalDNSService}, + }, nil) + + // xds client should be called to start a watch for one of the child + // clusters of the aggregate. The order of the children in the update + // written to the buffer to send to CDS matters, however there is no + // guarantee on the order it will start the watches of the children. + gotCluster, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService { + if gotCluster != logicalDNSService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, edsService) + } + } + + // xds client should then be called to start a watch for the second child + // cluster. + gotCluster, err = fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService { + if gotCluster != logicalDNSService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, logicalDNSService) + } + } + + // After starting a watch for the second child cluster, there should be no + // more watches started on the xds client. + shouldNotHappenCtx, shouldNotHappenCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + gotCluster, err = fakeClient.WaitForWatchCluster(shouldNotHappenCtx) + if err == nil { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, no more watches should be started.", gotCluster) + } + + // The handleResp() call on the root aggregate cluster should not ping the + // cluster handler to try and construct an update, as the handleResp() + // callback knows that when a child is created, it cannot possibly build a + // successful update yet. Thus, there should be nothing in the update + // channel. + + shouldNotHappenCtx, shouldNotHappenCtxCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + + select { + case <-ch.updateChannel: + t.Fatal("Cluster Handler wrote an update to updateChannel when it shouldn't have, as each node in the full cluster tree has not yet received an update") + case <-shouldNotHappenCtx.Done(): + } + + // Switch the cluster back to an EDS Cluster. This should cause the two + // children to be deleted. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService2, + }, nil) + + // Should delete the two children (no guarantee of ordering deleted, which + // is ok), then successfully write an update to the update buffer as the + // full cluster tree has received updates. + clusterNameDeleted, err := fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + // No guarantee of ordering, so one of the children should be deleted first. + if clusterNameDeleted != edsService { + if clusterNameDeleted != logicalDNSService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want either: %v or: %v", clusterNameDeleted, edsService, logicalDNSService) + } + } + // Then the other child should be deleted. + clusterNameDeleted, err = fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != edsService { + if clusterNameDeleted != logicalDNSService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want either: %v or: %v", clusterNameDeleted, edsService, logicalDNSService) + } + } + + // After cancelling a watch for the second child cluster, there should be no + // more watches cancelled on the xds client. + shouldNotHappenCtx, shouldNotHappenCtxCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + gotCluster, err = fakeClient.WaitForCancelClusterWatch(shouldNotHappenCtx) + if err == nil { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, no more watches should be cancelled.", gotCluster) + } + + // Then an update should successfully be written to the update buffer. + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService2, + }}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } +} diff --git a/xds/internal/balancer/clusterimpl/balancer_test.go b/xds/internal/balancer/clusterimpl/balancer_test.go index 3e6ac0fd290..65ec17348f4 100644 --- a/xds/internal/balancer/clusterimpl/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/balancer_test.go @@ -20,6 +20,8 @@ package clusterimpl import ( "context" + "errors" + "fmt" "strings" "testing" "time" @@ -27,20 +29,28 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/balancer/stub" + "google.golang.org/grpc/internal/grpctest" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/resolver" - "google.golang.org/grpc/xds/internal/client/load" + xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) const ( - defaultTestTimeout = 1 * time.Second - testClusterName = "test-cluster" - testServiceName = "test-eds-service" - testLRSServerName = "test-lrs-name" + defaultTestTimeout = 1 * time.Second + defaultShortTestTimeout = 100 * time.Microsecond + + testClusterName = "test-cluster" + testServiceName = "test-eds-service" + testLRSServerName = "test-lrs-name" ) var ( @@ -54,19 +64,33 @@ var ( } ) +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func subConnFromPicker(p balancer.Picker) func() balancer.SubConn { + return func() balancer.SubConn { + scst, _ := p.Pick(balancer.PickInfo{}) + return scst.SubConn + } +} + func init() { - newRandomWRR = testutils.NewTestWRR + NewRandomWRR = testutils.NewTestWRR } // TestDropByCategory verifies that the balancer correctly drops the picks, and // that the drops are reported. -func TestDropByCategory(t *testing.T) { +func (s) TestDropByCategory(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() - builder := balancer.Get(clusterImplName) + builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() @@ -77,14 +101,12 @@ func TestDropByCategory(t *testing.T) { dropDenominator = 2 ) if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, - BalancerConfig: &lbConfig{ - Cluster: testClusterName, - EDSServiceName: testServiceName, - LRSLoadReportingServerName: newString(testLRSServerName), - DropCategories: []dropCategory{{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + DropCategories: []DropConfig{{ Category: dropReason, RequestsPerMillion: million * dropNumerator / dropDenominator, }}, @@ -150,6 +172,9 @@ func TestDropByCategory(t *testing.T) { Service: testServiceName, TotalDrops: dropCount, Drops: map[string]uint64{dropReason: dropCount}, + LocalityStats: map[string]load.LocalityData{ + assertString(xdsinternal.LocalityID{}.ToString): {RequestStats: load.RequestData{Succeeded: rpcCount - dropCount}}, + }, }} gotStatsData0 := loadStore.Stats([]string{testClusterName}) @@ -164,14 +189,12 @@ func TestDropByCategory(t *testing.T) { dropDenominator2 = 4 ) if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, - BalancerConfig: &lbConfig{ - Cluster: testClusterName, - EDSServiceName: testServiceName, - LRSLoadReportingServerName: newString(testLRSServerName), - DropCategories: []dropCategory{{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + DropCategories: []DropConfig{{ Category: dropReason2, RequestsPerMillion: million * dropNumerator2 / dropDenominator2, }}, @@ -207,6 +230,9 @@ func TestDropByCategory(t *testing.T) { Service: testServiceName, TotalDrops: dropCount2, Drops: map[string]uint64{dropReason2: dropCount2}, + LocalityStats: map[string]load.LocalityData{ + assertString(xdsinternal.LocalityID{}.ToString): {RequestStats: load.RequestData{Succeeded: rpcCount - dropCount2}}, + }, }} gotStatsData1 := loadStore.Stats([]string{testClusterName}) @@ -217,27 +243,24 @@ func TestDropByCategory(t *testing.T) { // TestDropCircuitBreaking verifies that the balancer correctly drops the picks // due to circuit breaking, and that the drops are reported. -func TestDropCircuitBreaking(t *testing.T) { +func (s) TestDropCircuitBreaking(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() - builder := balancer.Get(clusterImplName) + builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() var maxRequest uint32 = 50 if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, - BalancerConfig: &lbConfig{ - Cluster: testClusterName, - EDSServiceName: testServiceName, - LRSLoadReportingServerName: newString(testLRSServerName), - MaxConcurrentRequests: &maxRequest, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + MaxConcurrentRequests: &maxRequest, ChildPolicy: &internalserviceconfig.BalancerConfig{ Name: roundrobin.Name, }, @@ -317,6 +340,9 @@ func TestDropCircuitBreaking(t *testing.T) { Cluster: testClusterName, Service: testServiceName, TotalDrops: uint64(maxRequest), + LocalityStats: map[string]load.LocalityData{ + assertString(xdsinternal.LocalityID{}.ToString): {RequestStats: load.RequestData{Succeeded: uint64(rpcCount - maxRequest + 50)}}, + }, }} gotStatsData0 := loadStore.Stats([]string{testClusterName}) @@ -324,3 +350,452 @@ func TestDropCircuitBreaking(t *testing.T) { t.Fatalf("got unexpected drop reports, diff (-got, +want): %v", diff) } } + +// TestPickerUpdateAfterClose covers the case where a child policy sends a +// picker update after the cluster_impl policy is closed. Because picker updates +// are handled in the run() goroutine, which exits before Close() returns, we +// expect the above picker update to be dropped. +func (s) TestPickerUpdateAfterClose(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + + // Create a stub balancer which waits for the cluster_impl policy to be + // closed before sending a picker update (upon receipt of a subConn state + // change). + closeCh := make(chan struct{}) + const childPolicyName = "stubBalancer-TestPickerUpdateAfterClose" + stub.Register(childPolicyName, stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { + // Create a subConn which will be used later on to test the race + // between UpdateSubConnState() and Close(). + bd.ClientConn.NewSubConn(ccs.ResolverState.Addresses, balancer.NewSubConnOptions{}) + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, _ balancer.SubConn, _ balancer.SubConnState) { + go func() { + // Wait for Close() to be called on the parent policy before + // sending the picker update. + <-closeCh + bd.ClientConn.UpdateState(balancer.State{ + Picker: base.NewErrPicker(errors.New("dummy error picker")), + }) + }() + }, + }) + + var maxRequest uint32 = 50 + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + MaxConcurrentRequests: &maxRequest, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: childPolicyName, + }, + }, + }); err != nil { + b.Close() + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + // Send a subConn state change to trigger a picker update. The stub balancer + // that we use as the child policy will not send a picker update until the + // parent policy is closed. + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.Close() + close(closeCh) + + select { + case <-cc.NewPickerCh: + t.Fatalf("unexpected picker update after balancer is closed") + case <-time.After(defaultShortTestTimeout): + } +} + +// TestClusterNameInAddressAttributes covers the case that cluster name is +// attached to the subconn address attributes. +func (s) TestClusterNameInAddressAttributes(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + // This should get the connecting picker. + p0 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p0.Pick(balancer.PickInfo{}) + if err != balancer.ErrNoSubConnAvailable { + t.Fatalf("picker.Pick, got _,%v, want Err=%v", err, balancer.ErrNoSubConnAvailable) + } + } + + addrs1 := <-cc.NewSubConnAddrsCh + if got, want := addrs1[0].Addr, testBackendAddrs[0].Addr; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + cn, ok := internal.GetXDSHandshakeClusterName(addrs1[0].Attributes) + if !ok || cn != testClusterName { + t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, testClusterName) + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // Test pick with one backend. + p1 := <-cc.NewPickerCh + const rpcCount = 20 + for i := 0; i < rpcCount; i++ { + gotSCSt, err := p1.Pick(balancer.PickInfo{}) + if err != nil || !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1) + } + if gotSCSt.Done != nil { + gotSCSt.Done(balancer.DoneInfo{}) + } + } + + const testClusterName2 = "test-cluster-2" + var addr2 = resolver.Address{Addr: "2.2.2.2"} + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: []resolver.Address{addr2}}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName2, + EDSServiceName: testServiceName, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + addrs2 := <-cc.NewSubConnAddrsCh + if got, want := addrs2[0].Addr, addr2.Addr; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + // New addresses should have the new cluster name. + cn2, ok := internal.GetXDSHandshakeClusterName(addrs2[0].Attributes) + if !ok || cn2 != testClusterName2 { + t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, testClusterName2) + } +} + +// TestReResolution verifies that when a SubConn turns transient failure, +// re-resolution is triggered. +func (s) TestReResolution(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + // This should get the connecting picker. + p0 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p0.Pick(balancer.PickInfo{}) + if err != balancer.ErrNoSubConnAvailable { + t.Fatalf("picker.Pick, got _,%v, want Err=%v", err, balancer.ErrNoSubConnAvailable) + } + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + // This should get the transient failure picker. + p1 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p1.Pick(balancer.PickInfo{}) + if err == nil { + t.Fatalf("picker.Pick, got _,%v, want not nil", err) + } + } + + // The transient failure should trigger a re-resolution. + select { + case <-cc.ResolveNowCh: + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for ResolveNow()") + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // Test pick with one backend. + p2 := <-cc.NewPickerCh + want := []balancer.SubConn{sc1} + if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { + t.Fatalf("want %v, got %v", want, err) + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + // This should get the transient failure picker. + p3 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p3.Pick(balancer.PickInfo{}) + if err == nil { + t.Fatalf("picker.Pick, got _,%v, want not nil", err) + } + } + + // The transient failure should trigger a re-resolution. + select { + case <-cc.ResolveNowCh: + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for ResolveNow()") + } +} + +func (s) TestLoadReporting(t *testing.T) { + var testLocality = xdsinternal.LocalityID{ + Region: "test-region", + Zone: "test-zone", + SubZone: "test-sub-zone", + } + + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + addrs := make([]resolver.Address, len(testBackendAddrs)) + for i, a := range testBackendAddrs { + addrs[i] = xdsinternal.SetLocalityID(a, testLocality) + } + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + // Locality: testLocality, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + got, err := xdsC.WaitForReportLoad(ctx) + if err != nil { + t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) + } + if got.Server != testLRSServerName { + t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerName) + } + + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + // This should get the connecting picker. + p0 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p0.Pick(balancer.PickInfo{}) + if err != balancer.ErrNoSubConnAvailable { + t.Fatalf("picker.Pick, got _,%v, want Err=%v", err, balancer.ErrNoSubConnAvailable) + } + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // Test pick with one backend. + p1 := <-cc.NewPickerCh + const successCount = 5 + for i := 0; i < successCount; i++ { + gotSCSt, err := p1.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1) + } + gotSCSt.Done(balancer.DoneInfo{}) + } + const errorCount = 5 + for i := 0; i < errorCount; i++ { + gotSCSt, err := p1.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1) + } + gotSCSt.Done(balancer.DoneInfo{Err: fmt.Errorf("error")}) + } + + // Dump load data from the store and compare with expected counts. + loadStore := xdsC.LoadStore() + if loadStore == nil { + t.Fatal("loadStore is nil in xdsClient") + } + sds := loadStore.Stats([]string{testClusterName}) + if len(sds) == 0 { + t.Fatalf("loads for cluster %v not found in store", testClusterName) + } + sd := sds[0] + if sd.Cluster != testClusterName || sd.Service != testServiceName { + t.Fatalf("got unexpected load for %q, %q, want %q, %q", sd.Cluster, sd.Service, testClusterName, testServiceName) + } + testLocalityJSON, _ := testLocality.ToString() + localityData, ok := sd.LocalityStats[testLocalityJSON] + if !ok { + t.Fatalf("loads for %v not found in store", testLocality) + } + reqStats := localityData.RequestStats + if reqStats.Succeeded != successCount { + t.Errorf("got succeeded %v, want %v", reqStats.Succeeded, successCount) + } + if reqStats.Errored != errorCount { + t.Errorf("got errord %v, want %v", reqStats.Errored, errorCount) + } + if reqStats.InProgress != 0 { + t.Errorf("got inProgress %v, want %v", reqStats.InProgress, 0) + } + + b.Close() + if err := xdsC.WaitForCancelReportLoad(ctx); err != nil { + t.Fatalf("unexpected error waiting form load report to be canceled: %v", err) + } +} + +// TestUpdateLRSServer covers the cases +// - the init config specifies "" as the LRS server +// - config modifies LRS server to a different string +// - config sets LRS server to nil to stop load reporting +func (s) TestUpdateLRSServer(t *testing.T) { + var testLocality = xdsinternal.LocalityID{ + Region: "test-region", + Zone: "test-zone", + SubZone: "test-sub-zone", + } + + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + addrs := make([]resolver.Address, len(testBackendAddrs)) + for i, a := range testBackendAddrs { + addrs[i] = xdsinternal.SetLocalityID(a, testLocality) + } + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(""), + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + got, err := xdsC.WaitForReportLoad(ctx) + if err != nil { + t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) + } + if got.Server != "" { + t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, "") + } + + // Update LRS server to a different name. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + if err := xdsC.WaitForCancelReportLoad(ctx); err != nil { + t.Fatalf("unexpected error waiting form load report to be canceled: %v", err) + } + got2, err2 := xdsC.WaitForReportLoad(ctx) + if err2 != nil { + t.Fatalf("xdsClient.ReportLoad failed with error: %v", err2) + } + if got2.Server != testLRSServerName { + t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got2.Server, testLRSServerName) + } + + // Update LRS server to nil, to disable LRS. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: nil, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + if err := xdsC.WaitForCancelReportLoad(ctx); err != nil { + t.Fatalf("unexpected error waiting form load report to be canceled: %v", err) + } + + shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultShortTestTimeout) + defer shortCancel() + if s, err := xdsC.WaitForReportLoad(shortCtx); err != context.DeadlineExceeded { + t.Fatalf("unexpected load report to server: %q", s) + } +} + +func assertString(f func() (string, error)) string { + s, err := f() + if err != nil { + panic(err.Error()) + } + return s +} diff --git a/xds/internal/balancer/clusterimpl/clusterimpl.go b/xds/internal/balancer/clusterimpl/clusterimpl.go index 4e4af5a02b4..03d357b1f4e 100644 --- a/xds/internal/balancer/clusterimpl/clusterimpl.go +++ b/xds/internal/balancer/clusterimpl/clusterimpl.go @@ -26,107 +26,132 @@ package clusterimpl import ( "encoding/json" "fmt" + "sync" + "sync/atomic" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/buffer" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" + xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/balancer/loadstore" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) const ( - clusterImplName = "xds_cluster_impl_experimental" + // Name is the name of the cluster_impl balancer. + Name = "xds_cluster_impl_experimental" defaultRequestCountMax = 1024 ) func init() { - balancer.Register(clusterImplBB{}) + balancer.Register(bb{}) } -var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } +type bb struct{} -type clusterImplBB struct{} - -func (clusterImplBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &clusterImplBalancer{ ClientConn: cc, bOpts: bOpts, closed: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), loadWrapper: loadstore.NewWrapper(), + scWrappers: make(map[balancer.SubConn]*scWrapper), pickerUpdateCh: buffer.NewUnbounded(), requestCountMax: defaultRequestCountMax, } b.logger = prefixLogger(b) - - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.xdsC = client go b.run() - b.logger.Infof("Created") return b } -func (clusterImplBB) Name() string { - return clusterImplName +func (bb) Name() string { + return Name } -func (clusterImplBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(c) } -// xdsClientInterface contains only the xds_client methods needed by LRS -// balancer. It's defined so we can override xdsclient in tests. -type xdsClientInterface interface { - ReportLoad(server string) (*load.Store, func()) - Close() -} - type clusterImplBalancer struct { balancer.ClientConn - bOpts balancer.BuildOptions + + // mu guarantees mutual exclusion between Close() and handling of picker + // update to the parent ClientConn in run(). It's to make sure that the + // run() goroutine doesn't send picker update to parent after the balancer + // is closed. + // + // It's only used by the run() goroutine, but not the other exported + // functions. Because the exported functions are guaranteed to be + // synchronized with Close(). + mu sync.Mutex closed *grpcsync.Event - logger *grpclog.PrefixLogger - xdsC xdsClientInterface + done *grpcsync.Event - config *lbConfig + bOpts balancer.BuildOptions + logger *grpclog.PrefixLogger + xdsClient xdsclient.XDSClient + + config *LBConfig childLB balancer.Balancer cancelLoadReport func() - clusterName string edsServiceName string - lrsServerName string + lrsServerName *string loadWrapper *loadstore.Wrapper - // childState/drops/requestCounter can only be accessed in run(). And run() - // is the only goroutine that sends picker to the parent ClientConn. All + clusterNameMu sync.Mutex + clusterName string + + scWrappersMu sync.Mutex + // The SubConns passed to the child policy are wrapped in a wrapper, to keep + // locality ID. But when the parent ClientConn sends updates, it's going to + // give the original SubConn, not the wrapper. But the child policies only + // know about the wrapper, so when forwarding SubConn updates, they must be + // sent for the wrappers. + // + // This keeps a map from original SubConn to wrapper, so that when + // forwarding the SubConn state update, the child policy will get the + // wrappers. + scWrappers map[balancer.SubConn]*scWrapper + + // childState/drops/requestCounter keeps the state used by the most recently + // generated picker. All fields can only be accessed in run(). And run() is + // the only goroutine that sends picker to the parent ClientConn. All // requests to update picker need to be sent to pickerUpdateCh. - childState balancer.State - drops []*dropper - requestCounter *xdsclient.ServiceRequestsCounter - requestCountMax uint32 - pickerUpdateCh *buffer.Unbounded + childState balancer.State + dropCategories []DropConfig // The categories for drops. + drops []*dropper + requestCounterCluster string // The cluster name for the request counter. + requestCounterService string // The service name for the request counter. + requestCounter *xdsclient.ClusterRequestsCounter + requestCountMax uint32 + pickerUpdateCh *buffer.Unbounded } // updateLoadStore checks the config for load store, and decides whether it // needs to restart the load reporting stream. -func (cib *clusterImplBalancer) updateLoadStore(newConfig *lbConfig) error { +func (b *clusterImplBalancer) updateLoadStore(newConfig *LBConfig) error { var updateLoadClusterAndService bool // ClusterName is different, restart. ClusterName is from ClusterName and - // EdsServiceName. - if cib.clusterName != newConfig.Cluster { + // EDSServiceName. + clusterName := b.getClusterName() + if clusterName != newConfig.Cluster { updateLoadClusterAndService = true - cib.clusterName = newConfig.Cluster + b.setClusterName(newConfig.Cluster) + clusterName = newConfig.Cluster } - if cib.edsServiceName != newConfig.EDSServiceName { + if b.edsServiceName != newConfig.EDSServiceName { updateLoadClusterAndService = true - cib.edsServiceName = newConfig.EDSServiceName + b.edsServiceName = newConfig.EDSServiceName } if updateLoadClusterAndService { // This updates the clusterName and serviceName that will be reported @@ -137,39 +162,66 @@ func (cib *clusterImplBalancer) updateLoadStore(newConfig *lbConfig) error { // On the other hand, this will almost never happen. Each LRS policy // shouldn't get updated config. The parent should do a graceful switch // when the clusterName or serviceName is changed. - cib.loadWrapper.UpdateClusterAndService(cib.clusterName, cib.edsServiceName) + b.loadWrapper.UpdateClusterAndService(clusterName, b.edsServiceName) } + var ( + stopOldLoadReport bool + startNewLoadReport bool + ) + // Check if it's necessary to restart load report. - var newLRSServerName string - if newConfig.LRSLoadReportingServerName != nil { - newLRSServerName = *newConfig.LRSLoadReportingServerName - } - if cib.lrsServerName != newLRSServerName { - // LrsLoadReportingServerName is different, load should be report to a - // different server, restart. - cib.lrsServerName = newLRSServerName - if cib.cancelLoadReport != nil { - cib.cancelLoadReport() - cib.cancelLoadReport = nil + if b.lrsServerName == nil { + if newConfig.LoadReportingServerName != nil { + // Old is nil, new is not nil, start new LRS. + b.lrsServerName = newConfig.LoadReportingServerName + startNewLoadReport = true + } + // Old is nil, new is nil, do nothing. + } else if newConfig.LoadReportingServerName == nil { + // Old is not nil, new is nil, stop old, don't start new. + b.lrsServerName = newConfig.LoadReportingServerName + stopOldLoadReport = true + } else { + // Old is not nil, new is not nil, compare string values, if + // different, stop old and start new. + if *b.lrsServerName != *newConfig.LoadReportingServerName { + b.lrsServerName = newConfig.LoadReportingServerName + stopOldLoadReport = true + startNewLoadReport = true } + } + + if stopOldLoadReport { + if b.cancelLoadReport != nil { + b.cancelLoadReport() + b.cancelLoadReport = nil + if !startNewLoadReport { + // If a new LRS stream will be started later, no need to update + // it to nil here. + b.loadWrapper.UpdateLoadStore(nil) + } + } + } + if startNewLoadReport { var loadStore *load.Store - if cib.xdsC != nil { - loadStore, cib.cancelLoadReport = cib.xdsC.ReportLoad(cib.lrsServerName) + if b.xdsClient != nil { + loadStore, b.cancelLoadReport = b.xdsClient.ReportLoad(*b.lrsServerName) } - cib.loadWrapper.UpdateLoadStore(loadStore) + b.loadWrapper.UpdateLoadStore(loadStore) } return nil } -func (cib *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - if cib.closed.HasFired() { - cib.logger.Warningf("xds: received ClientConnState {%+v} after clusterImplBalancer was closed", s) +func (b *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + if b.closed.HasFired() { + b.logger.Warningf("xds: received ClientConnState {%+v} after clusterImplBalancer was closed", s) return nil } - newConfig, ok := s.BalancerConfig.(*lbConfig) + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) + newConfig, ok := s.BalancerConfig.(*LBConfig) if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) } @@ -182,59 +234,32 @@ func (cib *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState return fmt.Errorf("balancer %q not registered", newConfig.ChildPolicy.Name) } + if b.xdsClient == nil { + c := xdsclient.FromResolverState(s.ResolverState) + if c == nil { + return balancer.ErrBadResolverState + } + b.xdsClient = c + } + // Update load reporting config. This needs to be done before updating the // child policy because we need the loadStore from the updated client to be // passed to the ccWrapper, so that the next picker from the child policy // will pick up the new loadStore. - if err := cib.updateLoadStore(newConfig); err != nil { + if err := b.updateLoadStore(newConfig); err != nil { return err } - // Compare new drop config. And update picker if it's changed. - var updatePicker bool - if cib.config == nil || !equalDropCategories(cib.config.DropCategories, newConfig.DropCategories) { - cib.drops = make([]*dropper, 0, len(newConfig.DropCategories)) - for _, c := range newConfig.DropCategories { - cib.drops = append(cib.drops, newDropper(c)) - } - updatePicker = true - } - - // Compare cluster name. And update picker if it's changed, because circuit - // breaking's stream counter will be different. - if cib.config == nil || cib.config.Cluster != newConfig.Cluster { - cib.requestCounter = xdsclient.GetServiceRequestsCounter(newConfig.Cluster) - updatePicker = true - } - // Compare upper bound of stream count. And update picker if it's changed. - // This is also for circuit breaking. - var newRequestCountMax uint32 = 1024 - if newConfig.MaxConcurrentRequests != nil { - newRequestCountMax = *newConfig.MaxConcurrentRequests - } - if cib.requestCountMax != newRequestCountMax { - cib.requestCountMax = newRequestCountMax - updatePicker = true - } - - if updatePicker { - cib.pickerUpdateCh.Put(&dropConfigs{ - drops: cib.drops, - requestCounter: cib.requestCounter, - requestCountMax: cib.requestCountMax, - }) - } - // If child policy is a different type, recreate the sub-balancer. - if cib.config == nil || cib.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { - if cib.childLB != nil { - cib.childLB.Close() + if b.config == nil || b.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { + if b.childLB != nil { + b.childLB.Close() } - cib.childLB = bb.Build(cib, cib.bOpts) + b.childLB = bb.Build(b, b.bOpts) } - cib.config = newConfig + b.config = newConfig - if cib.childLB == nil { + if b.childLB == nil { // This is not an expected situation, and should be super rare in // practice. // @@ -244,85 +269,274 @@ func (cib *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState return fmt.Errorf("child policy is nil, this means balancer %q's Build() returned nil", newConfig.ChildPolicy.Name) } + // Notify run() of this new config, in case drop and request counter need + // update (which means a new picker needs to be generated). + b.pickerUpdateCh.Put(newConfig) + // Addresses and sub-balancer config are sent to sub-balancer. - return cib.childLB.UpdateClientConnState(balancer.ClientConnState{ + return b.childLB.UpdateClientConnState(balancer.ClientConnState{ ResolverState: s.ResolverState, - BalancerConfig: cib.config.ChildPolicy.Config, + BalancerConfig: b.config.ChildPolicy.Config, }) } -func (cib *clusterImplBalancer) ResolverError(err error) { - if cib.closed.HasFired() { - cib.logger.Warningf("xds: received resolver error {%+v} after clusterImplBalancer was closed", err) +func (b *clusterImplBalancer) ResolverError(err error) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received resolver error {%+v} after clusterImplBalancer was closed", err) return } - if cib.childLB != nil { - cib.childLB.ResolverError(err) + if b.childLB != nil { + b.childLB.ResolverError(err) } } -func (cib *clusterImplBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { - if cib.closed.HasFired() { - cib.logger.Warningf("xds: received subconn state change {%+v, %+v} after clusterImplBalancer was closed", sc, s) +func (b *clusterImplBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received subconn state change {%+v, %+v} after clusterImplBalancer was closed", sc, s) return } - if cib.childLB != nil { - cib.childLB.UpdateSubConnState(sc, s) + // Trigger re-resolution when a SubConn turns transient failure. This is + // necessary for the LogicalDNS in cluster_resolver policy to re-resolve. + // + // Note that this happens not only for the addresses from DNS, but also for + // EDS (cluster_impl doesn't know if it's DNS or EDS, only the parent + // knows). The parent priority policy is configured to ignore re-resolution + // signal from the EDS children. + if s.ConnectivityState == connectivity.TransientFailure { + b.ClientConn.ResolveNow(resolver.ResolveNowOptions{}) + } + + b.scWrappersMu.Lock() + if scw, ok := b.scWrappers[sc]; ok { + sc = scw + if s.ConnectivityState == connectivity.Shutdown { + // Remove this SubConn from the map on Shutdown. + delete(b.scWrappers, scw.SubConn) + } + } + b.scWrappersMu.Unlock() + if b.childLB != nil { + b.childLB.UpdateSubConnState(sc, s) } } -func (cib *clusterImplBalancer) Close() { - if cib.childLB != nil { - cib.childLB.Close() - cib.childLB = nil +func (b *clusterImplBalancer) Close() { + b.mu.Lock() + b.closed.Fire() + b.mu.Unlock() + + if b.childLB != nil { + b.childLB.Close() + b.childLB = nil + } + <-b.done.Done() + b.logger.Infof("Shutdown") +} + +func (b *clusterImplBalancer) ExitIdle() { + if b.childLB == nil { + return + } + if ei, ok := b.childLB.(balancer.ExitIdler); ok { + ei.ExitIdle() + return + } + // Fallback for children that don't support ExitIdle -- connect to all + // SubConns. + for _, sc := range b.scWrappers { + sc.Connect() } - cib.xdsC.Close() - cib.closed.Fire() - cib.logger.Infof("Shutdown") } // Override methods to accept updates from the child LB. -func (cib *clusterImplBalancer) UpdateState(state balancer.State) { +func (b *clusterImplBalancer) UpdateState(state balancer.State) { // Instead of updating parent ClientConn inline, send state to run(). - cib.pickerUpdateCh.Put(state) + b.pickerUpdateCh.Put(state) +} + +func (b *clusterImplBalancer) setClusterName(n string) { + b.clusterNameMu.Lock() + defer b.clusterNameMu.Unlock() + b.clusterName = n +} + +func (b *clusterImplBalancer) getClusterName() string { + b.clusterNameMu.Lock() + defer b.clusterNameMu.Unlock() + return b.clusterName +} + +// scWrapper is a wrapper of SubConn with locality ID. The locality ID can be +// retrieved from the addresses when creating SubConn. +// +// All SubConns passed to the child policies are wrapped in this, so that the +// picker can get the localityID from the picked SubConn, and do load reporting. +// +// After wrapping, all SubConns to and from the parent ClientConn (e.g. for +// SubConn state update, update/remove SubConn) must be the original SubConns. +// All SubConns to and from the child policy (NewSubConn, forwarding SubConn +// state update) must be the wrapper. The balancer keeps a map from the original +// SubConn to the wrapper for this purpose. +type scWrapper struct { + balancer.SubConn + // locality needs to be atomic because it can be updated while being read by + // the picker. + locality atomic.Value // type xdsinternal.LocalityID +} + +func (scw *scWrapper) updateLocalityID(lID xdsinternal.LocalityID) { + scw.locality.Store(lID) +} + +func (scw *scWrapper) localityID() xdsinternal.LocalityID { + lID, _ := scw.locality.Load().(xdsinternal.LocalityID) + return lID +} + +func (b *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + clusterName := b.getClusterName() + newAddrs := make([]resolver.Address, len(addrs)) + var lID xdsinternal.LocalityID + for i, addr := range addrs { + newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) + lID = xdsinternal.GetLocalityID(newAddrs[i]) + } + sc, err := b.ClientConn.NewSubConn(newAddrs, opts) + if err != nil { + return nil, err + } + // Wrap this SubConn in a wrapper, and add it to the map. + b.scWrappersMu.Lock() + ret := &scWrapper{SubConn: sc} + ret.updateLocalityID(lID) + b.scWrappers[sc] = ret + b.scWrappersMu.Unlock() + return ret, nil +} + +func (b *clusterImplBalancer) RemoveSubConn(sc balancer.SubConn) { + scw, ok := sc.(*scWrapper) + if !ok { + b.ClientConn.RemoveSubConn(sc) + return + } + // Remove the original SubConn from the parent ClientConn. + // + // Note that we don't remove this SubConn from the scWrappers map. We will + // need it to forward the final SubConn state Shutdown to the child policy. + // + // This entry is kept in the map until it's state is changes to Shutdown, + // and will be deleted in UpdateSubConnState(). + b.ClientConn.RemoveSubConn(scw.SubConn) +} + +func (b *clusterImplBalancer) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { + clusterName := b.getClusterName() + newAddrs := make([]resolver.Address, len(addrs)) + var lID xdsinternal.LocalityID + for i, addr := range addrs { + newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) + lID = xdsinternal.GetLocalityID(newAddrs[i]) + } + if scw, ok := sc.(*scWrapper); ok { + scw.updateLocalityID(lID) + // Need to get the original SubConn from the wrapper before calling + // parent ClientConn. + sc = scw.SubConn + } + b.ClientConn.UpdateAddresses(sc, newAddrs) } type dropConfigs struct { drops []*dropper - requestCounter *xdsclient.ServiceRequestsCounter + requestCounter *xdsclient.ClusterRequestsCounter requestCountMax uint32 } -func (cib *clusterImplBalancer) run() { +// handleDropAndRequestCount compares drop and request counter in newConfig with +// the one currently used by picker. It returns a new dropConfigs if a new +// picker needs to be generated, otherwise it returns nil. +func (b *clusterImplBalancer) handleDropAndRequestCount(newConfig *LBConfig) *dropConfigs { + // Compare new drop config. And update picker if it's changed. + var updatePicker bool + if !equalDropCategories(b.dropCategories, newConfig.DropCategories) { + b.dropCategories = newConfig.DropCategories + b.drops = make([]*dropper, 0, len(newConfig.DropCategories)) + for _, c := range newConfig.DropCategories { + b.drops = append(b.drops, newDropper(c)) + } + updatePicker = true + } + + // Compare cluster name. And update picker if it's changed, because circuit + // breaking's stream counter will be different. + if b.requestCounterCluster != newConfig.Cluster || b.requestCounterService != newConfig.EDSServiceName { + b.requestCounterCluster = newConfig.Cluster + b.requestCounterService = newConfig.EDSServiceName + b.requestCounter = xdsclient.GetClusterRequestsCounter(newConfig.Cluster, newConfig.EDSServiceName) + updatePicker = true + } + // Compare upper bound of stream count. And update picker if it's changed. + // This is also for circuit breaking. + var newRequestCountMax uint32 = 1024 + if newConfig.MaxConcurrentRequests != nil { + newRequestCountMax = *newConfig.MaxConcurrentRequests + } + if b.requestCountMax != newRequestCountMax { + b.requestCountMax = newRequestCountMax + updatePicker = true + } + + if !updatePicker { + return nil + } + return &dropConfigs{ + drops: b.drops, + requestCounter: b.requestCounter, + requestCountMax: b.requestCountMax, + } +} + +func (b *clusterImplBalancer) run() { + defer b.done.Fire() for { select { - case update := <-cib.pickerUpdateCh.Get(): - cib.pickerUpdateCh.Load() + case update := <-b.pickerUpdateCh.Get(): + b.pickerUpdateCh.Load() + b.mu.Lock() + if b.closed.HasFired() { + b.mu.Unlock() + return + } switch u := update.(type) { case balancer.State: - cib.childState = u - cib.ClientConn.UpdateState(balancer.State{ - ConnectivityState: cib.childState.ConnectivityState, - Picker: newDropPicker(cib.childState, &dropConfigs{ - drops: cib.drops, - requestCounter: cib.requestCounter, - requestCountMax: cib.requestCountMax, - }, cib.loadWrapper), + b.childState = u + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: b.childState.ConnectivityState, + Picker: newPicker(b.childState, &dropConfigs{ + drops: b.drops, + requestCounter: b.requestCounter, + requestCountMax: b.requestCountMax, + }, b.loadWrapper), }) - case *dropConfigs: - cib.drops = u.drops - cib.requestCounter = u.requestCounter - if cib.childState.Picker != nil { - cib.ClientConn.UpdateState(balancer.State{ - ConnectivityState: cib.childState.ConnectivityState, - Picker: newDropPicker(cib.childState, u, cib.loadWrapper), + case *LBConfig: + dc := b.handleDropAndRequestCount(u) + if dc != nil && b.childState.Picker != nil { + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: b.childState.ConnectivityState, + Picker: newPicker(b.childState, dc, b.loadWrapper), }) } } - case <-cib.closed.Done(): + b.mu.Unlock() + case <-b.closed.Done(): + if b.cancelLoadReport != nil { + b.cancelLoadReport() + b.cancelLoadReport = nil + } return } } diff --git a/xds/internal/balancer/clusterimpl/config.go b/xds/internal/balancer/clusterimpl/config.go index 548ab34bce4..51ff654f6eb 100644 --- a/xds/internal/balancer/clusterimpl/config.go +++ b/xds/internal/balancer/clusterimpl/config.go @@ -25,32 +25,33 @@ import ( "google.golang.org/grpc/serviceconfig" ) -type dropCategory struct { +// DropConfig contains the category, and drop ratio. +type DropConfig struct { Category string RequestsPerMillion uint32 } -// lbConfig is the balancer config for weighted_target. -type lbConfig struct { - serviceconfig.LoadBalancingConfig +// LBConfig is the balancer config for cluster_impl balancer. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` - Cluster string - EDSServiceName string - LRSLoadReportingServerName *string - MaxConcurrentRequests *uint32 - DropCategories []dropCategory - ChildPolicy *internalserviceconfig.BalancerConfig + Cluster string `json:"cluster,omitempty"` + EDSServiceName string `json:"edsServiceName,omitempty"` + LoadReportingServerName *string `json:"lrsLoadReportingServerName,omitempty"` + MaxConcurrentRequests *uint32 `json:"maxConcurrentRequests,omitempty"` + DropCategories []DropConfig `json:"dropCategories,omitempty"` + ChildPolicy *internalserviceconfig.BalancerConfig `json:"childPolicy,omitempty"` } -func parseConfig(c json.RawMessage) (*lbConfig, error) { - var cfg lbConfig +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, err } return &cfg, nil } -func equalDropCategories(a, b []dropCategory) bool { +func equalDropCategories(a, b []DropConfig) bool { if len(a) != len(b) { return false } diff --git a/xds/internal/balancer/clusterimpl/config_test.go b/xds/internal/balancer/clusterimpl/config_test.go index 89696981e2a..ccb0c5e74d9 100644 --- a/xds/internal/balancer/clusterimpl/config_test.go +++ b/xds/internal/balancer/clusterimpl/config_test.go @@ -87,7 +87,7 @@ func TestParseConfig(t *testing.T) { tests := []struct { name string js string - want *lbConfig + want *LBConfig wantErr bool }{ { @@ -105,12 +105,12 @@ func TestParseConfig(t *testing.T) { { name: "OK", js: testJSONConfig, - want: &lbConfig{ - Cluster: "test_cluster", - EDSServiceName: "test-eds", - LRSLoadReportingServerName: newString("lrs_server"), - MaxConcurrentRequests: newUint32(123), - DropCategories: []dropCategory{ + want: &LBConfig{ + Cluster: "test_cluster", + EDSServiceName: "test-eds", + LoadReportingServerName: newString("lrs_server"), + MaxConcurrentRequests: newUint32(123), + DropCategories: []DropConfig{ {Category: "drop-1", RequestsPerMillion: 314}, {Category: "drop-2", RequestsPerMillion: 159}, }, diff --git a/xds/internal/balancer/clusterimpl/picker.go b/xds/internal/balancer/clusterimpl/picker.go index 6e9d2791153..db29c550be1 100644 --- a/xds/internal/balancer/clusterimpl/picker.go +++ b/xds/internal/balancer/clusterimpl/picker.go @@ -19,16 +19,19 @@ package clusterimpl import ( + orcapb "github.com/cncf/udpa/go/udpa/data/orca/v1" "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/wrr" "google.golang.org/grpc/status" - "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) -var newRandomWRR = wrr.NewRandom +// NewRandomWRR is used when calculating drops. It's exported so that tests can +// override it. +var NewRandomWRR = wrr.NewRandom const million = 1000000 @@ -47,8 +50,8 @@ func gcd(a, b uint32) uint32 { return a } -func newDropper(c dropCategory) *dropper { - w := newRandomWRR() +func newDropper(c DropConfig) *dropper { + w := NewRandomWRR() gcdv := gcd(c.RequestsPerMillion, million) // Return true for RequestPerMillion, false for the rest. w.Add(true, int64(c.RequestsPerMillion/gcdv)) @@ -64,21 +67,30 @@ func (d *dropper) drop() (ret bool) { return d.w.Next().(bool) } +const ( + serverLoadCPUName = "cpu_utilization" + serverLoadMemoryName = "mem_utilization" +) + // loadReporter wraps the methods from the loadStore that are used here. type loadReporter interface { + CallStarted(locality string) + CallFinished(locality string, err error) + CallServerLoad(locality, name string, val float64) CallDropped(locality string) } -type dropPicker struct { +// Picker implements RPC drop, circuit breaking drop and load reporting. +type picker struct { drops []*dropper s balancer.State loadStore loadReporter - counter *client.ServiceRequestsCounter + counter *xdsclient.ClusterRequestsCounter countMax uint32 } -func newDropPicker(s balancer.State, config *dropConfigs, loadStore load.PerClusterReporter) *dropPicker { - return &dropPicker{ +func newPicker(s balancer.State, config *dropConfigs, loadStore load.PerClusterReporter) *picker { + return &picker{ drops: config.drops, s: s, loadStore: loadStore, @@ -87,13 +99,14 @@ func newDropPicker(s balancer.State, config *dropConfigs, loadStore load.PerClus } } -func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { +func (d *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { // Don't drop unless the inner picker is READY. Similar to // https://github.com/grpc/grpc-go/issues/2622. if d.s.ConnectivityState != connectivity.Ready { return d.s.Picker.Pick(info) } + // Check if this RPC should be dropped by category. for _, dp := range d.drops { if dp.drop() { if d.loadStore != nil { @@ -103,6 +116,7 @@ func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { } } + // Check if this RPC should be dropped by circuit breaking. if d.counter != nil { if err := d.counter.StartRequest(d.countMax); err != nil { // Drops by circuit breaking are reported with empty category. They @@ -112,11 +126,58 @@ func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { } return balancer.PickResult{}, status.Errorf(codes.Unavailable, err.Error()) } - pr, err := d.s.Picker.Pick(info) - if err != nil { + } + + var lIDStr string + pr, err := d.s.Picker.Pick(info) + if scw, ok := pr.SubConn.(*scWrapper); ok { + // This OK check also covers the case err!=nil, because SubConn will be + // nil. + pr.SubConn = scw.SubConn + var e error + // If locality ID isn't found in the wrapper, an empty locality ID will + // be used. + lIDStr, e = scw.localityID().ToString() + if e != nil { + logger.Infof("failed to marshal LocalityID: %#v, loads won't be reported", scw.localityID()) + } + } + + if err != nil { + if d.counter != nil { + // Release one request count if this pick fails. d.counter.EndRequest() - return pr, err } + return pr, err + } + + if d.loadStore != nil { + d.loadStore.CallStarted(lIDStr) + oldDone := pr.Done + pr.Done = func(info balancer.DoneInfo) { + if oldDone != nil { + oldDone(info) + } + d.loadStore.CallFinished(lIDStr, info.Err) + + load, ok := info.ServerLoad.(*orcapb.OrcaLoadReport) + if !ok { + return + } + d.loadStore.CallServerLoad(lIDStr, serverLoadCPUName, load.CpuUtilization) + d.loadStore.CallServerLoad(lIDStr, serverLoadMemoryName, load.MemUtilization) + for n, c := range load.RequestCost { + d.loadStore.CallServerLoad(lIDStr, n, c) + } + for n, c := range load.Utilization { + d.loadStore.CallServerLoad(lIDStr, n, c) + } + } + } + + if d.counter != nil { + // Update Done() so that when the RPC finishes, the request count will + // be released. oldDone := pr.Done pr.Done = func(doneInfo balancer.DoneInfo) { d.counter.EndRequest() @@ -124,8 +185,7 @@ func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { oldDone(doneInfo) } } - return pr, err } - return d.s.Picker.Pick(info) + return pr, err } diff --git a/xds/internal/balancer/clustermanager/balancerstateaggregator.go b/xds/internal/balancer/clustermanager/balancerstateaggregator.go index 35eb86c3590..6e0e03299f9 100644 --- a/xds/internal/balancer/clustermanager/balancerstateaggregator.go +++ b/xds/internal/balancer/clustermanager/balancerstateaggregator.go @@ -183,13 +183,18 @@ func (bsa *balancerStateAggregator) build() balancer.State { // handling the special connecting after ready, as in UpdateState(). Then a // function to calculate the aggregated connectivity state as in this // function. - var readyN, connectingN int + // + // TODO: use balancer.ConnectivityStateEvaluator to calculate the aggregated + // state. + var readyN, connectingN, idleN int for _, ps := range bsa.idToPickerState { switch ps.stateToAggregate { case connectivity.Ready: readyN++ case connectivity.Connecting: connectingN++ + case connectivity.Idle: + idleN++ } } var aggregatedState connectivity.State @@ -198,6 +203,8 @@ func (bsa *balancerStateAggregator) build() balancer.State { aggregatedState = connectivity.Ready case connectingN > 0: aggregatedState = connectivity.Connecting + case idleN > 0: + aggregatedState = connectivity.Idle default: aggregatedState = connectivity.TransientFailure } diff --git a/xds/internal/balancer/clustermanager/clustermanager.go b/xds/internal/balancer/clustermanager/clustermanager.go index 1e4dee7f5d3..318545d79b0 100644 --- a/xds/internal/balancer/clustermanager/clustermanager.go +++ b/xds/internal/balancer/clustermanager/clustermanager.go @@ -27,6 +27,7 @@ import ( "google.golang.org/grpc/grpclog" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/hierarchy" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/xds/internal/balancer/balancergroup" @@ -35,12 +36,12 @@ import ( const balancerName = "xds_cluster_manager_experimental" func init() { - balancer.Register(builder{}) + balancer.Register(bb{}) } -type builder struct{} +type bb struct{} -func (builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { b := &bal{} b.logger = prefixLogger(b) b.stateAggregator = newBalancerStateAggregator(cc, b.logger) @@ -51,11 +52,11 @@ func (builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balance return b } -func (builder) Name() string { +func (bb) Name() string { return balancerName } -func (builder) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(c) } @@ -115,7 +116,7 @@ func (b *bal) UpdateClientConnState(s balancer.ClientConnState) error { if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) } - b.logger.Infof("update with config %+v, resolver state %+v", s.BalancerConfig, s.ResolverState) + b.logger.Infof("update with config %+v, resolver state %+v", pretty.ToJSON(s.BalancerConfig), s.ResolverState) b.updateChildren(s, newConfig) return nil @@ -132,6 +133,11 @@ func (b *bal) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnStat func (b *bal) Close() { b.stateAggregator.close() b.bg.Close() + b.logger.Infof("Shutdown") +} + +func (b *bal) ExitIdle() { + b.bg.ExitIdle() } const prefix = "[xds-cluster-manager-lb %p] " diff --git a/xds/internal/balancer/clustermanager/clustermanager_test.go b/xds/internal/balancer/clustermanager/clustermanager_test.go index a40d954ad64..d3475ea3f5d 100644 --- a/xds/internal/balancer/clustermanager/clustermanager_test.go +++ b/xds/internal/balancer/clustermanager/clustermanager_test.go @@ -565,3 +565,68 @@ func TestClusterManagerForwardsBalancerBuildOptions(t *testing.T) { t.Fatal(err2) } } + +const initIdleBalancerName = "test-init-Idle-balancer" + +var errTestInitIdle = fmt.Errorf("init Idle balancer error 0") + +func init() { + stub.Register(initIdleBalancerName, stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, opts balancer.ClientConnState) error { + bd.ClientConn.NewSubConn(opts.ResolverState.Addresses, balancer.NewSubConnOptions{}) + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { + err := fmt.Errorf("wrong picker error") + if state.ConnectivityState == connectivity.Idle { + err = errTestInitIdle + } + bd.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: &testutils.TestConstPicker{Err: err}, + }) + }, + }) +} + +// TestInitialIdle covers the case that if the child reports Idle, the overall +// state will be Idle. +func TestInitialIdle(t *testing.T) { + cc := testutils.NewTestClientConn(t) + rtb := rtBuilder.Build(cc, balancer.BuildOptions{}) + + configJSON1 := `{ +"children": { + "cds:cluster_1":{ "childPolicy": [{"test-init-Idle-balancer":""}] } +} +}` + + config1, err := rtParser.ParseConfig([]byte(configJSON1)) + if err != nil { + t.Fatalf("failed to parse balancer config: %v", err) + } + + // Send the config, and an address with hierarchy path ["cluster_1"]. + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0], Attributes: nil}, + } + if err := rtb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + hierarchy.Set(wantAddrs[0], []string{"cds:cluster_1"}), + }}, + BalancerConfig: config1, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + // Verify that a subconn is created with the address, and the hierarchy path + // in the address is cleared. + for range wantAddrs { + sc := <-cc.NewSubConnCh + rtb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + } + + if state1 := <-cc.NewStateCh; state1 != connectivity.Idle { + t.Fatalf("Received aggregated state: %v, want Idle", state1) + } +} diff --git a/xds/internal/balancer/clusterresolver/clusterresolver.go b/xds/internal/balancer/clusterresolver/clusterresolver.go new file mode 100644 index 00000000000..66a5aab305e --- /dev/null +++ b/xds/internal/balancer/clusterresolver/clusterresolver.go @@ -0,0 +1,378 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package clusterresolver contains EDS balancer implementation. +package clusterresolver + +import ( + "encoding/json" + "errors" + "fmt" + + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/buffer" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +// Name is the name of the cluster_resolver balancer. +const Name = "cluster_resolver_experimental" + +var ( + errBalancerClosed = errors.New("cdsBalancer is closed") + newChildBalancer = func(bb balancer.Builder, cc balancer.ClientConn, o balancer.BuildOptions) balancer.Balancer { + return bb.Build(cc, o) + } +) + +func init() { + balancer.Register(bb{}) +} + +type bb struct{} + +// Build helps implement the balancer.Builder interface. +func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + priorityBuilder := balancer.Get(priority.Name) + if priorityBuilder == nil { + logger.Errorf("priority balancer is needed but not registered") + return nil + } + priorityConfigParser, ok := priorityBuilder.(balancer.ConfigParser) + if !ok { + logger.Errorf("priority balancer builder is not a config parser") + return nil + } + + b := &clusterResolverBalancer{ + bOpts: opts, + updateCh: buffer.NewUnbounded(), + closed: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), + + priorityBuilder: priorityBuilder, + priorityConfigParser: priorityConfigParser, + } + b.logger = prefixLogger(b) + b.logger.Infof("Created") + + b.resourceWatcher = newResourceResolver(b) + b.cc = &ccWrapper{ + ClientConn: cc, + resourceWatcher: b.resourceWatcher, + } + + go b.run() + return b +} + +func (bb) Name() string { + return Name +} + +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + var cfg LBConfig + if err := json.Unmarshal(c, &cfg); err != nil { + return nil, fmt.Errorf("unable to unmarshal balancer config %s into cluster-resolver config, error: %v", string(c), err) + } + return &cfg, nil +} + +// ccUpdate wraps a clientConn update received from gRPC (pushed from the +// xdsResolver). +type ccUpdate struct { + state balancer.ClientConnState + err error +} + +// scUpdate wraps a subConn update received from gRPC. This is directly passed +// on to the child balancer. +type scUpdate struct { + subConn balancer.SubConn + state balancer.SubConnState +} + +type exitIdle struct{} + +// clusterResolverBalancer manages xdsClient and the actual EDS balancer implementation that +// does load balancing. +// +// It currently has only an clusterResolverBalancer. Later, we may add fallback. +type clusterResolverBalancer struct { + cc balancer.ClientConn + bOpts balancer.BuildOptions + updateCh *buffer.Unbounded // Channel for updates from gRPC. + resourceWatcher *resourceResolver + logger *grpclog.PrefixLogger + closed *grpcsync.Event + done *grpcsync.Event + + priorityBuilder balancer.Builder + priorityConfigParser balancer.ConfigParser + + config *LBConfig + configRaw *serviceconfig.ParseResult + xdsClient xdsclient.XDSClient // xDS client to watch EDS resource. + attrsWithClient *attributes.Attributes // Attributes with xdsClient attached to be passed to the child policies. + + child balancer.Balancer + priorities []priorityConfig + watchUpdateReceived bool +} + +// handleClientConnUpdate handles a ClientConnUpdate received from gRPC. Good +// updates lead to registration of EDS and DNS watches. Updates with error lead +// to cancellation of existing watch and propagation of the same error to the +// child balancer. +func (b *clusterResolverBalancer) handleClientConnUpdate(update *ccUpdate) { + // We first handle errors, if any, and then proceed with handling the + // update, only if the status quo has changed. + if err := update.err; err != nil { + b.handleErrorFromUpdate(err, true) + return + } + + b.logger.Infof("Receive update from resolver, balancer config: %v", pretty.ToJSON(update.state.BalancerConfig)) + cfg, _ := update.state.BalancerConfig.(*LBConfig) + if cfg == nil { + b.logger.Warningf("xds: unexpected LoadBalancingConfig type: %T", update.state.BalancerConfig) + return + } + + b.config = cfg + b.configRaw = update.state.ResolverState.ServiceConfig + b.resourceWatcher.updateMechanisms(cfg.DiscoveryMechanisms) + + if !b.watchUpdateReceived { + // If update was not received, wait for it. + return + } + // If eds resp was received before this, the child policy was created. We + // need to generate a new balancer config and send it to the child, because + // certain fields (unrelated to EDS watch) might have changed. + if err := b.updateChildConfig(); err != nil { + b.logger.Warningf("failed to update child policy config: %v", err) + } +} + +// handleWatchUpdate handles a watch update from the xDS Client. Good updates +// lead to clientConn updates being invoked on the underlying child balancer. +func (b *clusterResolverBalancer) handleWatchUpdate(update *resourceUpdate) { + if err := update.err; err != nil { + b.logger.Warningf("Watch error from xds-client %p: %v", b.xdsClient, err) + b.handleErrorFromUpdate(err, false) + return + } + + b.logger.Infof("resource update: %+v", pretty.ToJSON(update.priorities)) + b.watchUpdateReceived = true + b.priorities = update.priorities + + // A new EDS update triggers new child configs (e.g. different priorities + // for the priority balancer), and new addresses (the endpoints come from + // the EDS response). + if err := b.updateChildConfig(); err != nil { + b.logger.Warningf("failed to update child policy's balancer config: %v", err) + } +} + +// updateChildConfig builds a balancer config from eb's cached eds resp and +// service config, and sends that to the child balancer. Note that it also +// generates the addresses, because the endpoints come from the EDS resp. +// +// If child balancer doesn't already exist, one will be created. +func (b *clusterResolverBalancer) updateChildConfig() error { + // Child was build when the first EDS resp was received, so we just build + // the config and addresses. + if b.child == nil { + b.child = newChildBalancer(b.priorityBuilder, b.cc, b.bOpts) + } + + childCfgBytes, addrs, err := buildPriorityConfigJSON(b.priorities, b.config.XDSLBPolicy) + if err != nil { + return fmt.Errorf("failed to build priority balancer config: %v", err) + } + childCfg, err := b.priorityConfigParser.ParseConfig(childCfgBytes) + if err != nil { + return fmt.Errorf("failed to parse generated priority balancer config, this should never happen because the config is generated: %v", err) + } + b.logger.Infof("build balancer config: %v", pretty.ToJSON(childCfg)) + return b.child.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: addrs, + ServiceConfig: b.configRaw, + Attributes: b.attrsWithClient, + }, + BalancerConfig: childCfg, + }) +} + +// handleErrorFromUpdate handles both the error from parent ClientConn (from CDS +// balancer) and the error from xds client (from the watcher). fromParent is +// true if error is from parent ClientConn. +// +// If the error is connection error, it should be handled for fallback purposes. +// +// If the error is resource-not-found: +// - If it's from CDS balancer (shows as a resolver error), it means LDS or CDS +// resources were removed. The EDS watch should be canceled. +// - If it's from xds client, it means EDS resource were removed. The EDS +// watcher should keep watching. +// In both cases, the sub-balancers will be receive the error. +func (b *clusterResolverBalancer) handleErrorFromUpdate(err error, fromParent bool) { + b.logger.Warningf("Received error: %v", err) + if fromParent && xdsclient.ErrType(err) == xdsclient.ErrorTypeResourceNotFound { + // This is an error from the parent ClientConn (can be the parent CDS + // balancer), and is a resource-not-found error. This means the resource + // (can be either LDS or CDS) was removed. Stop the EDS watch. + b.resourceWatcher.stop() + } + if b.child != nil { + b.child.ResolverError(err) + } else { + // If eds balancer was never created, fail the RPCs with errors. + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: base.NewErrPicker(err), + }) + } + +} + +// run is a long-running goroutine which handles all updates from gRPC and +// xdsClient. All methods which are invoked directly by gRPC or xdsClient simply +// push an update onto a channel which is read and acted upon right here. +func (b *clusterResolverBalancer) run() { + for { + select { + case u := <-b.updateCh.Get(): + b.updateCh.Load() + switch update := u.(type) { + case *ccUpdate: + b.handleClientConnUpdate(update) + case *scUpdate: + // SubConn updates are simply handed over to the underlying + // child balancer. + if b.child == nil { + b.logger.Errorf("xds: received scUpdate {%+v} with no child balancer", update) + break + } + b.child.UpdateSubConnState(update.subConn, update.state) + case exitIdle: + if b.child == nil { + b.logger.Errorf("xds: received ExitIdle with no child balancer") + break + } + // This implementation assumes the child balancer supports + // ExitIdle (but still checks for the interface's existence to + // avoid a panic if not). If the child does not, no subconns + // will be connected. + if ei, ok := b.child.(balancer.ExitIdler); ok { + ei.ExitIdle() + } + } + case u := <-b.resourceWatcher.updateChannel: + b.handleWatchUpdate(u) + + // Close results in cancellation of the EDS watch and closing of the + // underlying child policy and is the only way to exit this goroutine. + case <-b.closed.Done(): + b.resourceWatcher.stop() + + if b.child != nil { + b.child.Close() + b.child = nil + } + // This is the *ONLY* point of return from this function. + b.logger.Infof("Shutdown") + b.done.Fire() + return + } + } +} + +// Following are methods to implement the balancer interface. + +// UpdateClientConnState receives the serviceConfig (which contains the +// clusterName to watch for in CDS) and the xdsClient object from the +// xdsResolver. +func (b *clusterResolverBalancer) UpdateClientConnState(state balancer.ClientConnState) error { + if b.closed.HasFired() { + b.logger.Warningf("xds: received ClientConnState {%+v} after clusterResolverBalancer was closed", state) + return errBalancerClosed + } + + if b.xdsClient == nil { + c := xdsclient.FromResolverState(state.ResolverState) + if c == nil { + return balancer.ErrBadResolverState + } + b.xdsClient = c + b.attrsWithClient = state.ResolverState.Attributes + } + + b.updateCh.Put(&ccUpdate{state: state}) + return nil +} + +// ResolverError handles errors reported by the xdsResolver. +func (b *clusterResolverBalancer) ResolverError(err error) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received resolver error {%v} after clusterResolverBalancer was closed", err) + return + } + b.updateCh.Put(&ccUpdate{err: err}) +} + +// UpdateSubConnState handles subConn updates from gRPC. +func (b *clusterResolverBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received subConn update {%v, %v} after clusterResolverBalancer was closed", sc, state) + return + } + b.updateCh.Put(&scUpdate{subConn: sc, state: state}) +} + +// Close closes the cdsBalancer and the underlying child balancer. +func (b *clusterResolverBalancer) Close() { + b.closed.Fire() + <-b.done.Done() +} + +func (b *clusterResolverBalancer) ExitIdle() { + b.updateCh.Put(exitIdle{}) +} + +// ccWrapper overrides ResolveNow(), so that re-resolution from the child +// policies will trigger the DNS resolver in cluster_resolver balancer. +type ccWrapper struct { + balancer.ClientConn + resourceWatcher *resourceResolver +} + +func (c *ccWrapper) ResolveNow(resolver.ResolveNowOptions) { + c.resourceWatcher.resolveNow() +} diff --git a/xds/internal/balancer/clusterresolver/clusterresolver_test.go b/xds/internal/balancer/clusterresolver/clusterresolver_test.go new file mode 100644 index 00000000000..6af81f89f1f --- /dev/null +++ b/xds/internal/balancer/clusterresolver/clusterresolver_test.go @@ -0,0 +1,500 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // V2 client registration. +) + +const ( + defaultTestTimeout = 1 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + testEDSServcie = "test-eds-service-name" + testClusterName = "test-cluster-name" +) + +var ( + // A non-empty endpoints update which is expected to be accepted by the EDS + // LB policy. + defaultEndpointsUpdate = xdsclient.EndpointsUpdate{ + Localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{{Address: "endpoint1"}}, + ID: internal.LocalityID{Zone: "zone"}, + Priority: 1, + Weight: 100, + }, + }, + } +) + +func init() { + balancer.Register(bb{}) +} + +type s struct { + grpctest.Tester + + cleanup func() +} + +func (ss s) Teardown(t *testing.T) { + xdsclient.ClearAllCountersForTesting() + ss.Tester.Teardown(t) + if ss.cleanup != nil { + ss.cleanup() + } +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +const testBalancerNameFooBar = "foo.bar" + +func newNoopTestClientConn() *noopTestClientConn { + return &noopTestClientConn{} +} + +// noopTestClientConn is used in EDS balancer config update tests that only +// cover the config update handling, but not SubConn/load-balancing. +type noopTestClientConn struct { + balancer.ClientConn +} + +func (t *noopTestClientConn) NewSubConn([]resolver.Address, balancer.NewSubConnOptions) (balancer.SubConn, error) { + return nil, nil +} + +func (noopTestClientConn) Target() string { return testEDSServcie } + +type scStateChange struct { + sc balancer.SubConn + state balancer.SubConnState +} + +type fakeChildBalancer struct { + cc balancer.ClientConn + subConnState *testutils.Channel + clientConnState *testutils.Channel + resolverError *testutils.Channel +} + +func (f *fakeChildBalancer) UpdateClientConnState(state balancer.ClientConnState) error { + f.clientConnState.Send(state) + return nil +} + +func (f *fakeChildBalancer) ResolverError(err error) { + f.resolverError.Send(err) +} + +func (f *fakeChildBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + f.subConnState.Send(&scStateChange{sc: sc, state: state}) +} + +func (f *fakeChildBalancer) Close() {} + +func (f *fakeChildBalancer) ExitIdle() {} + +func (f *fakeChildBalancer) waitForClientConnStateChange(ctx context.Context) error { + _, err := f.clientConnState.Receive(ctx) + if err != nil { + return err + } + return nil +} + +func (f *fakeChildBalancer) waitForResolverError(ctx context.Context) error { + _, err := f.resolverError.Receive(ctx) + if err != nil { + return err + } + return nil +} + +func (f *fakeChildBalancer) waitForSubConnStateChange(ctx context.Context, wantState *scStateChange) error { + val, err := f.subConnState.Receive(ctx) + if err != nil { + return err + } + gotState := val.(*scStateChange) + if !cmp.Equal(gotState, wantState, cmp.AllowUnexported(scStateChange{})) { + return fmt.Errorf("got subconnStateChange %v, want %v", gotState, wantState) + } + return nil +} + +func newFakeChildBalancer(cc balancer.ClientConn) balancer.Balancer { + return &fakeChildBalancer{ + cc: cc, + subConnState: testutils.NewChannelWithSize(10), + clientConnState: testutils.NewChannelWithSize(10), + resolverError: testutils.NewChannelWithSize(10), + } +} + +type fakeSubConn struct{} + +func (*fakeSubConn) UpdateAddresses([]resolver.Address) { panic("implement me") } +func (*fakeSubConn) Connect() { panic("implement me") } + +// waitForNewChildLB makes sure that a new child LB is created by the top-level +// clusterResolverBalancer. +func waitForNewChildLB(ctx context.Context, ch *testutils.Channel) (*fakeChildBalancer, error) { + val, err := ch.Receive(ctx) + if err != nil { + return nil, fmt.Errorf("error when waiting for a new edsLB: %v", err) + } + return val.(*fakeChildBalancer), nil +} + +// setup overrides the functions which are used to create the xdsClient and the +// edsLB, creates fake version of them and makes them available on the provided +// channels. The returned cancel function should be called by the test for +// cleanup. +func setup(childLBCh *testutils.Channel) (*fakeclient.Client, func()) { + xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) + + origNewChildBalancer := newChildBalancer + newChildBalancer = func(_ balancer.Builder, cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { + childLB := newFakeChildBalancer(cc) + defer func() { childLBCh.Send(childLB) }() + return childLB + } + return xdsC, func() { + newChildBalancer = origNewChildBalancer + xdsC.Close() + } +} + +// TestSubConnStateChange verifies if the top-level clusterResolverBalancer passes on +// the subConnState to appropriate child balancer. +func (s) TestSubConnStateChange(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + xdsC.InvokeWatchEDSCallback("", defaultEndpointsUpdate, nil) + edsLB, err := waitForNewChildLB(ctx, edsLBCh) + if err != nil { + t.Fatal(err) + } + + fsc := &fakeSubConn{} + state := balancer.SubConnState{ConnectivityState: connectivity.Ready} + edsB.UpdateSubConnState(fsc, state) + if err := edsLB.waitForSubConnStateChange(ctx, &scStateChange{sc: fsc, state: state}); err != nil { + t.Fatal(err) + } +} + +// TestErrorFromXDSClientUpdate verifies that an error from xdsClient update is +// handled correctly. +// +// If it's resource-not-found, watch will NOT be canceled, the EDS impl will +// receive an empty EDS update, and new RPCs will fail. +// +// If it's connection error, nothing will happen. This will need to change to +// handle fallback. +func (s) TestErrorFromXDSClientUpdate(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + edsLB, err := waitForNewChildLB(ctx, edsLBCh) + if err != nil { + t.Fatal(err) + } + if err := edsLB.waitForClientConnStateChange(ctx); err != nil { + t.Fatalf("EDS impl got unexpected update: %v", err) + } + + connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, connectionErr) + + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { + t.Fatal("watch was canceled, want not canceled (timeout error)") + } + + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal(err) + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "clusterResolverBalancer resource not found error") + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, resourceErr) + // Even if error is resource not found, watch shouldn't be canceled, because + // this is an EDS resource removed (and xds client actually never sends this + // error, but we still handles it). + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { + t.Fatal("watch was canceled, want not canceled (timeout error)") + } + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal(err) + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + // An update with the same service name should not trigger a new watch. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForWatchEDS(sCtx); err != context.DeadlineExceeded { + t.Fatal("got unexpected new EDS watch") + } +} + +// TestErrorFromResolver verifies that resolver errors are handled correctly. +// +// If it's resource-not-found, watch will be canceled, the EDS impl will receive +// an empty EDS update, and new RPCs will fail. +// +// If it's connection error, nothing will happen. This will need to change to +// handle fallback. +func (s) TestErrorFromResolver(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + edsLB, err := waitForNewChildLB(ctx, edsLBCh) + if err != nil { + t.Fatal(err) + } + if err := edsLB.waitForClientConnStateChange(ctx); err != nil { + t.Fatalf("EDS impl got unexpected update: %v", err) + } + + connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") + edsB.ResolverError(connectionErr) + + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { + t.Fatal("watch was canceled, want not canceled (timeout error)") + } + + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal("eds impl got EDS resp, want timeout error") + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "clusterResolverBalancer resource not found error") + edsB.ResolverError(resourceErr) + if _, err := xdsC.WaitForCancelEDSWatch(ctx); err != nil { + t.Fatalf("want watch to be canceled, waitForCancel failed: %v", err) + } + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal(err) + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + // An update with the same service name should trigger a new watch, because + // the previous watch was canceled. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } +} + +// Given a list of resource names, verifies that EDS requests for the same are +// sent by the EDS balancer, through the fake xDS client. +func verifyExpectedRequests(ctx context.Context, fc *fakeclient.Client, resourceNames ...string) error { + for _, name := range resourceNames { + if name == "" { + // ResourceName empty string indicates a cancel. + if _, err := fc.WaitForCancelEDSWatch(ctx); err != nil { + return fmt.Errorf("timed out when expecting resource %q", name) + } + continue + } + + resName, err := fc.WaitForWatchEDS(ctx) + if err != nil { + return fmt.Errorf("timed out when expecting resource %q, %p", name, fc) + } + if resName != name { + return fmt.Errorf("got EDS request for resource %q, expected: %q", resName, name) + } + } + return nil +} + +// TestClientWatchEDS verifies that the xdsClient inside the top-level EDS LB +// policy registers an EDS watch for expected resource upon receiving an update +// from gRPC. +func (s) TestClientWatchEDS(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // If eds service name is not set, should watch for cluster name. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS("cluster-1"), + }); err != nil { + t.Fatal(err) + } + if err := verifyExpectedRequests(ctx, xdsC, "cluster-1"); err != nil { + t.Fatal(err) + } + + // Update with an non-empty edsServiceName should trigger an EDS watch for + // the same. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS("foobar-1"), + }); err != nil { + t.Fatal(err) + } + if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-1"); err != nil { + t.Fatal(err) + } + + // Also test the case where the edsServerName changes from one non-empty + // name to another, and make sure a new watch is registered. The previously + // registered watch will be cancelled, which will result in an EDS request + // with no resource names being sent to the server. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS("foobar-2"), + }); err != nil { + t.Fatal(err) + } + if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-2"); err != nil { + t.Fatal(err) + } +} + +func newLBConfigWithOneEDS(edsServiceName string) *LBConfig { + return &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: edsServiceName, + }}, + } +} diff --git a/xds/internal/balancer/clusterresolver/config.go b/xds/internal/balancer/clusterresolver/config.go new file mode 100644 index 00000000000..a6a3cbab804 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/config.go @@ -0,0 +1,185 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package clusterresolver + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + + "google.golang.org/grpc/balancer/roundrobin" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/serviceconfig" + "google.golang.org/grpc/xds/internal/balancer/ringhash" +) + +// DiscoveryMechanismType is the type of discovery mechanism. +type DiscoveryMechanismType int + +const ( + // DiscoveryMechanismTypeEDS is eds. + DiscoveryMechanismTypeEDS DiscoveryMechanismType = iota // `json:"EDS"` + // DiscoveryMechanismTypeLogicalDNS is DNS. + DiscoveryMechanismTypeLogicalDNS // `json:"LOGICAL_DNS"` +) + +// MarshalJSON marshals a DiscoveryMechanismType to a quoted json string. +// +// This is necessary to handle enum (as strings) from JSON. +// +// Note that this needs to be defined on the type not pointer, otherwise the +// variables of this type will marshal to int not string. +func (t DiscoveryMechanismType) MarshalJSON() ([]byte, error) { + buffer := bytes.NewBufferString(`"`) + switch t { + case DiscoveryMechanismTypeEDS: + buffer.WriteString("EDS") + case DiscoveryMechanismTypeLogicalDNS: + buffer.WriteString("LOGICAL_DNS") + } + buffer.WriteString(`"`) + return buffer.Bytes(), nil +} + +// UnmarshalJSON unmarshals a quoted json string to the DiscoveryMechanismType. +func (t *DiscoveryMechanismType) UnmarshalJSON(b []byte) error { + var s string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + switch s { + case "EDS": + *t = DiscoveryMechanismTypeEDS + case "LOGICAL_DNS": + *t = DiscoveryMechanismTypeLogicalDNS + default: + return fmt.Errorf("unable to unmarshal string %q to type DiscoveryMechanismType", s) + } + return nil +} + +// DiscoveryMechanism is the discovery mechanism, can be either EDS or DNS. +// +// For DNS, the ClientConn target will be used for name resolution. +// +// For EDS, if EDSServiceName is not empty, it will be used for watching. If +// EDSServiceName is empty, Cluster will be used. +type DiscoveryMechanism struct { + // Cluster is the cluster name. + Cluster string `json:"cluster,omitempty"` + // LoadReportingServerName is the LRS server to send load reports to. If + // not present, load reporting will be disabled. If set to the empty string, + // load reporting will be sent to the same server that we obtained CDS data + // from. + LoadReportingServerName *string `json:"lrsLoadReportingServerName,omitempty"` + // MaxConcurrentRequests is the maximum number of outstanding requests can + // be made to the upstream cluster. Default is 1024. + MaxConcurrentRequests *uint32 `json:"maxConcurrentRequests,omitempty"` + // Type is the discovery mechanism type. + Type DiscoveryMechanismType `json:"type,omitempty"` + // EDSServiceName is the EDS service name, as returned in CDS. May be unset + // if not specified in CDS. For type EDS only. + // + // This is used for EDS watch if set. If unset, Cluster is used for EDS + // watch. + EDSServiceName string `json:"edsServiceName,omitempty"` + // DNSHostname is the DNS name to resolve in "host:port" form. For type + // LOGICAL_DNS only. + DNSHostname string `json:"dnsHostname,omitempty"` +} + +// Equal returns whether the DiscoveryMechanism is the same with the parameter. +func (dm DiscoveryMechanism) Equal(b DiscoveryMechanism) bool { + switch { + case dm.Cluster != b.Cluster: + return false + case !equalStringP(dm.LoadReportingServerName, b.LoadReportingServerName): + return false + case !equalUint32P(dm.MaxConcurrentRequests, b.MaxConcurrentRequests): + return false + case dm.Type != b.Type: + return false + case dm.EDSServiceName != b.EDSServiceName: + return false + case dm.DNSHostname != b.DNSHostname: + return false + } + return true +} + +func equalStringP(a, b *string) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +func equalUint32P(a, b *uint32) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +// LBConfig is the config for cluster resolver balancer. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + // DiscoveryMechanisms is an ordered list of discovery mechanisms. + // + // Must have at least one element. Results from each discovery mechanism are + // concatenated together in successive priorities. + DiscoveryMechanisms []DiscoveryMechanism `json:"discoveryMechanisms,omitempty"` + + // XDSLBPolicy specifies the policy for locality picking and endpoint picking. + // + // Note that it's not normal balancing policy, and it can only be either + // ROUND_ROBIN or RING_HASH. + // + // For ROUND_ROBIN, the policy name will be "ROUND_ROBIN", and the config + // will be empty. This sets the locality-picking policy to weighted_target + // and the endpoint-picking policy to round_robin. + // + // For RING_HASH, the policy name will be "RING_HASH", and the config will + // be lb config for the ring_hash_experimental LB Policy. ring_hash policy + // is responsible for both locality picking and endpoint picking. + XDSLBPolicy *internalserviceconfig.BalancerConfig `json:"xdsLbPolicy,omitempty"` +} + +const ( + rrName = roundrobin.Name + rhName = ringhash.Name +) + +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig + if err := json.Unmarshal(c, &cfg); err != nil { + return nil, err + } + if lbp := cfg.XDSLBPolicy; lbp != nil && !strings.EqualFold(lbp.Name, rrName) && !strings.EqualFold(lbp.Name, rhName) { + return nil, fmt.Errorf("unsupported child policy with name %q, not one of {%q,%q}", lbp.Name, rrName, rhName) + } + return &cfg, nil +} diff --git a/xds/internal/balancer/clusterresolver/config_test.go b/xds/internal/balancer/clusterresolver/config_test.go new file mode 100644 index 00000000000..796f8a49372 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/config_test.go @@ -0,0 +1,269 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/internal/balancer/stub" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/xds/internal/balancer/ringhash" +) + +func TestDiscoveryMechanismTypeMarshalJSON(t *testing.T) { + tests := []struct { + name string + typ DiscoveryMechanismType + want string + }{ + { + name: "eds", + typ: DiscoveryMechanismTypeEDS, + want: `"EDS"`, + }, + { + name: "dns", + typ: DiscoveryMechanismTypeLogicalDNS, + want: `"LOGICAL_DNS"`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, err := json.Marshal(tt.typ); err != nil || string(got) != tt.want { + t.Fatalf("DiscoveryMechanismTypeEDS.MarshalJSON() = (%v, %v), want (%s, nil)", string(got), err, tt.want) + } + }) + } +} +func TestDiscoveryMechanismTypeUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + js string + want DiscoveryMechanismType + wantErr bool + }{ + { + name: "eds", + js: `"EDS"`, + want: DiscoveryMechanismTypeEDS, + }, + { + name: "dns", + js: `"LOGICAL_DNS"`, + want: DiscoveryMechanismTypeLogicalDNS, + }, + { + name: "error", + js: `"1234"`, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got DiscoveryMechanismType + err := json.Unmarshal([]byte(tt.js), &got) + if (err != nil) != tt.wantErr { + t.Fatalf("DiscoveryMechanismTypeEDS.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Fatalf("DiscoveryMechanismTypeEDS.UnmarshalJSON() got unexpected output, diff (-got +want): %v", diff) + } + }) + } +} + +func init() { + // This is needed now for the config parsing tests to pass. Otherwise they + // will fail with "RING_HASH unsupported". + // + // TODO: delete this once ring-hash policy is implemented and imported. + stub.Register(rhName, stub.BalancerFuncs{}) +} + +const ( + testJSONConfig1 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }] +}` + testJSONConfig2 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + },{ + "type": "LOGICAL_DNS" + }] +}` + testJSONConfig3 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }], + "xdsLbPolicy":[{"ROUND_ROBIN":{}}] +}` + testJSONConfig4 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }], + "xdsLbPolicy":[{"ring_hash_experimental":{}}] +}` + testJSONConfig5 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }], + "xdsLbPolicy":[{"pick_first":{}}] +}` +) + +func TestParseConfig(t *testing.T) { + tests := []struct { + name string + js string + want *LBConfig + wantErr bool + }{ + { + name: "empty json", + js: "", + want: nil, + wantErr: true, + }, + { + name: "OK with one discovery mechanism", + js: testJSONConfig1, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + }, + XDSLBPolicy: nil, + }, + wantErr: false, + }, + { + name: "OK with multiple discovery mechanisms", + js: testJSONConfig2, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + { + Type: DiscoveryMechanismTypeLogicalDNS, + }, + }, + XDSLBPolicy: nil, + }, + wantErr: false, + }, + { + name: "OK with picking policy round_robin", + js: testJSONConfig3, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + }, + XDSLBPolicy: &internalserviceconfig.BalancerConfig{ + Name: "ROUND_ROBIN", + Config: nil, + }, + }, + wantErr: false, + }, + { + name: "OK with picking policy ring_hash", + js: testJSONConfig4, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + }, + XDSLBPolicy: &internalserviceconfig.BalancerConfig{ + Name: ringhash.Name, + Config: nil, + }, + }, + wantErr: false, + }, + { + name: "unsupported picking policy", + js: testJSONConfig5, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseConfig([]byte(tt.js)) + if (err != nil) != tt.wantErr { + t.Fatalf("parseConfig() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("parseConfig() got unexpected output, diff (-got +want): %v", diff) + } + }) + } +} + +func newString(s string) *string { + return &s +} + +func newUint32(i uint32) *uint32 { + return &i +} diff --git a/xds/internal/balancer/clusterresolver/configbuilder.go b/xds/internal/balancer/clusterresolver/configbuilder.go new file mode 100644 index 00000000000..475497d4895 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/configbuilder.go @@ -0,0 +1,364 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "encoding/json" + "fmt" + "sort" + + "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/internal/hierarchy" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/balancer/clusterimpl" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/balancer/ringhash" + "google.golang.org/grpc/xds/internal/balancer/weightedtarget" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const million = 1000000 + +// priorityConfig is config for one priority. For example, if there an EDS and a +// DNS, the priority list will be [priorityConfig{EDS}, priorityConfig{DNS}]. +// +// Each priorityConfig corresponds to one discovery mechanism from the LBConfig +// generated by the CDS balancer. The CDS balancer resolves the cluster name to +// an ordered list of discovery mechanisms (if the top cluster is an aggregated +// cluster), one for each underlying cluster. +type priorityConfig struct { + mechanism DiscoveryMechanism + // edsResp is set only if type is EDS. + edsResp xdsclient.EndpointsUpdate + // addresses is set only if type is DNS. + addresses []string +} + +// buildPriorityConfigJSON builds balancer config for the passed in +// priorities. +// +// The built tree of balancers (see test for the output struct). +// +// If xds lb policy is ROUND_ROBIN, the children will be weighted_target for +// locality picking, and round_robin for endpoint picking. +// +// ┌────────┐ +// │priority│ +// └┬──────┬┘ +// │ │ +// ┌───────────▼┐ ┌▼───────────┐ +// │cluster_impl│ │cluster_impl│ +// └─┬──────────┘ └──────────┬─┘ +// │ │ +// ┌──────────────▼─┐ ┌─▼──────────────┐ +// │locality_picking│ │locality_picking│ +// └┬──────────────┬┘ └┬──────────────┬┘ +// │ │ │ │ +// ┌─▼─┐ ┌─▼─┐ ┌─▼─┐ ┌─▼─┐ +// │LRS│ │LRS│ │LRS│ │LRS│ +// └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ +// │ │ │ │ +// ┌──────────▼─────┐ ┌─────▼──────────┐ ┌──────────▼─────┐ ┌─────▼──────────┐ +// │endpoint_picking│ │endpoint_picking│ │endpoint_picking│ │endpoint_picking│ +// └────────────────┘ └────────────────┘ └────────────────┘ └────────────────┘ +// +// If xds lb policy is RING_HASH, the children will be just a ring_hash policy. +// The endpoints from all localities will be flattened to one addresses list, +// and the ring_hash policy will pick endpoints from it. +// +// ┌────────┐ +// │priority│ +// └┬──────┬┘ +// │ │ +// ┌──────────▼─┐ ┌─▼──────────┐ +// │cluster_impl│ │cluster_impl│ +// └──────┬─────┘ └─────┬──────┘ +// │ │ +// ┌──────▼─────┐ ┌─────▼──────┐ +// │ ring_hash │ │ ring_hash │ +// └────────────┘ └────────────┘ +// +// If endpointPickingPolicy is nil, roundrobin will be used. +// +// Custom locality picking policy isn't support, and weighted_target is always +// used. +func buildPriorityConfigJSON(priorities []priorityConfig, xdsLBPolicy *internalserviceconfig.BalancerConfig) ([]byte, []resolver.Address, error) { + pc, addrs, err := buildPriorityConfig(priorities, xdsLBPolicy) + if err != nil { + return nil, nil, fmt.Errorf("failed to build priority config: %v", err) + } + ret, err := json.Marshal(pc) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal built priority config struct into json: %v", err) + } + return ret, addrs, nil +} + +func buildPriorityConfig(priorities []priorityConfig, xdsLBPolicy *internalserviceconfig.BalancerConfig) (*priority.LBConfig, []resolver.Address, error) { + var ( + retConfig = &priority.LBConfig{Children: make(map[string]*priority.Child)} + retAddrs []resolver.Address + ) + for i, p := range priorities { + switch p.mechanism.Type { + case DiscoveryMechanismTypeEDS: + names, configs, addrs, err := buildClusterImplConfigForEDS(i, p.edsResp, p.mechanism, xdsLBPolicy) + if err != nil { + return nil, nil, err + } + retConfig.Priorities = append(retConfig.Priorities, names...) + for n, c := range configs { + retConfig.Children[n] = &priority.Child{ + Config: &internalserviceconfig.BalancerConfig{Name: clusterimpl.Name, Config: c}, + // Ignore all re-resolution from EDS children. + IgnoreReresolutionRequests: true, + } + } + retAddrs = append(retAddrs, addrs...) + case DiscoveryMechanismTypeLogicalDNS: + name, config, addrs := buildClusterImplConfigForDNS(i, p.addresses) + retConfig.Priorities = append(retConfig.Priorities, name) + retConfig.Children[name] = &priority.Child{ + Config: &internalserviceconfig.BalancerConfig{Name: clusterimpl.Name, Config: config}, + // Not ignore re-resolution from DNS children, they will trigger + // DNS to re-resolve. + IgnoreReresolutionRequests: false, + } + retAddrs = append(retAddrs, addrs...) + } + } + return retConfig, retAddrs, nil +} + +func buildClusterImplConfigForDNS(parentPriority int, addrStrs []string) (string, *clusterimpl.LBConfig, []resolver.Address) { + // Endpoint picking policy for DNS is hardcoded to pick_first. + const childPolicy = "pick_first" + retAddrs := make([]resolver.Address, 0, len(addrStrs)) + pName := fmt.Sprintf("priority-%v", parentPriority) + for _, addrStr := range addrStrs { + retAddrs = append(retAddrs, hierarchy.Set(resolver.Address{Addr: addrStr}, []string{pName})) + } + return pName, &clusterimpl.LBConfig{ChildPolicy: &internalserviceconfig.BalancerConfig{Name: childPolicy}}, retAddrs +} + +// buildClusterImplConfigForEDS returns a list of cluster_impl configs, one for +// each priority, sorted by priority, and the addresses for each priority (with +// hierarchy attributes set). +// +// For example, if there are two priorities, the returned values will be +// - ["p0", "p1"] +// - map{"p0":p0_config, "p1":p1_config} +// - [p0_address_0, p0_address_1, p1_address_0, p1_address_1] +// - p0 addresses' hierarchy attributes are set to p0 +func buildClusterImplConfigForEDS(parentPriority int, edsResp xdsclient.EndpointsUpdate, mechanism DiscoveryMechanism, xdsLBPolicy *internalserviceconfig.BalancerConfig) ([]string, map[string]*clusterimpl.LBConfig, []resolver.Address, error) { + drops := make([]clusterimpl.DropConfig, 0, len(edsResp.Drops)) + for _, d := range edsResp.Drops { + drops = append(drops, clusterimpl.DropConfig{ + Category: d.Category, + RequestsPerMillion: d.Numerator * million / d.Denominator, + }) + } + + priorityChildNames, priorities := groupLocalitiesByPriority(edsResp.Localities) + retNames := make([]string, 0, len(priorityChildNames)) + retAddrs := make([]resolver.Address, 0, len(priorityChildNames)) + retConfigs := make(map[string]*clusterimpl.LBConfig, len(priorityChildNames)) + for _, priorityName := range priorityChildNames { + priorityLocalities := priorities[priorityName] + // Prepend parent priority to the priority names, to avoid duplicates. + pName := fmt.Sprintf("priority-%v-%v", parentPriority, priorityName) + retNames = append(retNames, pName) + cfg, addrs, err := priorityLocalitiesToClusterImpl(priorityLocalities, pName, mechanism, drops, xdsLBPolicy) + if err != nil { + return nil, nil, nil, err + } + retConfigs[pName] = cfg + retAddrs = append(retAddrs, addrs...) + } + return retNames, retConfigs, retAddrs, nil +} + +// groupLocalitiesByPriority returns the localities grouped by priority. +// +// It also returns a list of strings where each string represents a priority, +// and the list is sorted from higher priority to lower priority. +// +// For example, for L0-p0, L1-p0, L2-p1, results will be +// - ["p0", "p1"] +// - map{"p0":[L0, L1], "p1":[L2]} +func groupLocalitiesByPriority(localities []xdsclient.Locality) ([]string, map[string][]xdsclient.Locality) { + var priorityIntSlice []int + priorities := make(map[string][]xdsclient.Locality) + for _, locality := range localities { + if locality.Weight == 0 { + continue + } + priorityName := fmt.Sprintf("%v", locality.Priority) + priorities[priorityName] = append(priorities[priorityName], locality) + priorityIntSlice = append(priorityIntSlice, int(locality.Priority)) + } + // Sort the priorities based on the int value, deduplicate, and then turn + // the sorted list into a string list. This will be child names, in priority + // order. + sort.Ints(priorityIntSlice) + priorityIntSliceDeduped := dedupSortedIntSlice(priorityIntSlice) + priorityNameSlice := make([]string, 0, len(priorityIntSliceDeduped)) + for _, p := range priorityIntSliceDeduped { + priorityNameSlice = append(priorityNameSlice, fmt.Sprintf("%v", p)) + } + return priorityNameSlice, priorities +} + +func dedupSortedIntSlice(a []int) []int { + if len(a) == 0 { + return a + } + i, j := 0, 1 + for ; j < len(a); j++ { + if a[i] == a[j] { + continue + } + i++ + if i != j { + a[i] = a[j] + } + } + return a[:i+1] +} + +// rrBalancerConfig is a const roundrobin config, used as child of +// weighted-roundrobin. To avoid allocating memory everytime. +var rrBalancerConfig = &internalserviceconfig.BalancerConfig{Name: roundrobin.Name} + +// priorityLocalitiesToClusterImpl takes a list of localities (with the same +// priority), and generates a cluster impl policy config, and a list of +// addresses. +func priorityLocalitiesToClusterImpl(localities []xdsclient.Locality, priorityName string, mechanism DiscoveryMechanism, drops []clusterimpl.DropConfig, xdsLBPolicy *internalserviceconfig.BalancerConfig) (*clusterimpl.LBConfig, []resolver.Address, error) { + clusterImplCfg := &clusterimpl.LBConfig{ + Cluster: mechanism.Cluster, + EDSServiceName: mechanism.EDSServiceName, + LoadReportingServerName: mechanism.LoadReportingServerName, + MaxConcurrentRequests: mechanism.MaxConcurrentRequests, + DropCategories: drops, + // ChildPolicy is not set. Will be set based on xdsLBPolicy + } + + if xdsLBPolicy == nil || xdsLBPolicy.Name == rrName { + // If lb policy is ROUND_ROBIN: + // - locality-picking policy is weighted_target + // - endpoint-picking policy is round_robin + logger.Infof("xds lb policy is %q, building config with weighted_target + round_robin", rrName) + // Child of weighted_target is hardcoded to round_robin. + wtConfig, addrs := localitiesToWeightedTarget(localities, priorityName, rrBalancerConfig) + clusterImplCfg.ChildPolicy = &internalserviceconfig.BalancerConfig{Name: weightedtarget.Name, Config: wtConfig} + return clusterImplCfg, addrs, nil + } + + if xdsLBPolicy.Name == rhName { + // If lb policy is RIHG_HASH, will build one ring_hash policy as child. + // The endpoints from all localities will be flattened to one addresses + // list, and the ring_hash policy will pick endpoints from it. + logger.Infof("xds lb policy is %q, building config with ring_hash", rhName) + addrs := localitiesToRingHash(localities, priorityName) + // Set child to ring_hash, note that the ring_hash config is from + // xdsLBPolicy. + clusterImplCfg.ChildPolicy = &internalserviceconfig.BalancerConfig{Name: ringhash.Name, Config: xdsLBPolicy.Config} + return clusterImplCfg, addrs, nil + } + + return nil, nil, fmt.Errorf("unsupported xds LB policy %q, not one of {%q,%q}", xdsLBPolicy.Name, rrName, rhName) +} + +// localitiesToRingHash takes a list of localities (with the same priority), and +// generates a list of addresses. +// +// The addresses have path hierarchy set to [priority-name], so priority knows +// which child policy they are for. +func localitiesToRingHash(localities []xdsclient.Locality, priorityName string) []resolver.Address { + var addrs []resolver.Address + for _, locality := range localities { + var lw uint32 = 1 + if locality.Weight != 0 { + lw = locality.Weight + } + localityStr, err := locality.ID.ToString() + if err != nil { + localityStr = fmt.Sprintf("%+v", locality.ID) + } + for _, endpoint := range locality.Endpoints { + // Filter out all "unhealthy" endpoints (unknown and healthy are + // both considered to be healthy: + // https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/core/health_check.proto#envoy-api-enum-core-healthstatus). + if endpoint.HealthStatus != xdsclient.EndpointHealthStatusHealthy && endpoint.HealthStatus != xdsclient.EndpointHealthStatusUnknown { + continue + } + + var ew uint32 = 1 + if endpoint.Weight != 0 { + ew = endpoint.Weight + } + + // The weight of each endpoint is locality_weight * endpoint_weight. + ai := weightedroundrobin.AddrInfo{Weight: lw * ew} + addr := weightedroundrobin.SetAddrInfo(resolver.Address{Addr: endpoint.Address}, ai) + addr = hierarchy.Set(addr, []string{priorityName, localityStr}) + addr = internal.SetLocalityID(addr, locality.ID) + addrs = append(addrs, addr) + } + } + return addrs +} + +// localitiesToWeightedTarget takes a list of localities (with the same +// priority), and generates a weighted target config, and list of addresses. +// +// The addresses have path hierarchy set to [priority-name, locality-name], so +// priority and weighted target know which child policy they are for. +func localitiesToWeightedTarget(localities []xdsclient.Locality, priorityName string, childPolicy *internalserviceconfig.BalancerConfig) (*weightedtarget.LBConfig, []resolver.Address) { + weightedTargets := make(map[string]weightedtarget.Target) + var addrs []resolver.Address + for _, locality := range localities { + localityStr, err := locality.ID.ToString() + if err != nil { + localityStr = fmt.Sprintf("%+v", locality.ID) + } + weightedTargets[localityStr] = weightedtarget.Target{Weight: locality.Weight, ChildPolicy: childPolicy} + for _, endpoint := range locality.Endpoints { + // Filter out all "unhealthy" endpoints (unknown and healthy are + // both considered to be healthy: + // https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/core/health_check.proto#envoy-api-enum-core-healthstatus). + if endpoint.HealthStatus != xdsclient.EndpointHealthStatusHealthy && endpoint.HealthStatus != xdsclient.EndpointHealthStatusUnknown { + continue + } + + addr := resolver.Address{Addr: endpoint.Address} + if childPolicy.Name == weightedroundrobin.Name && endpoint.Weight != 0 { + ai := weightedroundrobin.AddrInfo{Weight: endpoint.Weight} + addr = weightedroundrobin.SetAddrInfo(addr, ai) + } + addr = hierarchy.Set(addr, []string{priorityName, localityStr}) + addr = internal.SetLocalityID(addr, locality.ID) + addrs = append(addrs, addr) + } + } + return &weightedtarget.LBConfig{Targets: weightedTargets}, addrs +} diff --git a/xds/internal/balancer/clusterresolver/configbuilder_test.go b/xds/internal/balancer/clusterresolver/configbuilder_test.go new file mode 100644 index 00000000000..3e2ad8a2e64 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/configbuilder_test.go @@ -0,0 +1,979 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "bytes" + "encoding/json" + "fmt" + "sort" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/internal/hierarchy" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/balancer/clusterimpl" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/balancer/ringhash" + "google.golang.org/grpc/xds/internal/balancer/weightedtarget" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const ( + testLRSServer = "test-lrs-server" + testMaxRequests = 314 + testEDSServiceName = "service-name-from-parent" + testDropCategory = "test-drops" + testDropOverMillion = 1 + + localityCount = 5 + addressPerLocality = 2 +) + +var ( + testLocalityIDs []internal.LocalityID + testAddressStrs [][]string + testEndpoints [][]xdsclient.Endpoint + + testLocalitiesP0, testLocalitiesP1 []xdsclient.Locality + + addrCmpOpts = cmp.Options{ + cmp.AllowUnexported(attributes.Attributes{}), + cmp.Transformer("SortAddrs", func(in []resolver.Address) []resolver.Address { + out := append([]resolver.Address(nil), in...) // Copy input to avoid mutating it + sort.Slice(out, func(i, j int) bool { + return out[i].Addr < out[j].Addr + }) + return out + })} +) + +func init() { + for i := 0; i < localityCount; i++ { + testLocalityIDs = append(testLocalityIDs, internal.LocalityID{Zone: fmt.Sprintf("test-zone-%d", i)}) + var ( + addrs []string + ends []xdsclient.Endpoint + ) + for j := 0; j < addressPerLocality; j++ { + addr := fmt.Sprintf("addr-%d-%d", i, j) + addrs = append(addrs, addr) + ends = append(ends, xdsclient.Endpoint{ + Address: addr, + HealthStatus: xdsclient.EndpointHealthStatusHealthy, + }) + } + testAddressStrs = append(testAddressStrs, addrs) + testEndpoints = append(testEndpoints, ends) + } + + testLocalitiesP0 = []xdsclient.Locality{ + { + Endpoints: testEndpoints[0], + ID: testLocalityIDs[0], + Weight: 20, + Priority: 0, + }, + { + Endpoints: testEndpoints[1], + ID: testLocalityIDs[1], + Weight: 80, + Priority: 0, + }, + } + testLocalitiesP1 = []xdsclient.Locality{ + { + Endpoints: testEndpoints[2], + ID: testLocalityIDs[2], + Weight: 20, + Priority: 1, + }, + { + Endpoints: testEndpoints[3], + ID: testLocalityIDs[3], + Weight: 80, + Priority: 1, + }, + } +} + +// TestBuildPriorityConfigJSON is a sanity check that the built balancer config +// can be parsed. The behavior test is covered by TestBuildPriorityConfig. +func TestBuildPriorityConfigJSON(t *testing.T) { + gotConfig, _, err := buildPriorityConfigJSON([]priorityConfig{ + { + mechanism: DiscoveryMechanism{ + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServiceName, + }, + edsResp: xdsclient.EndpointsUpdate{ + Drops: []xdsclient.OverloadDropConfig{ + { + Category: testDropCategory, + Numerator: testDropOverMillion, + Denominator: million, + }, + }, + Localities: []xdsclient.Locality{ + testLocalitiesP0[0], + testLocalitiesP0[1], + testLocalitiesP1[0], + testLocalitiesP1[1], + }, + }, + }, + { + mechanism: DiscoveryMechanism{ + Type: DiscoveryMechanismTypeLogicalDNS, + }, + addresses: testAddressStrs[4], + }, + }, nil) + if err != nil { + t.Fatalf("buildPriorityConfigJSON(...) failed: %v", err) + } + + var prettyGot bytes.Buffer + if err := json.Indent(&prettyGot, gotConfig, ">>> ", " "); err != nil { + t.Fatalf("json.Indent() failed: %v", err) + } + // Print the indented json if this test fails. + t.Log(prettyGot.String()) + + priorityB := balancer.Get(priority.Name) + if _, err = priorityB.(balancer.ConfigParser).ParseConfig(gotConfig); err != nil { + t.Fatalf("ParseConfig(%+v) failed: %v", gotConfig, err) + } +} + +func TestBuildPriorityConfig(t *testing.T) { + gotConfig, gotAddrs, _ := buildPriorityConfig([]priorityConfig{ + { + mechanism: DiscoveryMechanism{ + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServiceName, + }, + edsResp: xdsclient.EndpointsUpdate{ + Drops: []xdsclient.OverloadDropConfig{ + { + Category: testDropCategory, + Numerator: testDropOverMillion, + Denominator: million, + }, + }, + Localities: []xdsclient.Locality{ + testLocalitiesP0[0], + testLocalitiesP0[1], + testLocalitiesP1[0], + testLocalitiesP1[1], + }, + }, + }, + { + mechanism: DiscoveryMechanism{ + Type: DiscoveryMechanismTypeLogicalDNS, + }, + addresses: testAddressStrs[4], + }, + }, nil) + + wantConfig := &priority.LBConfig{ + Children: map[string]*priority.Child{ + "priority-0-0": { + Config: &internalserviceconfig.BalancerConfig{ + Name: clusterimpl.Name, + Config: &clusterimpl.LBConfig{ + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[0].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[1].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + }, + IgnoreReresolutionRequests: true, + }, + "priority-0-1": { + Config: &internalserviceconfig.BalancerConfig{ + Name: clusterimpl.Name, + Config: &clusterimpl.LBConfig{ + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[2].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[3].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + }, + IgnoreReresolutionRequests: true, + }, + "priority-1": { + Config: &internalserviceconfig.BalancerConfig{ + Name: clusterimpl.Name, + Config: &clusterimpl.LBConfig{ + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: "pick_first"}, + }, + }, + IgnoreReresolutionRequests: false, + }, + }, + Priorities: []string{"priority-0-0", "priority-0-1", "priority-1"}, + } + wantAddrs := []resolver.Address{ + testAddrWithAttrs(testAddressStrs[0][0], nil, "priority-0-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[0][1], nil, "priority-0-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[1][0], nil, "priority-0-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[1][1], nil, "priority-0-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[2][0], nil, "priority-0-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[2][1], nil, "priority-0-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[3][0], nil, "priority-0-1", &testLocalityIDs[3]), + testAddrWithAttrs(testAddressStrs[3][1], nil, "priority-0-1", &testLocalityIDs[3]), + testAddrWithAttrs(testAddressStrs[4][0], nil, "priority-1", nil), + testAddrWithAttrs(testAddressStrs[4][1], nil, "priority-1", nil), + } + + if diff := cmp.Diff(gotConfig, wantConfig); diff != "" { + t.Errorf("buildPriorityConfig() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotAddrs, wantAddrs, addrCmpOpts); diff != "" { + t.Errorf("buildPriorityConfig() diff (-got +want) %v", diff) + } +} + +func TestBuildClusterImplConfigForDNS(t *testing.T) { + gotName, gotConfig, gotAddrs := buildClusterImplConfigForDNS(3, testAddressStrs[0]) + wantName := "priority-3" + wantConfig := &clusterimpl.LBConfig{ + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: "pick_first", + }, + } + wantAddrs := []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testAddressStrs[0][0]}, []string{"priority-3"}), + hierarchy.Set(resolver.Address{Addr: testAddressStrs[0][1]}, []string{"priority-3"}), + } + + if diff := cmp.Diff(gotName, wantName); diff != "" { + t.Errorf("buildClusterImplConfigForDNS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotConfig, wantConfig); diff != "" { + t.Errorf("buildClusterImplConfigForDNS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotAddrs, wantAddrs, addrCmpOpts); diff != "" { + t.Errorf("buildClusterImplConfigForDNS() diff (-got +want) %v", diff) + } +} + +func TestBuildClusterImplConfigForEDS(t *testing.T) { + gotNames, gotConfigs, gotAddrs, _ := buildClusterImplConfigForEDS( + 2, + xdsclient.EndpointsUpdate{ + Drops: []xdsclient.OverloadDropConfig{ + { + Category: testDropCategory, + Numerator: testDropOverMillion, + Denominator: million, + }, + }, + Localities: []xdsclient.Locality{ + { + Endpoints: testEndpoints[3], + ID: testLocalityIDs[3], + Weight: 80, + Priority: 1, + }, { + Endpoints: testEndpoints[1], + ID: testLocalityIDs[1], + Weight: 80, + Priority: 0, + }, { + Endpoints: testEndpoints[2], + ID: testLocalityIDs[2], + Weight: 20, + Priority: 1, + }, { + Endpoints: testEndpoints[0], + ID: testLocalityIDs[0], + Weight: 20, + Priority: 0, + }, + }, + }, + DiscoveryMechanism{ + Cluster: testClusterName, + MaxConcurrentRequests: newUint32(testMaxRequests), + LoadReportingServerName: newString(testLRSServer), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServiceName, + }, + nil, + ) + + wantNames := []string{ + fmt.Sprintf("priority-%v-%v", 2, 0), + fmt.Sprintf("priority-%v-%v", 2, 1), + } + wantConfigs := map[string]*clusterimpl.LBConfig{ + "priority-2-0": { + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[0].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[1].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + "priority-2-1": { + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[2].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[3].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + } + wantAddrs := []resolver.Address{ + testAddrWithAttrs(testAddressStrs[0][0], nil, "priority-2-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[0][1], nil, "priority-2-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[1][0], nil, "priority-2-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[1][1], nil, "priority-2-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[2][0], nil, "priority-2-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[2][1], nil, "priority-2-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[3][0], nil, "priority-2-1", &testLocalityIDs[3]), + testAddrWithAttrs(testAddressStrs[3][1], nil, "priority-2-1", &testLocalityIDs[3]), + } + + if diff := cmp.Diff(gotNames, wantNames); diff != "" { + t.Errorf("buildClusterImplConfigForEDS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotConfigs, wantConfigs); diff != "" { + t.Errorf("buildClusterImplConfigForEDS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotAddrs, wantAddrs, addrCmpOpts); diff != "" { + t.Errorf("buildClusterImplConfigForEDS() diff (-got +want) %v", diff) + } + +} + +func TestGroupLocalitiesByPriority(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + wantPriorities []string + wantLocalities map[string][]xdsclient.Locality + }{ + { + name: "1 locality 1 priority", + localities: []xdsclient.Locality{testLocalitiesP0[0]}, + wantPriorities: []string{"0"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0]}, + }, + }, + { + name: "2 locality 1 priority", + localities: []xdsclient.Locality{testLocalitiesP0[0], testLocalitiesP0[1]}, + wantPriorities: []string{"0"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0], testLocalitiesP0[1]}, + }, + }, + { + name: "1 locality in each", + localities: []xdsclient.Locality{testLocalitiesP0[0], testLocalitiesP1[0]}, + wantPriorities: []string{"0", "1"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0]}, + "1": {testLocalitiesP1[0]}, + }, + }, + { + name: "2 localities in each sorted", + localities: []xdsclient.Locality{ + testLocalitiesP0[0], testLocalitiesP0[1], + testLocalitiesP1[0], testLocalitiesP1[1]}, + wantPriorities: []string{"0", "1"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0], testLocalitiesP0[1]}, + "1": {testLocalitiesP1[0], testLocalitiesP1[1]}, + }, + }, + { + // The localities are given in order [p1, p0, p1, p0], but the + // returned priority list must be sorted [p0, p1], because the list + // order is the priority order. + name: "2 localities in each needs to sort", + localities: []xdsclient.Locality{ + testLocalitiesP1[1], testLocalitiesP0[1], + testLocalitiesP1[0], testLocalitiesP0[0]}, + wantPriorities: []string{"0", "1"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[1], testLocalitiesP0[0]}, + "1": {testLocalitiesP1[1], testLocalitiesP1[0]}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPriorities, gotLocalities := groupLocalitiesByPriority(tt.localities) + if diff := cmp.Diff(gotPriorities, tt.wantPriorities); diff != "" { + t.Errorf("groupLocalitiesByPriority() diff(-got +want) %v", diff) + } + if diff := cmp.Diff(gotLocalities, tt.wantLocalities); diff != "" { + t.Errorf("groupLocalitiesByPriority() diff(-got +want) %v", diff) + } + }) + } +} + +func TestDedupSortedIntSlice(t *testing.T) { + tests := []struct { + name string + a []int + want []int + }{ + { + name: "empty", + a: []int{}, + want: []int{}, + }, + { + name: "no dup", + a: []int{0, 1, 2, 3}, + want: []int{0, 1, 2, 3}, + }, + { + name: "with dup", + a: []int{0, 0, 1, 1, 1, 2, 3}, + want: []int{0, 1, 2, 3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := dedupSortedIntSlice(tt.a); !cmp.Equal(got, tt.want) { + t.Errorf("dedupSortedIntSlice() = %v, want %v, diff %v", got, tt.want, cmp.Diff(got, tt.want)) + } + }) + } +} + +func TestPriorityLocalitiesToClusterImpl(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + priorityName string + mechanism DiscoveryMechanism + childPolicy *internalserviceconfig.BalancerConfig + wantConfig *clusterimpl.LBConfig + wantAddrs []resolver.Address + wantErr bool + }{{ + name: "round robin as child, no LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: rrName}, + mechanism: DiscoveryMechanism{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &clusterimpl.LBConfig{ + Cluster: testClusterName, + EDSServiceName: testEDSServcie, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "ring_hash as child", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: rhName, Config: &ringhash.LBConfig{MinRingSize: 1, MaxRingSize: 2}}, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &clusterimpl.LBConfig{ + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: ringhash.Name, + Config: &ringhash.LBConfig{MinRingSize: 1, MaxRingSize: 2}, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(1800), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(200), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(7200), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(800), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "unsupported child", + localities: []xdsclient.Locality{{ + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }}, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: "some-child"}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1, err := priorityLocalitiesToClusterImpl(tt.localities, tt.priorityName, tt.mechanism, nil, tt.childPolicy) + if (err != nil) != tt.wantErr { + t.Fatalf("priorityLocalitiesToClusterImpl() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.wantConfig); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(got1, tt.wantAddrs, cmp.AllowUnexported(attributes.Attributes{})); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + }) + } +} + +func TestLocalitiesToWeightedTarget(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + priorityName string + childPolicy *internalserviceconfig.BalancerConfig + lrsServer *string + wantConfig *weightedtarget.LBConfig + wantAddrs []resolver.Address + }{ + { + name: "roundrobin as child, with LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + lrsServer: newString("test-lrs-server"), + wantConfig: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "roundrobin as child, no LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "weighted round robin as child, no LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: weightedroundrobin.Name}, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedroundrobin.Name, + }, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedroundrobin.Name, + }, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := localitiesToWeightedTarget(tt.localities, tt.priorityName, tt.childPolicy) + if diff := cmp.Diff(got, tt.wantConfig); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(got1, tt.wantAddrs, cmp.AllowUnexported(attributes.Attributes{})); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + }) + } +} + +func TestLocalitiesToRingHash(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + priorityName string + wantAddrs []resolver.Address + }{ + { + // Check that address weights are locality_weight * endpoint_weight. + name: "with locality and endpoint weight", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(1800), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(200), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(7200), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(800), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + // Check that endpoint_weight is 0, weight is the locality weight. + name: "locality weight only", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(20), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(20), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(80), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(80), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + // Check that locality_weight is 0, weight is the endpoint weight. + name: "endpoint weight only", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + }, + }, + priorityName: "test-priority", + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := localitiesToRingHash(tt.localities, tt.priorityName) + if diff := cmp.Diff(got, tt.wantAddrs, cmp.AllowUnexported(attributes.Attributes{})); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + }) + } +} + +func assertString(f func() (string, error)) string { + s, err := f() + if err != nil { + panic(err.Error()) + } + return s +} + +func testAddrWithAttrs(addrStr string, weight *uint32, priority string, lID *internal.LocalityID) resolver.Address { + addr := resolver.Address{Addr: addrStr} + if weight != nil { + addr = weightedroundrobin.SetAddrInfo(addr, weightedroundrobin.AddrInfo{Weight: *weight}) + } + path := []string{priority} + if lID != nil { + path = append(path, assertString(lID.ToString)) + addr = internal.SetLocalityID(addr, *lID) + } + addr = hierarchy.Set(addr, path) + return addr +} diff --git a/xds/internal/balancer/clusterresolver/eds_impl_test.go b/xds/internal/balancer/clusterresolver/eds_impl_test.go new file mode 100644 index 00000000000..00814a6212b --- /dev/null +++ b/xds/internal/balancer/clusterresolver/eds_impl_test.go @@ -0,0 +1,575 @@ +/* + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package clusterresolver + +import ( + "context" + "fmt" + "sort" + "testing" + "time" + + corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/balancer/balancergroup" + "google.golang.org/grpc/xds/internal/balancer/clusterimpl" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/balancer/weightedtarget" + "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +var ( + testClusterNames = []string{"test-cluster-1", "test-cluster-2"} + testSubZones = []string{"I", "II", "III", "IV"} + testEndpointAddrs []string +) + +const testBackendAddrsCount = 12 + +func init() { + for i := 0; i < testBackendAddrsCount; i++ { + testEndpointAddrs = append(testEndpointAddrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) + } + balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond + clusterimpl.NewRandomWRR = testutils.NewTestWRR + weightedtarget.NewRandomWRR = testutils.NewTestWRR + balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond * 100 +} + +func setupTestEDS(t *testing.T, initChild *internalserviceconfig.BalancerConfig) (balancer.Balancer, *testutils.TestClientConn, *fakeclient.Client, func()) { + xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) + cc := testutils.NewTestClientConn(t) + builder := balancer.Get(Name) + edsb := builder.Build(cc, balancer.BuildOptions{Target: resolver.Target{Endpoint: testEDSServcie}}) + if edsb == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + }}, + }, + }); err != nil { + edsb.Close() + xdsC.Close() + t.Fatal(err) + } + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + edsb.Close() + xdsC.Close() + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + return edsb, cc, xdsC, func() { + edsb.Close() + xdsC.Close() + } +} + +// One locality +// - add backend +// - remove backend +// - replace backend +// - change drop rate +func (s) TestEDS_OneLocality(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Pick with only the first backend. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) + } + + // The same locality, add one more backend. + clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) + + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test roundrobin with two subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2}); err != nil { + t.Fatal(err) + } + + // The same locality, delete first backend. + clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[1:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) + + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Test pick with only the second subconn. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) + } + + // The same locality, replace backend. + clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab4.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab4.Build()), nil) + + sc3 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + scToRemove = <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Test pick with only the third subconn. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3}); err != nil { + t.Fatal(err) + } + + // The same locality, different drop rate, dropping 50%. + clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], map[string]uint32{"test-drop": 50}) + clab5.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab5.Build()), nil) + + // Picks with drops. + if err := testPickerFromCh(cc.NewPickerCh, func(p balancer.Picker) error { + for i := 0; i < 100; i++ { + _, err := p.Pick(balancer.PickInfo{}) + // TODO: the dropping algorithm needs a design. When the dropping algorithm + // is fixed, this test also needs fix. + if i%2 == 0 && err == nil { + return fmt.Errorf("%d - the even number picks should be drops, got error ", i) + } else if i%2 != 0 && err != nil { + return fmt.Errorf("%d - the odd number picks should be non-drops, got error %v", i, err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + // The same locality, remove drops. + clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab6.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab6.Build()), nil) + + // Pick without drops. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3}); err != nil { + t.Fatal(err) + } +} + +// 2 locality +// - start with 2 locality +// - add locality +// - remove locality +// - address change for the locality +// - update locality weight +func (s) TestEDS_TwoLocalities(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + // Two localities, each with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Add the second locality later to make sure sc2 belongs to the second + // locality. Otherwise the test is flaky because of a map is used in EDS to + // keep localities. + clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test roundrobin with two subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2}); err != nil { + t.Fatal(err) + } + + // Add another locality, with one backend. + clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + clab2.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + clab2.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) + + sc3 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test roundrobin with three subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2, sc3}); err != nil { + t.Fatal(err) + } + + // Remove first locality. + clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab3.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + clab3.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) + + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Test pick with two subconns (without the first one). + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2, sc3}); err != nil { + t.Fatal(err) + } + + // Add a backend to the last locality. + clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab4.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + clab4.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab4.Build()), nil) + + sc4 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test pick with two subconns (without the first one). + // + // Locality-1 will be picked twice, and locality-2 will be picked twice. + // Locality-1 contains only sc2, locality-2 contains sc3 and sc4. So expect + // two sc2's and sc3, sc4. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2, sc2, sc3, sc4}); err != nil { + t.Fatal(err) + } + + // Change weight of the locality[1]. + clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab5.AddLocality(testSubZones[1], 2, 0, testEndpointAddrs[1:2], nil) + clab5.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab5.Build()), nil) + + // Test pick with two subconns different locality weight. + // + // Locality-1 will be picked four times, and locality-2 will be picked twice + // (weight 2 and 1). Locality-1 contains only sc2, locality-2 contains sc3 and + // sc4. So expect four sc2's and sc3, sc4. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2, sc2, sc2, sc2, sc3, sc4}); err != nil { + t.Fatal(err) + } + + // Change weight of the locality[1] to 0, it should never be picked. + clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab6.AddLocality(testSubZones[1], 0, 0, testEndpointAddrs[1:2], nil) + clab6.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab6.Build()), nil) + + // Changing weight of locality[1] to 0 caused it to be removed. It's subconn + // should also be removed. + // + // NOTE: this is because we handle locality with weight 0 same as the + // locality doesn't exist. If this changes in the future, this removeSubConn + // behavior will also change. + scToRemove2 := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove2, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove2) + } + + // Test pick with two subconns different locality weight. + // + // Locality-1 will be not be picked, and locality-2 will be picked. + // Locality-2 contains sc3 and sc4. So expect sc3, sc4. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3, sc4}); err != nil { + t.Fatal(err) + } +} + +// The EDS balancer gets EDS resp with unhealthy endpoints. Test that only +// healthy ones are used. +func (s) TestEDS_EndpointsHealth(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + // Two localities, each 3 backend, one Healthy, one Unhealthy, one Unknown. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:6], &testutils.AddLocalityOptions{ + Health: []corepb.HealthStatus{ + corepb.HealthStatus_HEALTHY, + corepb.HealthStatus_UNHEALTHY, + corepb.HealthStatus_UNKNOWN, + corepb.HealthStatus_DRAINING, + corepb.HealthStatus_TIMEOUT, + corepb.HealthStatus_DEGRADED, + }, + }) + clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[6:12], &testutils.AddLocalityOptions{ + Health: []corepb.HealthStatus{ + corepb.HealthStatus_HEALTHY, + corepb.HealthStatus_UNHEALTHY, + corepb.HealthStatus_UNKNOWN, + corepb.HealthStatus_DRAINING, + corepb.HealthStatus_TIMEOUT, + corepb.HealthStatus_DEGRADED, + }, + }) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + var ( + readySCs []balancer.SubConn + newSubConnAddrStrs []string + ) + for i := 0; i < 4; i++ { + addr := <-cc.NewSubConnAddrsCh + newSubConnAddrStrs = append(newSubConnAddrStrs, addr[0].Addr) + sc := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + readySCs = append(readySCs, sc) + } + + wantNewSubConnAddrStrs := []string{ + testEndpointAddrs[0], + testEndpointAddrs[2], + testEndpointAddrs[6], + testEndpointAddrs[8], + } + sortStrTrans := cmp.Transformer("Sort", func(in []string) []string { + out := append([]string(nil), in...) // Copy input to avoid mutating it. + sort.Strings(out) + return out + }) + if !cmp.Equal(newSubConnAddrStrs, wantNewSubConnAddrStrs, sortStrTrans) { + t.Fatalf("want newSubConn with address %v, got %v", wantNewSubConnAddrStrs, newSubConnAddrStrs) + } + + // There should be exactly 4 new SubConns. Check to make sure there's no + // more subconns being created. + select { + case <-cc.NewSubConnCh: + t.Fatalf("Got unexpected new subconn") + case <-time.After(time.Microsecond * 100): + } + + // Test roundrobin with the subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, readySCs); err != nil { + t.Fatal(err) + } +} + +// TestEDS_EmptyUpdate covers the cases when eds impl receives an empty update. +// +// It should send an error picker with transient failure to the parent. +func (s) TestEDS_EmptyUpdate(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + const cacheTimeout = 100 * time.Microsecond + oldCacheTimeout := balancergroup.DefaultSubBalancerCloseTimeout + balancergroup.DefaultSubBalancerCloseTimeout = cacheTimeout + defer func() { balancergroup.DefaultSubBalancerCloseTimeout = oldCacheTimeout }() + + // The first update is an empty update. + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + // Pick should fail with transient failure, and all priority removed error. + if err := testErrPickerFromCh(cc.NewPickerCh, priority.ErrAllPrioritiesRemoved); err != nil { + t.Fatal(err) + } + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Pick with only the first backend. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) + } + + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + // Pick should fail with transient failure, and all priority removed error. + if err := testErrPickerFromCh(cc.NewPickerCh, priority.ErrAllPrioritiesRemoved); err != nil { + t.Fatal(err) + } + + // Wait for the old SubConn to be removed (which happens when the child + // policy is closed), so a new update would trigger a new SubConn (we need + // this new SubConn to tell if the next picker is newly created). + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Handle another update with priorities and localities. + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Pick with only the first backend. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) + } +} + +func (s) TestEDS_CircuitBreaking(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + var maxRequests uint32 = 50 + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + MaxConcurrentRequests: &maxRequests, + Type: DiscoveryMechanismTypeEDS, + }}, + }, + }); err != nil { + t.Fatal(err) + } + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Picks with drops. + dones := []func(){} + p := <-cc.NewPickerCh + for i := 0; i < 100; i++ { + pr, err := p.Pick(balancer.PickInfo{}) + if i < 50 && err != nil { + t.Errorf("The first 50%% picks should be non-drops, got error %v", err) + } else if i > 50 && err == nil { + t.Errorf("The second 50%% picks should be drops, got error ") + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + for _, done := range dones { + done() + } + dones = []func(){} + + // Pick without drops. + for i := 0; i < 50; i++ { + pr, err := p.Pick(balancer.PickInfo{}) + if err != nil { + t.Errorf("The third 50%% picks should be non-drops, got error %v", err) + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + // Without this, future tests with the same service name will fail. + for _, done := range dones { + done() + } + + // Send another update, with only circuit breaking update (and no picker + // update afterwards). Make sure the new picker uses the new configs. + var maxRequests2 uint32 = 10 + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + MaxConcurrentRequests: &maxRequests2, + Type: DiscoveryMechanismTypeEDS, + }}, + }, + }); err != nil { + t.Fatal(err) + } + + // Picks with drops. + dones = []func(){} + p2 := <-cc.NewPickerCh + for i := 0; i < 100; i++ { + pr, err := p2.Pick(balancer.PickInfo{}) + if i < 10 && err != nil { + t.Errorf("The first 10%% picks should be non-drops, got error %v", err) + } else if i > 10 && err == nil { + t.Errorf("The next 90%% picks should be drops, got error ") + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + for _, done := range dones { + done() + } + dones = []func(){} + + // Pick without drops. + for i := 0; i < 10; i++ { + pr, err := p2.Pick(balancer.PickInfo{}) + if err != nil { + t.Errorf("The next 10%% picks should be non-drops, got error %v", err) + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + // Without this, future tests with the same service name will fail. + for _, done := range dones { + done() + } +} diff --git a/xds/internal/balancer/lrs/logging.go b/xds/internal/balancer/clusterresolver/logging.go similarity index 84% rename from xds/internal/balancer/lrs/logging.go rename to xds/internal/balancer/clusterresolver/logging.go index 602dac09959..728f1f709c2 100644 --- a/xds/internal/balancer/lrs/logging.go +++ b/xds/internal/balancer/clusterresolver/logging.go @@ -16,7 +16,7 @@ * */ -package lrs +package clusterresolver import ( "fmt" @@ -25,10 +25,10 @@ import ( internalgrpclog "google.golang.org/grpc/internal/grpclog" ) -const prefix = "[lrs-lb %p] " +const prefix = "[xds-cluster-resolver-lb %p] " var logger = grpclog.Component("xds") -func prefixLogger(p *lrsBalancer) *internalgrpclog.PrefixLogger { +func prefixLogger(p *clusterResolverBalancer) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) } diff --git a/xds/internal/balancer/edsbalancer/eds_impl_priority_test.go b/xds/internal/balancer/clusterresolver/priority_test.go similarity index 54% rename from xds/internal/balancer/edsbalancer/eds_impl_priority_test.go rename to xds/internal/balancer/clusterresolver/priority_test.go index 7696feb5bd0..8438a373d9d 100644 --- a/xds/internal/balancer/edsbalancer/eds_impl_priority_test.go +++ b/xds/internal/balancer/clusterresolver/priority_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package edsbalancer +package clusterresolver import ( "context" @@ -26,6 +26,8 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/balancer/priority" "google.golang.org/grpc/xds/internal/testutils" ) @@ -34,15 +36,14 @@ import ( // // Init 0 and 1; 0 is up, use 0; add 2, use 0; remove 2, use 0. func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with priorities [0, 1], each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[0]; got != want { @@ -51,22 +52,20 @@ func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { sc1 := <-cc.NewSubConnCh // p0 is ready. - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p1 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } - // Add p2, it shouldn't cause any udpates. + // Add p2, it shouldn't cause any updates. clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab2.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) select { case <-cc.NewPickerCh: @@ -82,7 +81,7 @@ func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab3.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) select { case <-cc.NewPickerCh: @@ -100,15 +99,14 @@ func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { // Init 0 and 1; 0 is up, use 0; 0 is down, 1 is up, use 1; add 2, use 1; 1 is // down, use 2; remove 2, use 1. func (s) TestEDSPriority_SwitchPriority(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with priorities [0, 1], each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { @@ -117,41 +115,35 @@ func (s) TestEDSPriority_SwitchPriority(t *testing.T) { sc0 := <-cc.NewSubConnCh // p0 is ready. - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p0 := <-cc.NewPickerCh - want := []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p0)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Turn down 0, 1 is used. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 1. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } - // Add p2, it shouldn't cause any udpates. + // Add p2, it shouldn't cause any updates. clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab2.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) select { case <-cc.NewPickerCh: @@ -164,29 +156,25 @@ func (s) TestEDSPriority_SwitchPriority(t *testing.T) { } // Turn down 1, use 2 - edsb.handleSubConnStateChange(sc1, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 2. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) } // Remove 2, use 1. clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab3.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) // p2 SubConns are removed. scToRemove := <-cc.RemoveSubConnCh @@ -195,28 +183,23 @@ func (s) TestEDSPriority_SwitchPriority(t *testing.T) { } // Should get an update with 1's old picker, to override 2's old picker. - p3 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if _, err := p3.Pick(balancer.PickInfo{}); err != balancer.ErrTransientFailure { - t.Fatalf("want pick error %v, got %v", balancer.ErrTransientFailure, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, balancer.ErrTransientFailure); err != nil { + t.Fatal(err) } + } // Add a lower priority while the higher priority is down. // // Init 0 and 1; 0 and 1 both down; add 2, use 2. func (s) TestEDSPriority_HigherDownWhileAddingLower(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -224,21 +207,18 @@ func (s) TestEDSPriority_HigherDownWhileAddingLower(t *testing.T) { sc0 := <-cc.NewSubConnCh // Turn down 0, 1 is used. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh // Turn down 1, pick should error. - edsb.handleSubConnStateChange(sc1, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) // Test pick failure. - pFail := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != balancer.ErrTransientFailure { - t.Fatalf("want pick error %v, got %v", balancer.ErrTransientFailure, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, balancer.ErrTransientFailure); err != nil { + t.Fatal(err) } // Add p2, it should create a new SubConn. @@ -246,41 +226,34 @@ func (s) TestEDSPriority_HigherDownWhileAddingLower(t *testing.T) { clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab2.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 2. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) } + } // When a higher priority becomes available, all lower priorities are closed. // // Init 0,1,2; 0 and 1 down, use 2; 0 up, close 1 and 2. func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with priorities [0,1,2], each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab1.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -288,39 +261,55 @@ func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { sc0 := <-cc.NewSubConnCh // Turn down 0, 1 is used. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh // Turn down 1, 2 is used. - edsb.handleSubConnStateChange(sc1, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 2. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) } // When 0 becomes ready, 0 should be used, 1 and 2 should all be closed. - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + var ( + scToRemove []balancer.SubConn + scToRemoveMap = make(map[balancer.SubConn]struct{}) + ) + // Each subconn is removed twice. This is OK in production, but it makes + // testing harder. + // + // The sub-balancer to be closed is priority's child, clusterimpl, who has + // weightedtarget as children. + // + // - When clusterimpl is removed from priority's balancergroup, all its + // subconns are removed once. + // - When clusterimpl is closed, it closes weightedtarget, and this + // weightedtarget's balancer removes all the same subconns again. + for i := 0; i < 4; i++ { + // We expect 2 subconns, so we recv from channel 4 times. + scToRemoveMap[<-cc.RemoveSubConnCh] = struct{}{} + } + for sc := range scToRemoveMap { + scToRemove = append(scToRemove, sc) + } // sc1 and sc2 should be removed. // // With localities caching, the lower priorities are closed after a timeout, // in goroutines. The order is no longer guaranteed. - scToRemove := []balancer.SubConn{<-cc.RemoveSubConnCh, <-cc.RemoveSubConnCh} if !(cmp.Equal(scToRemove[0], sc1, cmp.AllowUnexported(testutils.TestSubConn{})) && cmp.Equal(scToRemove[1], sc2, cmp.AllowUnexported(testutils.TestSubConn{}))) && !(cmp.Equal(scToRemove[0], sc2, cmp.AllowUnexported(testutils.TestSubConn{})) && @@ -329,12 +318,8 @@ func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { } // Test pick with 0. - p0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p0.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } } @@ -345,23 +330,20 @@ func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { func (s) TestEDSPriority_InitTimeout(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := priority.DefaultPriorityInitTimeout + priority.DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + priority.DefaultPriorityInitTimeout = old } }()() - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -369,7 +351,7 @@ func (s) TestEDSPriority_InitTimeout(t *testing.T) { sc0 := <-cc.NewSubConnCh // Keep 0 in connecting, 1 will be used after init timeout. - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) // Make sure new SubConn is created before timeout. select { @@ -384,16 +366,12 @@ func (s) TestEDSPriority_InitTimeout(t *testing.T) { } sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 1. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } } @@ -402,51 +380,44 @@ func (s) TestEDSPriority_InitTimeout(t *testing.T) { // - start with 2 locality with p0 and p1 // - add localities to existing p0 and p1 func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab0 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab0.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab0.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab0.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab0.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc0 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p0 := <-cc.NewPickerCh - want := []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p0)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Turn down p0 subconns, p1 subconns will be created. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p1 subconns. - p1 := <-cc.NewPickerCh - want = []balancer.SubConn{sc1} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } // Reconnect p0 subconns, p1 subconn will be closed. - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) scToRemove := <-cc.RemoveSubConnCh if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { @@ -454,10 +425,8 @@ func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { } // Test roundrobin with only p0 subconns. - p2 := <-cc.NewPickerCh - want = []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Add two localities, with two priorities, with one backend. @@ -466,39 +435,34 @@ func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab1.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) clab1.AddLocality(testSubZones[3], 1, 1, testEndpointAddrs[3:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only two p0 subconns. - p3 := <-cc.NewPickerCh - want = []balancer.SubConn{sc0, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p3)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0, sc2}); err != nil { + t.Fatal(err) } // Turn down p0 subconns, p1 subconns will be created. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) - edsb.handleSubConnStateChange(sc2, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) sc4 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc4, connectivity.Connecting) - edsb.handleSubConnStateChange(sc4, connectivity.Ready) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p1 subconns. - p4 := <-cc.NewPickerCh - want = []balancer.SubConn{sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p4)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3, sc4}); err != nil { + t.Fatal(err) } } @@ -506,62 +470,55 @@ func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { func (s) TestEDSPriority_RemovesAllLocalities(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := priority.DefaultPriorityInitTimeout + priority.DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + priority.DefaultPriorityInitTimeout = old } }()() - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab0 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab0.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab0.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab0.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab0.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc0 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p0 := <-cc.NewPickerCh - want := []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p0)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Remove all priorities. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) // p0 subconn should be removed. scToRemove := <-cc.RemoveSubConnCh + <-cc.RemoveSubConnCh // Drain the duplicate subconn removed. if !cmp.Equal(scToRemove, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { t.Fatalf("RemoveSubConn, want %v, got %v", sc0, scToRemove) } + // time.Sleep(time.Second) + // Test pick return TransientFailure. - pFail := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != errAllPrioritiesRemoved { - t.Fatalf("want pick error %v, got %v", errAllPrioritiesRemoved, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, priority.ErrAllPrioritiesRemoved); err != nil { + t.Fatal(err) } // Re-add two localities, with previous priorities, but different backends. clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[3:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) addrs01 := <-cc.NewSubConnAddrsCh if got, want := addrs01[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -578,45 +535,39 @@ func (s) TestEDSPriority_RemovesAllLocalities(t *testing.T) { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc11 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc11, connectivity.Connecting) - edsb.handleSubConnStateChange(sc11, connectivity.Ready) + edsb.UpdateSubConnState(sc11, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc11, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p1 subconns. - p1 := <-cc.NewPickerCh - want = []balancer.SubConn{sc11} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc11}); err != nil { + t.Fatal(err) } // Remove p1 from EDS, to fallback to p0. clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) // p1 subconn should be removed. scToRemove1 := <-cc.RemoveSubConnCh + <-cc.RemoveSubConnCh // Drain the duplicate subconn removed. if !cmp.Equal(scToRemove1, sc11, cmp.AllowUnexported(testutils.TestSubConn{})) { t.Fatalf("RemoveSubConn, want %v, got %v", sc11, scToRemove1) } // Test pick return TransientFailure. - pFail1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if scst, err := pFail1.Pick(balancer.PickInfo{}); err != balancer.ErrNoSubConnAvailable { - t.Fatalf("want pick error _, %v, got %v, _ ,%v", balancer.ErrTransientFailure, scst, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, balancer.ErrNoSubConnAvailable); err != nil { + t.Fatal(err) } // Send an ready update for the p0 sc that was received when re-adding // localities to EDS. - edsb.handleSubConnStateChange(sc01, connectivity.Connecting) - edsb.handleSubConnStateChange(sc01, connectivity.Ready) + edsb.UpdateSubConnState(sc01, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc01, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p2 := <-cc.NewPickerCh - want = []balancer.SubConn{sc01} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc01}); err != nil { + t.Fatal(err) } select { @@ -630,83 +581,16 @@ func (s) TestEDSPriority_RemovesAllLocalities(t *testing.T) { } } -func (s) TestPriorityType(t *testing.T) { - p0 := newPriorityType(0) - p1 := newPriorityType(1) - p2 := newPriorityType(2) - - if !p0.higherThan(p1) || !p0.higherThan(p2) { - t.Errorf("want p0 to be higher than p1 and p2, got p0>p1: %v, p0>p2: %v", !p0.higherThan(p1), !p0.higherThan(p2)) - } - if !p1.lowerThan(p0) || !p1.higherThan(p2) { - t.Errorf("want p1 to be between p0 and p2, got p1p2: %v", !p1.lowerThan(p0), !p1.higherThan(p2)) - } - if !p2.lowerThan(p0) || !p2.lowerThan(p1) { - t.Errorf("want p2 to be lower than p0 and p1, got p2 subConnMu, but this is implicit via - // balancers (starting balancer with next priority while holding priorityMu, - // and the balancer may create new SubConn). - - priorityMu sync.Mutex - // priorities are pointers, and will be nil when EDS returns empty result. - priorityInUse priorityType - priorityLowest priorityType - priorityToState map[priorityType]*balancer.State - // The timer to give a priority 10 seconds to connect. And if the priority - // doesn't go into Ready/Failure, start the next priority. - // - // One timer is enough because there can be at most one priority in init - // state. - priorityInitTimer *time.Timer - - subConnMu sync.Mutex - subConnToPriority map[balancer.SubConn]priorityType - - pickerMu sync.Mutex - dropConfig []xdsclient.OverloadDropConfig - drops []*dropper - innerState balancer.State // The state of the picker without drop support. - serviceRequestsCounter *client.ServiceRequestsCounter - serviceRequestCountMax uint32 -} - -// newEDSBalancerImpl create a new edsBalancerImpl. -func newEDSBalancerImpl(cc balancer.ClientConn, bOpts balancer.BuildOptions, enqueueState func(priorityType, balancer.State), lr load.PerClusterReporter, logger *grpclog.PrefixLogger) *edsBalancerImpl { - edsImpl := &edsBalancerImpl{ - cc: cc, - buildOpts: bOpts, - logger: logger, - subBalancerBuilder: balancer.Get(roundrobin.Name), - loadReporter: lr, - - enqueueChildBalancerStateUpdate: enqueueState, - - priorityToLocalities: make(map[priorityType]*balancerGroupWithConfig), - priorityToState: make(map[priorityType]*balancer.State), - subConnToPriority: make(map[balancer.SubConn]priorityType), - serviceRequestCountMax: defaultServiceRequestCountMax, - } - // Don't start balancer group here. Start it when handling the first EDS - // response. Otherwise the balancer group will be started with round-robin, - // and if users specify a different sub-balancer, all balancers in balancer - // group will be closed and recreated when sub-balancer update happens. - return edsImpl -} - -// handleChildPolicy updates the child balancers handling endpoints. Child -// policy is roundrobin by default. If the specified balancer is not installed, -// the old child balancer will be used. -// -// HandleChildPolicy and HandleEDSResponse must be called by the same goroutine. -func (edsImpl *edsBalancerImpl) handleChildPolicy(name string, config json.RawMessage) { - if edsImpl.subBalancerBuilder.Name() == name { - return - } - newSubBalancerBuilder := balancer.Get(name) - if newSubBalancerBuilder == nil { - edsImpl.logger.Infof("edsBalancerImpl: failed to find balancer with name %q, keep using %q", name, edsImpl.subBalancerBuilder.Name()) - return - } - edsImpl.subBalancerBuilder = newSubBalancerBuilder - for _, bgwc := range edsImpl.priorityToLocalities { - if bgwc == nil { - continue - } - for lid, config := range bgwc.configs { - lidJSON, err := lid.ToString() - if err != nil { - edsImpl.logger.Errorf("failed to marshal LocalityID: %#v, skipping this locality", lid) - continue - } - // TODO: (eds) add support to balancer group to support smoothly - // switching sub-balancers (keep old balancer around until new - // balancer becomes ready). - bgwc.bg.Remove(lidJSON) - bgwc.bg.Add(lidJSON, edsImpl.subBalancerBuilder) - bgwc.bg.UpdateClientConnState(lidJSON, balancer.ClientConnState{ - ResolverState: resolver.State{Addresses: config.addrs}, - }) - // This doesn't need to manually update picker, because the new - // sub-balancer will send it's picker later. - } - } -} - -// updateDrops compares new drop policies with the old. If they are different, -// it updates the drop policies and send ClientConn an updated picker. -func (edsImpl *edsBalancerImpl) updateDrops(dropConfig []xdsclient.OverloadDropConfig) { - if cmp.Equal(dropConfig, edsImpl.dropConfig) { - return - } - edsImpl.pickerMu.Lock() - edsImpl.dropConfig = dropConfig - var newDrops []*dropper - for _, c := range edsImpl.dropConfig { - newDrops = append(newDrops, newDropper(c)) - } - edsImpl.drops = newDrops - if edsImpl.innerState.Picker != nil { - // Update picker with old inner picker, new drops. - edsImpl.cc.UpdateState(balancer.State{ - ConnectivityState: edsImpl.innerState.ConnectivityState, - Picker: newDropPicker(edsImpl.innerState.Picker, newDrops, edsImpl.loadReporter, edsImpl.serviceRequestsCounter, edsImpl.serviceRequestCountMax)}, - ) - } - edsImpl.pickerMu.Unlock() -} - -// handleEDSResponse handles the EDS response and creates/deletes localities and -// SubConns. It also handles drops. -// -// HandleChildPolicy and HandleEDSResponse must be called by the same goroutine. -func (edsImpl *edsBalancerImpl) handleEDSResponse(edsResp xdsclient.EndpointsUpdate) { - // TODO: Unhandled fields from EDS response: - // - edsResp.GetPolicy().GetOverprovisioningFactor() - // - locality.GetPriority() - // - lbEndpoint.GetMetadata(): contains BNS name, send to sub-balancers - // - as service config or as resolved address - // - if socketAddress is not ip:port - // - socketAddress.GetNamedPort(), socketAddress.GetResolverName() - // - resolve endpoint's name with another resolver - - // If the first EDS update is an empty update, nothing is changing from the - // previous update (which is the default empty value). We need to explicitly - // handle first update being empty, and send a transient failure picker. - // - // TODO: define Equal() on type EndpointUpdate to avoid DeepEqual. And do - // the same for the other types. - if !edsImpl.respReceived && reflect.DeepEqual(edsResp, xdsclient.EndpointsUpdate{}) { - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: connectivity.TransientFailure, Picker: base.NewErrPicker(errAllPrioritiesRemoved)}) - } - edsImpl.respReceived = true - - edsImpl.updateDrops(edsResp.Drops) - - // Filter out all localities with weight 0. - // - // Locality weighted load balancer can be enabled by setting an option in - // CDS, and the weight of each locality. Currently, without the guarantee - // that CDS is always sent, we assume locality weighted load balance is - // always enabled, and ignore all weight 0 localities. - // - // In the future, we should look at the config in CDS response and decide - // whether locality weight matters. - newLocalitiesWithPriority := make(map[priorityType][]xdsclient.Locality) - for _, locality := range edsResp.Localities { - if locality.Weight == 0 { - continue - } - priority := newPriorityType(locality.Priority) - newLocalitiesWithPriority[priority] = append(newLocalitiesWithPriority[priority], locality) - } - - var ( - priorityLowest priorityType - priorityChanged bool - ) - - for priority, newLocalities := range newLocalitiesWithPriority { - if !priorityLowest.isSet() || priorityLowest.higherThan(priority) { - priorityLowest = priority - } - - bgwc, ok := edsImpl.priorityToLocalities[priority] - if !ok { - // Create balancer group if it's never created (this is the first - // time this priority is received). We don't start it here. It may - // be started when necessary (e.g. when higher is down, or if it's a - // new lowest priority). - ccPriorityWrapper := edsImpl.ccWrapperWithPriority(priority) - stateAggregator := weightedaggregator.New(ccPriorityWrapper, edsImpl.logger, newRandomWRR) - bgwc = &balancerGroupWithConfig{ - bg: balancergroup.New(ccPriorityWrapper, edsImpl.buildOpts, stateAggregator, edsImpl.loadReporter, edsImpl.logger), - stateAggregator: stateAggregator, - configs: make(map[internal.LocalityID]*localityConfig), - } - edsImpl.priorityToLocalities[priority] = bgwc - priorityChanged = true - edsImpl.logger.Infof("New priority %v added", priority) - } - edsImpl.handleEDSResponsePerPriority(bgwc, newLocalities) - } - edsImpl.priorityLowest = priorityLowest - - // Delete priorities that are removed in the latest response, and also close - // the balancer group. - for p, bgwc := range edsImpl.priorityToLocalities { - if _, ok := newLocalitiesWithPriority[p]; !ok { - delete(edsImpl.priorityToLocalities, p) - bgwc.bg.Close() - delete(edsImpl.priorityToState, p) - priorityChanged = true - edsImpl.logger.Infof("Priority %v deleted", p) - } - } - - // If priority was added/removed, it may affect the balancer group to use. - // E.g. priorityInUse was removed, or all priorities are down, and a new - // lower priority was added. - if priorityChanged { - edsImpl.handlePriorityChange() - } -} - -func (edsImpl *edsBalancerImpl) handleEDSResponsePerPriority(bgwc *balancerGroupWithConfig, newLocalities []xdsclient.Locality) { - // newLocalitiesSet contains all names of localities in the new EDS response - // for the same priority. It's used to delete localities that are removed in - // the new EDS response. - newLocalitiesSet := make(map[internal.LocalityID]struct{}) - var rebuildStateAndPicker bool - for _, locality := range newLocalities { - // One balancer for each locality. - - lid := locality.ID - lidJSON, err := lid.ToString() - if err != nil { - edsImpl.logger.Errorf("failed to marshal LocalityID: %#v, skipping this locality", lid) - continue - } - newLocalitiesSet[lid] = struct{}{} - - newWeight := locality.Weight - var newAddrs []resolver.Address - for _, lbEndpoint := range locality.Endpoints { - // Filter out all "unhealthy" endpoints (unknown and - // healthy are both considered to be healthy: - // https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/core/health_check.proto#envoy-api-enum-core-healthstatus). - if lbEndpoint.HealthStatus != xdsclient.EndpointHealthStatusHealthy && - lbEndpoint.HealthStatus != xdsclient.EndpointHealthStatusUnknown { - continue - } - - address := resolver.Address{ - Addr: lbEndpoint.Address, - } - if edsImpl.subBalancerBuilder.Name() == weightedroundrobin.Name && lbEndpoint.Weight != 0 { - ai := weightedroundrobin.AddrInfo{Weight: lbEndpoint.Weight} - address = weightedroundrobin.SetAddrInfo(address, ai) - // Metadata field in resolver.Address is deprecated. The - // attributes field should be used to specify arbitrary - // attributes about the address. We still need to populate the - // Metadata field here to allow users of this field to migrate - // to the new one. - // TODO(easwars): Remove this once all users have migrated. - // See https://github.com/grpc/grpc-go/issues/3563. - address.Metadata = &ai - } - newAddrs = append(newAddrs, address) - } - var weightChanged, addrsChanged bool - config, ok := bgwc.configs[lid] - if !ok { - // A new balancer, add it to balancer group and balancer map. - bgwc.stateAggregator.Add(lidJSON, newWeight) - bgwc.bg.Add(lidJSON, edsImpl.subBalancerBuilder) - config = &localityConfig{ - weight: newWeight, - } - bgwc.configs[lid] = config - - // weightChanged is false for new locality, because there's no need - // to update weight in bg. - addrsChanged = true - edsImpl.logger.Infof("New locality %v added", lid) - } else { - // Compare weight and addrs. - if config.weight != newWeight { - weightChanged = true - } - if !cmp.Equal(config.addrs, newAddrs) { - addrsChanged = true - } - edsImpl.logger.Infof("Locality %v updated, weightedChanged: %v, addrsChanged: %v", lid, weightChanged, addrsChanged) - } - - if weightChanged { - config.weight = newWeight - bgwc.stateAggregator.UpdateWeight(lidJSON, newWeight) - rebuildStateAndPicker = true - } - - if addrsChanged { - config.addrs = newAddrs - bgwc.bg.UpdateClientConnState(lidJSON, balancer.ClientConnState{ - ResolverState: resolver.State{Addresses: newAddrs}, - }) - } - } - - // Delete localities that are removed in the latest response. - for lid := range bgwc.configs { - lidJSON, err := lid.ToString() - if err != nil { - edsImpl.logger.Errorf("failed to marshal LocalityID: %#v, skipping this locality", lid) - continue - } - if _, ok := newLocalitiesSet[lid]; !ok { - bgwc.stateAggregator.Remove(lidJSON) - bgwc.bg.Remove(lidJSON) - delete(bgwc.configs, lid) - edsImpl.logger.Infof("Locality %v deleted", lid) - rebuildStateAndPicker = true - } - } - - if rebuildStateAndPicker { - bgwc.stateAggregator.BuildAndUpdate() - } -} - -// handleSubConnStateChange handles the state change and update pickers accordingly. -func (edsImpl *edsBalancerImpl) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { - edsImpl.subConnMu.Lock() - var bgwc *balancerGroupWithConfig - if p, ok := edsImpl.subConnToPriority[sc]; ok { - if s == connectivity.Shutdown { - // Only delete sc from the map when state changed to Shutdown. - delete(edsImpl.subConnToPriority, sc) - } - bgwc = edsImpl.priorityToLocalities[p] - } - edsImpl.subConnMu.Unlock() - if bgwc == nil { - edsImpl.logger.Infof("edsBalancerImpl: priority not found for sc state change") - return - } - if bg := bgwc.bg; bg != nil { - bg.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: s}) - } -} - -// updateServiceRequestsConfig handles changes to the circuit breaking configuration. -func (edsImpl *edsBalancerImpl) updateServiceRequestsConfig(serviceName string, max *uint32) { - if !env.CircuitBreakingSupport { - return - } - edsImpl.pickerMu.Lock() - var updatePicker bool - if edsImpl.serviceRequestsCounter == nil || edsImpl.serviceRequestsCounter.ServiceName != serviceName { - edsImpl.serviceRequestsCounter = client.GetServiceRequestsCounter(serviceName) - updatePicker = true - } - - var newMax uint32 = defaultServiceRequestCountMax - if max != nil { - newMax = *max - } - if edsImpl.serviceRequestCountMax != newMax { - edsImpl.serviceRequestCountMax = newMax - updatePicker = true - } - if updatePicker && edsImpl.innerState.Picker != nil { - // Update picker with old inner picker, new counter and counterMax. - edsImpl.cc.UpdateState(balancer.State{ - ConnectivityState: edsImpl.innerState.ConnectivityState, - Picker: newDropPicker(edsImpl.innerState.Picker, edsImpl.drops, edsImpl.loadReporter, edsImpl.serviceRequestsCounter, edsImpl.serviceRequestCountMax)}, - ) - } - edsImpl.pickerMu.Unlock() -} - -// updateState first handles priority, and then wraps picker in a drop picker -// before forwarding the update. -func (edsImpl *edsBalancerImpl) updateState(priority priorityType, s balancer.State) { - _, ok := edsImpl.priorityToLocalities[priority] - if !ok { - edsImpl.logger.Infof("eds: received picker update from unknown priority") - return - } - - if edsImpl.handlePriorityWithNewState(priority, s) { - edsImpl.pickerMu.Lock() - defer edsImpl.pickerMu.Unlock() - edsImpl.innerState = s - // Don't reset drops when it's a state change. - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: newDropPicker(s.Picker, edsImpl.drops, edsImpl.loadReporter, edsImpl.serviceRequestsCounter, edsImpl.serviceRequestCountMax)}) - } -} - -func (edsImpl *edsBalancerImpl) ccWrapperWithPriority(priority priorityType) *edsBalancerWrapperCC { - return &edsBalancerWrapperCC{ - ClientConn: edsImpl.cc, - priority: priority, - parent: edsImpl, - } -} - -// edsBalancerWrapperCC implements the balancer.ClientConn API and get passed to -// each balancer group. It contains the locality priority. -type edsBalancerWrapperCC struct { - balancer.ClientConn - priority priorityType - parent *edsBalancerImpl -} - -func (ebwcc *edsBalancerWrapperCC) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { - return ebwcc.parent.newSubConn(ebwcc.priority, addrs, opts) -} -func (ebwcc *edsBalancerWrapperCC) UpdateState(state balancer.State) { - ebwcc.parent.enqueueChildBalancerStateUpdate(ebwcc.priority, state) -} - -func (edsImpl *edsBalancerImpl) newSubConn(priority priorityType, addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { - sc, err := edsImpl.cc.NewSubConn(addrs, opts) - if err != nil { - return nil, err - } - edsImpl.subConnMu.Lock() - edsImpl.subConnToPriority[sc] = priority - edsImpl.subConnMu.Unlock() - return sc, nil -} - -// close closes the balancer. -func (edsImpl *edsBalancerImpl) close() { - for _, bgwc := range edsImpl.priorityToLocalities { - if bg := bgwc.bg; bg != nil { - bgwc.stateAggregator.Stop() - bg.Close() - } - } -} - -type dropPicker struct { - drops []*dropper - p balancer.Picker - loadStore load.PerClusterReporter - counter *client.ServiceRequestsCounter - countMax uint32 -} - -func newDropPicker(p balancer.Picker, drops []*dropper, loadStore load.PerClusterReporter, counter *client.ServiceRequestsCounter, countMax uint32) *dropPicker { - return &dropPicker{ - drops: drops, - p: p, - loadStore: loadStore, - counter: counter, - countMax: countMax, - } -} - -func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - var ( - drop bool - category string - ) - for _, dp := range d.drops { - if dp.drop() { - drop = true - category = dp.c.Category - break - } - } - if drop { - if d.loadStore != nil { - d.loadStore.CallDropped(category) - } - return balancer.PickResult{}, status.Errorf(codes.Unavailable, "RPC is dropped") - } - if d.counter != nil { - if err := d.counter.StartRequest(d.countMax); err != nil { - // Drops by circuit breaking are reported with empty category. They - // will be reported only in total drops, but not in per category. - if d.loadStore != nil { - d.loadStore.CallDropped("") - } - return balancer.PickResult{}, status.Errorf(codes.Unavailable, err.Error()) - } - pr, err := d.p.Pick(info) - if err != nil { - d.counter.EndRequest() - return pr, err - } - oldDone := pr.Done - pr.Done = func(doneInfo balancer.DoneInfo) { - d.counter.EndRequest() - if oldDone != nil { - oldDone(doneInfo) - } - } - return pr, err - } - // TODO: (eds) don't drop unless the inner picker is READY. Similar to - // https://github.com/grpc/grpc-go/issues/2622. - return d.p.Pick(info) -} diff --git a/xds/internal/balancer/edsbalancer/eds_impl_priority.go b/xds/internal/balancer/edsbalancer/eds_impl_priority.go deleted file mode 100644 index 53ac6ef5e87..00000000000 --- a/xds/internal/balancer/edsbalancer/eds_impl_priority.go +++ /dev/null @@ -1,358 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "errors" - "fmt" - "time" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/base" - "google.golang.org/grpc/connectivity" -) - -var errAllPrioritiesRemoved = errors.New("eds: no locality is provided, all priorities are removed") - -// handlePriorityChange handles priority after EDS adds/removes a -// priority. -// -// - If all priorities were deleted, unset priorityInUse, and set parent -// ClientConn to TransientFailure -// - If priorityInUse wasn't set, this is either the first EDS resp, or the -// previous EDS resp deleted everything. Set priorityInUse to 0, and start 0. -// - If priorityInUse was deleted, send the picker from the new lowest priority -// to parent ClientConn, and set priorityInUse to the new lowest. -// - If priorityInUse has a non-Ready state, and also there's a priority lower -// than priorityInUse (which means a lower priority was added), set the next -// priority as new priorityInUse, and start the bg. -func (edsImpl *edsBalancerImpl) handlePriorityChange() { - edsImpl.priorityMu.Lock() - defer edsImpl.priorityMu.Unlock() - - // Everything was removed by EDS. - if !edsImpl.priorityLowest.isSet() { - edsImpl.priorityInUse = newPriorityTypeUnset() - // Stop the init timer. This can happen if the only priority is removed - // shortly after it's added. - if timer := edsImpl.priorityInitTimer; timer != nil { - timer.Stop() - edsImpl.priorityInitTimer = nil - } - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: connectivity.TransientFailure, Picker: base.NewErrPicker(errAllPrioritiesRemoved)}) - return - } - - // priorityInUse wasn't set, use 0. - if !edsImpl.priorityInUse.isSet() { - edsImpl.logger.Infof("Switching priority from unset to %v", 0) - edsImpl.startPriority(newPriorityType(0)) - return - } - - // priorityInUse was deleted, use the new lowest. - if _, ok := edsImpl.priorityToLocalities[edsImpl.priorityInUse]; !ok { - oldP := edsImpl.priorityInUse - edsImpl.priorityInUse = edsImpl.priorityLowest - edsImpl.logger.Infof("Switching priority from %v to %v, because former was deleted", oldP, edsImpl.priorityInUse) - if s, ok := edsImpl.priorityToState[edsImpl.priorityLowest]; ok { - edsImpl.cc.UpdateState(*s) - } else { - // If state for priorityLowest is not found, this means priorityLowest was - // started, but never sent any update. The init timer fired and - // triggered the next priority. The old_priorityInUse (that was just - // deleted EDS) was picked later. - // - // We don't have an old state to send to parent, but we also don't - // want parent to keep using picker from old_priorityInUse. Send an - // update to trigger block picks until a new picker is ready. - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable)}) - } - return - } - - // priorityInUse is not ready, look for next priority, and use if found. - if s, ok := edsImpl.priorityToState[edsImpl.priorityInUse]; ok && s.ConnectivityState != connectivity.Ready { - pNext := edsImpl.priorityInUse.nextLower() - if _, ok := edsImpl.priorityToLocalities[pNext]; ok { - edsImpl.logger.Infof("Switching priority from %v to %v, because latter was added, and former wasn't Ready") - edsImpl.startPriority(pNext) - } - } -} - -// startPriority sets priorityInUse to p, and starts the balancer group for p. -// It also starts a timer to fall to next priority after timeout. -// -// Caller must hold priorityMu, priority must exist, and edsImpl.priorityInUse -// must be non-nil. -func (edsImpl *edsBalancerImpl) startPriority(priority priorityType) { - edsImpl.priorityInUse = priority - p := edsImpl.priorityToLocalities[priority] - // NOTE: this will eventually send addresses to sub-balancers. If the - // sub-balancer tries to update picker, it will result in a deadlock on - // priorityMu in the update is handled synchronously. The deadlock is - // currently avoided by handling balancer update in a goroutine (the run - // goroutine in the parent eds balancer). When priority balancer is split - // into its own, this asynchronous state handling needs to be copied. - p.stateAggregator.Start() - p.bg.Start() - // startPriority can be called when - // 1. first EDS resp, start p0 - // 2. a high priority goes Failure, start next - // 3. a high priority init timeout, start next - // - // In all the cases, the existing init timer is either closed, also already - // expired. There's no need to close the old timer. - edsImpl.priorityInitTimer = time.AfterFunc(defaultPriorityInitTimeout, func() { - edsImpl.priorityMu.Lock() - defer edsImpl.priorityMu.Unlock() - if !edsImpl.priorityInUse.isSet() || !edsImpl.priorityInUse.equal(priority) { - return - } - edsImpl.priorityInitTimer = nil - pNext := priority.nextLower() - if _, ok := edsImpl.priorityToLocalities[pNext]; ok { - edsImpl.startPriority(pNext) - } - }) -} - -// handlePriorityWithNewState start/close priorities based on the connectivity -// state. It returns whether the state should be forwarded to parent ClientConn. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewState(priority priorityType, s balancer.State) bool { - edsImpl.priorityMu.Lock() - defer edsImpl.priorityMu.Unlock() - - if !edsImpl.priorityInUse.isSet() { - edsImpl.logger.Infof("eds: received picker update when no priority is in use (EDS returned an empty list)") - return false - } - - if edsImpl.priorityInUse.higherThan(priority) { - // Lower priorities should all be closed, this is an unexpected update. - edsImpl.logger.Infof("eds: received picker update from priority lower then priorityInUse") - return false - } - - bState, ok := edsImpl.priorityToState[priority] - if !ok { - bState = &balancer.State{} - edsImpl.priorityToState[priority] = bState - } - oldState := bState.ConnectivityState - *bState = s - - switch s.ConnectivityState { - case connectivity.Ready: - return edsImpl.handlePriorityWithNewStateReady(priority) - case connectivity.TransientFailure: - return edsImpl.handlePriorityWithNewStateTransientFailure(priority) - case connectivity.Connecting: - return edsImpl.handlePriorityWithNewStateConnecting(priority, oldState) - default: - // New state is Idle, should never happen. Don't forward. - return false - } -} - -// handlePriorityWithNewStateReady handles state Ready and decides whether to -// forward update or not. -// -// An update with state Ready: -// - If it's from higher priority: -// - Forward the update -// - Set the priority as priorityInUse -// - Close all priorities lower than this one -// - If it's from priorityInUse: -// - Forward and do nothing else -// -// Caller must make sure priorityInUse is not higher than priority. -// -// Caller must hold priorityMu. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewStateReady(priority priorityType) bool { - // If one priority higher or equal to priorityInUse goes Ready, stop the - // init timer. If update is from higher than priorityInUse, - // priorityInUse will be closed, and the init timer will become useless. - if timer := edsImpl.priorityInitTimer; timer != nil { - timer.Stop() - edsImpl.priorityInitTimer = nil - } - - if edsImpl.priorityInUse.lowerThan(priority) { - edsImpl.logger.Infof("Switching priority from %v to %v, because latter became Ready", edsImpl.priorityInUse, priority) - edsImpl.priorityInUse = priority - for i := priority.nextLower(); !i.lowerThan(edsImpl.priorityLowest); i = i.nextLower() { - bgwc := edsImpl.priorityToLocalities[i] - bgwc.stateAggregator.Stop() - bgwc.bg.Close() - } - return true - } - return true -} - -// handlePriorityWithNewStateTransientFailure handles state TransientFailure and -// decides whether to forward update or not. -// -// An update with state Failure: -// - If it's from a higher priority: -// - Do not forward, and do nothing -// - If it's from priorityInUse: -// - If there's no lower: -// - Forward and do nothing else -// - If there's a lower priority: -// - Forward -// - Set lower as priorityInUse -// - Start lower -// -// Caller must make sure priorityInUse is not higher than priority. -// -// Caller must hold priorityMu. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewStateTransientFailure(priority priorityType) bool { - if edsImpl.priorityInUse.lowerThan(priority) { - return false - } - // priorityInUse sends a failure. Stop its init timer. - if timer := edsImpl.priorityInitTimer; timer != nil { - timer.Stop() - edsImpl.priorityInitTimer = nil - } - pNext := priority.nextLower() - if _, okNext := edsImpl.priorityToLocalities[pNext]; !okNext { - return true - } - edsImpl.logger.Infof("Switching priority from %v to %v, because former became TransientFailure", priority, pNext) - edsImpl.startPriority(pNext) - return true -} - -// handlePriorityWithNewStateConnecting handles state Connecting and decides -// whether to forward update or not. -// -// An update with state Connecting: -// - If it's from a higher priority -// - Do nothing -// - If it's from priorityInUse, the behavior depends on previous state. -// -// When new state is Connecting, the behavior depends on previous state. If the -// previous state was Ready, this is a transition out from Ready to Connecting. -// Assuming there are multiple backends in the same priority, this mean we are -// in a bad situation and we should failover to the next priority (Side note: -// the current connectivity state aggregating algorhtim (e.g. round-robin) is -// not handling this right, because if many backends all go from Ready to -// Connecting, the overall situation is more like TransientFailure, not -// Connecting). -// -// If the previous state was Idle, we don't do anything special with failure, -// and simply forward the update. The init timer should be in process, will -// handle failover if it timeouts. If the previous state was TransientFailure, -// we do not forward, because the lower priority is in use. -// -// Caller must make sure priorityInUse is not higher than priority. -// -// Caller must hold priorityMu. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewStateConnecting(priority priorityType, oldState connectivity.State) bool { - if edsImpl.priorityInUse.lowerThan(priority) { - return false - } - - switch oldState { - case connectivity.Ready: - pNext := priority.nextLower() - if _, okNext := edsImpl.priorityToLocalities[pNext]; !okNext { - return true - } - edsImpl.logger.Infof("Switching priority from %v to %v, because former became Connecting from Ready", priority, pNext) - edsImpl.startPriority(pNext) - return true - case connectivity.Idle: - return true - case connectivity.TransientFailure: - return false - default: - // Old state is Connecting or Shutdown. Don't forward. - return false - } -} - -// priorityType represents the priority from EDS response. -// -// 0 is the highest priority. The bigger the number, the lower the priority. -type priorityType struct { - set bool - p uint32 -} - -func newPriorityType(p uint32) priorityType { - return priorityType{ - set: true, - p: p, - } -} - -func newPriorityTypeUnset() priorityType { - return priorityType{} -} - -func (p priorityType) isSet() bool { - return p.set -} - -func (p priorityType) equal(p2 priorityType) bool { - if !p.isSet() && !p2.isSet() { - return true - } - if !p.isSet() || !p2.isSet() { - return false - } - return p == p2 -} - -func (p priorityType) higherThan(p2 priorityType) bool { - if !p.isSet() || !p2.isSet() { - // TODO(menghanl): return an appropriate value instead of panic. - panic("priority unset") - } - return p.p < p2.p -} - -func (p priorityType) lowerThan(p2 priorityType) bool { - if !p.isSet() || !p2.isSet() { - // TODO(menghanl): return an appropriate value instead of panic. - panic("priority unset") - } - return p.p > p2.p -} - -func (p priorityType) nextLower() priorityType { - if !p.isSet() { - panic("priority unset") - } - return priorityType{ - set: true, - p: p.p + 1, - } -} - -func (p priorityType) String() string { - if !p.set { - return "Nil" - } - return fmt.Sprint(p.p) -} diff --git a/xds/internal/balancer/edsbalancer/eds_impl_test.go b/xds/internal/balancer/edsbalancer/eds_impl_test.go deleted file mode 100644 index 3334376f4a9..00000000000 --- a/xds/internal/balancer/edsbalancer/eds_impl_test.go +++ /dev/null @@ -1,935 +0,0 @@ -/* - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "fmt" - "reflect" - "sort" - "testing" - "time" - - corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/roundrobin" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/internal/balancer/stub" - "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/balancer/balancergroup" - "google.golang.org/grpc/xds/internal/client" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/testutils" -) - -var ( - testClusterNames = []string{"test-cluster-1", "test-cluster-2"} - testSubZones = []string{"I", "II", "III", "IV"} - testEndpointAddrs []string -) - -const testBackendAddrsCount = 12 - -func init() { - for i := 0; i < testBackendAddrsCount; i++ { - testEndpointAddrs = append(testEndpointAddrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) - } - balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond -} - -// One locality -// - add backend -// - remove backend -// - replace backend -// - change drop rate -func (s) TestEDS_OneLocality(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Pick with only the first backend. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - } - - // The same locality, add one more backend. - clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p2 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // The same locality, delete first backend. - clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) - - scToRemove := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - - // Test pick with only the second subconn. - p3 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p3.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } - } - - // The same locality, replace backend. - clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab4.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab4.Build())) - - sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) - scToRemove = <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - - // Test pick with only the third subconn. - p4 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p4.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc3, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc3) - } - } - - // The same locality, different drop rate, dropping 50%. - clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], map[string]uint32{"test-drop": 50}) - clab5.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab5.Build())) - - // Picks with drops. - p5 := <-cc.NewPickerCh - for i := 0; i < 100; i++ { - _, err := p5.Pick(balancer.PickInfo{}) - // TODO: the dropping algorithm needs a design. When the dropping algorithm - // is fixed, this test also needs fix. - if i < 50 && err == nil { - t.Errorf("The first 50%% picks should be drops, got error ") - } else if i > 50 && err != nil { - t.Errorf("The second 50%% picks should be non-drops, got error %v", err) - } - } - - // The same locality, remove drops. - clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab6.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab6.Build())) - - // Pick without drops. - p6 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p6.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc3, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc3) - } - } -} - -// 2 locality -// - start with 2 locality -// - add locality -// - remove locality -// - address change for the locality -// - update locality weight -func (s) TestEDS_TwoLocalities(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // Two localities, each with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Add the second locality later to make sure sc2 belongs to the second - // locality. Otherwise the test is flaky because of a map is used in EDS to - // keep localities. - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Add another locality, with one backend. - clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - clab2.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - clab2.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - - sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) - - // Test roundrobin with three subconns. - p2 := <-cc.NewPickerCh - want = []balancer.SubConn{sc1, sc2, sc3} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Remove first locality. - clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab3.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - clab3.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) - - scToRemove := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - - // Test pick with two subconns (without the first one). - p3 := <-cc.NewPickerCh - want = []balancer.SubConn{sc2, sc3} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p3)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Add a backend to the last locality. - clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab4.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - clab4.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab4.Build())) - - sc4 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc4, connectivity.Connecting) - edsb.handleSubConnStateChange(sc4, connectivity.Ready) - - // Test pick with two subconns (without the first one). - p4 := <-cc.NewPickerCh - // Locality-1 will be picked twice, and locality-2 will be picked twice. - // Locality-1 contains only sc2, locality-2 contains sc3 and sc4. So expect - // two sc2's and sc3, sc4. - want = []balancer.SubConn{sc2, sc2, sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p4)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Change weight of the locality[1]. - clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab5.AddLocality(testSubZones[1], 2, 0, testEndpointAddrs[1:2], nil) - clab5.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab5.Build())) - - // Test pick with two subconns different locality weight. - p5 := <-cc.NewPickerCh - // Locality-1 will be picked four times, and locality-2 will be picked twice - // (weight 2 and 1). Locality-1 contains only sc2, locality-2 contains sc3 and - // sc4. So expect four sc2's and sc3, sc4. - want = []balancer.SubConn{sc2, sc2, sc2, sc2, sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p5)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Change weight of the locality[1] to 0, it should never be picked. - clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab6.AddLocality(testSubZones[1], 0, 0, testEndpointAddrs[1:2], nil) - clab6.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab6.Build())) - - // Changing weight of locality[1] to 0 caused it to be removed. It's subconn - // should also be removed. - // - // NOTE: this is because we handle locality with weight 0 same as the - // locality doesn't exist. If this changes in the future, this removeSubConn - // behavior will also change. - scToRemove2 := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove2, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove2) - } - - // Test pick with two subconns different locality weight. - p6 := <-cc.NewPickerCh - // Locality-1 will be not be picked, and locality-2 will be picked. - // Locality-2 contains sc3 and sc4. So expect sc3, sc4. - want = []balancer.SubConn{sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p6)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } -} - -// The EDS balancer gets EDS resp with unhealthy endpoints. Test that only -// healthy ones are used. -func (s) TestEDS_EndpointsHealth(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // Two localities, each 3 backend, one Healthy, one Unhealthy, one Unknown. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:6], &testutils.AddLocalityOptions{ - Health: []corepb.HealthStatus{ - corepb.HealthStatus_HEALTHY, - corepb.HealthStatus_UNHEALTHY, - corepb.HealthStatus_UNKNOWN, - corepb.HealthStatus_DRAINING, - corepb.HealthStatus_TIMEOUT, - corepb.HealthStatus_DEGRADED, - }, - }) - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[6:12], &testutils.AddLocalityOptions{ - Health: []corepb.HealthStatus{ - corepb.HealthStatus_HEALTHY, - corepb.HealthStatus_UNHEALTHY, - corepb.HealthStatus_UNKNOWN, - corepb.HealthStatus_DRAINING, - corepb.HealthStatus_TIMEOUT, - corepb.HealthStatus_DEGRADED, - }, - }) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - var ( - readySCs []balancer.SubConn - newSubConnAddrStrs []string - ) - for i := 0; i < 4; i++ { - addr := <-cc.NewSubConnAddrsCh - newSubConnAddrStrs = append(newSubConnAddrStrs, addr[0].Addr) - sc := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc, connectivity.Connecting) - edsb.handleSubConnStateChange(sc, connectivity.Ready) - readySCs = append(readySCs, sc) - } - - wantNewSubConnAddrStrs := []string{ - testEndpointAddrs[0], - testEndpointAddrs[2], - testEndpointAddrs[6], - testEndpointAddrs[8], - } - sortStrTrans := cmp.Transformer("Sort", func(in []string) []string { - out := append([]string(nil), in...) // Copy input to avoid mutating it. - sort.Strings(out) - return out - }) - if !cmp.Equal(newSubConnAddrStrs, wantNewSubConnAddrStrs, sortStrTrans) { - t.Fatalf("want newSubConn with address %v, got %v", wantNewSubConnAddrStrs, newSubConnAddrStrs) - } - - // There should be exactly 4 new SubConns. Check to make sure there's no - // more subconns being created. - select { - case <-cc.NewSubConnCh: - t.Fatalf("Got unexpected new subconn") - case <-time.After(time.Microsecond * 100): - } - - // Test roundrobin with the subconns. - p1 := <-cc.NewPickerCh - want := readySCs - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } -} - -func (s) TestClose(t *testing.T) { - edsb := newEDSBalancerImpl(nil, balancer.BuildOptions{}, nil, nil, nil) - // This is what could happen when switching between fallback and eds. This - // make sure it doesn't panic. - edsb.close() -} - -// TestEDS_EmptyUpdate covers the cases when eds impl receives an empty update. -// -// It should send an error picker with transient failure to the parent. -func (s) TestEDS_EmptyUpdate(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // The first update is an empty update. - edsb.handleEDSResponse(xdsclient.EndpointsUpdate{}) - // Pick should fail with transient failure, and all priority removed error. - perr0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := perr0.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(err, errAllPrioritiesRemoved) { - t.Fatalf("picker.Pick, got error %v, want error %v", err, errAllPrioritiesRemoved) - } - } - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Pick with only the first backend. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(gotSCSt.SubConn, sc1) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - } - - edsb.handleEDSResponse(xdsclient.EndpointsUpdate{}) - // Pick should fail with transient failure, and all priority removed error. - perr1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := perr1.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(err, errAllPrioritiesRemoved) { - t.Fatalf("picker.Pick, got error %v, want error %v", err, errAllPrioritiesRemoved) - } - } - - // Handle another update with priorities and localities. - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Pick with only the first backend. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(gotSCSt.SubConn, sc2) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } - } -} - -// Create XDS balancer, and update sub-balancer before handling eds responses. -// Then switch between round-robin and a test stub-balancer after handling first -// eds response. -func (s) TestEDS_UpdateSubBalancerName(t *testing.T) { - const balancerName = "stubBalancer-TestEDS_UpdateSubBalancerName" - - cc := testutils.NewTestClientConn(t) - stub.Register(balancerName, stub.BalancerFuncs{ - UpdateClientConnState: func(bd *stub.BalancerData, s balancer.ClientConnState) error { - if len(s.ResolverState.Addresses) == 0 { - return nil - } - bd.ClientConn.NewSubConn(s.ResolverState.Addresses, balancer.NewSubConnOptions{}) - return nil - }, - UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { - bd.ClientConn.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &testutils.TestConstPicker{Err: testutils.ErrTestConstPicker}, - }) - }, - }) - - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - t.Logf("update sub-balancer to stub-balancer") - edsb.handleChildPolicy(balancerName, nil) - - // Two localities, each with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - for i := 0; i < 2; i++ { - sc := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc, connectivity.Ready) - } - - p0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := p0.Pick(balancer.PickInfo{}) - if err != testutils.ErrTestConstPicker { - t.Fatalf("picker.Pick, got err %+v, want err %+v", err, testutils.ErrTestConstPicker) - } - } - - t.Logf("update sub-balancer to round-robin") - edsb.handleChildPolicy(roundrobin.Name, nil) - - for i := 0; i < 2; i++ { - <-cc.RemoveSubConnCh - } - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - t.Logf("update sub-balancer to stub-balancer") - edsb.handleChildPolicy(balancerName, nil) - - for i := 0; i < 2; i++ { - scToRemove := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) && - !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want (%v or %v), got %v", sc1, sc2, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - } - - for i := 0; i < 2; i++ { - sc := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc, connectivity.Ready) - } - - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := p2.Pick(balancer.PickInfo{}) - if err != testutils.ErrTestConstPicker { - t.Fatalf("picker.Pick, got err %q, want err %q", err, testutils.ErrTestConstPicker) - } - } - - t.Logf("update sub-balancer to round-robin") - edsb.handleChildPolicy(roundrobin.Name, nil) - - for i := 0; i < 2; i++ { - <-cc.RemoveSubConnCh - } - - sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) - sc4 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc4, connectivity.Connecting) - edsb.handleSubConnStateChange(sc4, connectivity.Ready) - - p3 := <-cc.NewPickerCh - want = []balancer.SubConn{sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p3)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } -} - -func (s) TestEDS_CircuitBreaking(t *testing.T) { - origCircuitBreakingSupport := env.CircuitBreakingSupport - env.CircuitBreakingSupport = true - defer func() { env.CircuitBreakingSupport = origCircuitBreakingSupport }() - - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - var maxRequests uint32 = 50 - edsb.updateServiceRequestsConfig("test", &maxRequests) - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Picks with drops. - dones := []func(){} - p := <-cc.NewPickerCh - for i := 0; i < 100; i++ { - pr, err := p.Pick(balancer.PickInfo{}) - if i < 50 && err != nil { - t.Errorf("The first 50%% picks should be non-drops, got error %v", err) - } else if i > 50 && err == nil { - t.Errorf("The second 50%% picks should be drops, got error ") - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - for _, done := range dones { - done() - } - dones = []func(){} - - // Pick without drops. - for i := 0; i < 50; i++ { - pr, err := p.Pick(balancer.PickInfo{}) - if err != nil { - t.Errorf("The third 50%% picks should be non-drops, got error %v", err) - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - // Without this, future tests with the same service name will fail. - for _, done := range dones { - done() - } - - // Send another update, with only circuit breaking update (and no picker - // update afterwards). Make sure the new picker uses the new configs. - var maxRequests2 uint32 = 10 - edsb.updateServiceRequestsConfig("test", &maxRequests2) - - // Picks with drops. - dones = []func(){} - p2 := <-cc.NewPickerCh - for i := 0; i < 100; i++ { - pr, err := p2.Pick(balancer.PickInfo{}) - if i < 10 && err != nil { - t.Errorf("The first 10%% picks should be non-drops, got error %v", err) - } else if i > 10 && err == nil { - t.Errorf("The next 90%% picks should be drops, got error ") - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - for _, done := range dones { - done() - } - dones = []func(){} - - // Pick without drops. - for i := 0; i < 10; i++ { - pr, err := p2.Pick(balancer.PickInfo{}) - if err != nil { - t.Errorf("The next 10%% picks should be non-drops, got error %v", err) - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - // Without this, future tests with the same service name will fail. - for _, done := range dones { - done() - } -} - -func init() { - balancer.Register(&testInlineUpdateBalancerBuilder{}) -} - -// A test balancer that updates balancer.State inline when handling ClientConn -// state. -type testInlineUpdateBalancerBuilder struct{} - -func (*testInlineUpdateBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - return &testInlineUpdateBalancer{cc: cc} -} - -func (*testInlineUpdateBalancerBuilder) Name() string { - return "test-inline-update-balancer" -} - -type testInlineUpdateBalancer struct { - cc balancer.ClientConn -} - -func (tb *testInlineUpdateBalancer) ResolverError(error) { - panic("not implemented") -} - -func (tb *testInlineUpdateBalancer) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) { -} - -var errTestInlineStateUpdate = fmt.Errorf("don't like addresses, empty or not") - -func (tb *testInlineUpdateBalancer) UpdateClientConnState(balancer.ClientConnState) error { - tb.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Ready, - Picker: &testutils.TestConstPicker{Err: errTestInlineStateUpdate}, - }) - return nil -} - -func (*testInlineUpdateBalancer) Close() { -} - -// When the child policy update picker inline in a handleClientUpdate call -// (e.g., roundrobin handling empty addresses). There could be deadlock caused -// by acquiring a locked mutex. -func (s) TestEDS_ChildPolicyUpdatePickerInline(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = func(p priorityType, state balancer.State) { - // For this test, euqueue needs to happen asynchronously (like in the - // real implementation). - go edsb.updateState(p, state) - } - - edsb.handleChildPolicy("test-inline-update-balancer", nil) - - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - p0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := p0.Pick(balancer.PickInfo{}) - if err != errTestInlineStateUpdate { - t.Fatalf("picker.Pick, got err %q, want err %q", err, errTestInlineStateUpdate) - } - } -} - -func (s) TestDropPicker(t *testing.T) { - const pickCount = 12 - var constPicker = &testutils.TestConstPicker{ - SC: testutils.TestSubConns[0], - } - - tests := []struct { - name string - drops []*dropper - }{ - { - name: "no drop", - drops: nil, - }, - { - name: "one drop", - drops: []*dropper{ - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 2}), - }, - }, - { - name: "two drops", - drops: []*dropper{ - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 3}), - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 2}), - }, - }, - { - name: "three drops", - drops: []*dropper{ - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 3}), - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 4}), - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 2}), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - p := newDropPicker(constPicker, tt.drops, nil, nil, defaultServiceRequestCountMax) - - // scCount is the number of sc's returned by pick. The opposite of - // drop-count. - var ( - scCount int - wantCount = pickCount - ) - for _, dp := range tt.drops { - wantCount = wantCount * int(dp.c.Denominator-dp.c.Numerator) / int(dp.c.Denominator) - } - - for i := 0; i < pickCount; i++ { - _, err := p.Pick(balancer.PickInfo{}) - if err == nil { - scCount++ - } - } - - if scCount != (wantCount) { - t.Errorf("drops: %+v, scCount %v, wantCount %v", tt.drops, scCount, wantCount) - } - }) - } -} - -func (s) TestEDS_LoadReport(t *testing.T) { - origCircuitBreakingSupport := env.CircuitBreakingSupport - env.CircuitBreakingSupport = true - defer func() { env.CircuitBreakingSupport = origCircuitBreakingSupport }() - - // We create an xdsClientWrapper with a dummy xdsClientInterface which only - // implements the LoadStore() method to return the underlying load.Store to - // be used. - loadStore := load.NewStore() - lsWrapper := &loadStoreWrapper{} - lsWrapper.updateServiceName(testClusterNames[0]) - lsWrapper.updateLoadStore(loadStore) - - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, lsWrapper, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - const ( - testServiceName = "test-service" - cbMaxRequests = 20 - ) - var maxRequestsTemp uint32 = cbMaxRequests - edsb.updateServiceRequestsConfig(testServiceName, &maxRequestsTemp) - defer client.ClearCounterForTesting(testServiceName) - - backendToBalancerID := make(map[balancer.SubConn]internal.LocalityID) - - const testDropCategory = "test-drop" - // Two localities, each with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], map[string]uint32{testDropCategory: 50}) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - locality1 := internal.LocalityID{SubZone: testSubZones[0]} - backendToBalancerID[sc1] = locality1 - - // Add the second locality later to make sure sc2 belongs to the second - // locality. Otherwise the test is flaky because of a map is used in EDS to - // keep localities. - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - locality2 := internal.LocalityID{SubZone: testSubZones[1]} - backendToBalancerID[sc2] = locality2 - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - // We expect the 10 picks to be split between the localities since they are - // of equal weight. And since we only mark the picks routed to sc2 as done, - // the picks on sc1 should show up as inProgress. - locality1JSON, _ := locality1.ToString() - locality2JSON, _ := locality2.ToString() - const ( - rpcCount = 100 - // 50% will be dropped with category testDropCategory. - dropWithCategory = rpcCount / 2 - // In the remaining RPCs, only cbMaxRequests are allowed by circuit - // breaking. Others will be dropped by CB. - dropWithCB = rpcCount - dropWithCategory - cbMaxRequests - - rpcInProgress = cbMaxRequests / 2 // 50% of RPCs will be never done. - rpcSucceeded = cbMaxRequests / 2 // 50% of RPCs will succeed. - ) - wantStoreData := []*load.Data{{ - Cluster: testClusterNames[0], - Service: "", - LocalityStats: map[string]load.LocalityData{ - locality1JSON: {RequestStats: load.RequestData{InProgress: rpcInProgress}}, - locality2JSON: {RequestStats: load.RequestData{Succeeded: rpcSucceeded}}, - }, - TotalDrops: dropWithCategory + dropWithCB, - Drops: map[string]uint64{ - testDropCategory: dropWithCategory, - }, - }} - - var rpcsToBeDone []balancer.PickResult - // Run the picks, but only pick with sc1 will be done later. - for i := 0; i < rpcCount; i++ { - scst, _ := p1.Pick(balancer.PickInfo{}) - if scst.Done != nil && scst.SubConn != sc1 { - rpcsToBeDone = append(rpcsToBeDone, scst) - } - } - // Call done on those sc1 picks. - for _, scst := range rpcsToBeDone { - scst.Done(balancer.DoneInfo{}) - } - - gotStoreData := loadStore.Stats(testClusterNames[0:1]) - if diff := cmp.Diff(wantStoreData, gotStoreData, cmpopts.EquateEmpty(), cmpopts.IgnoreFields(load.Data{}, "ReportInterval")); diff != "" { - t.Errorf("store.stats() returned unexpected diff (-want +got):\n%s", diff) - } -} - -// TestEDS_LoadReportDisabled covers the case that LRS is disabled. It makes -// sure the EDS implementation isn't broken (doesn't panic). -func (s) TestEDS_LoadReportDisabled(t *testing.T) { - lsWrapper := &loadStoreWrapper{} - lsWrapper.updateServiceName(testClusterNames[0]) - // Not calling lsWrapper.updateLoadStore(loadStore) because LRS is disabled. - - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, lsWrapper, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // One localities, with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - // We call picks to make sure they don't panic. - for i := 0; i < 10; i++ { - p1.Pick(balancer.PickInfo{}) - } -} diff --git a/xds/internal/balancer/edsbalancer/eds_test.go b/xds/internal/balancer/edsbalancer/eds_test.go deleted file mode 100644 index 5fe1f2ef6b9..00000000000 --- a/xds/internal/balancer/edsbalancer/eds_test.go +++ /dev/null @@ -1,825 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package edsbalancer - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "reflect" - "testing" - "time" - - "github.com/golang/protobuf/jsonpb" - wrapperspb "github.com/golang/protobuf/ptypes/wrappers" - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/internal/grpctest" - scpb "google.golang.org/grpc/internal/proto/grpc_service_config" - "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" - "google.golang.org/grpc/xds/internal/testutils/fakeclient" - - _ "google.golang.org/grpc/xds/internal/client/v2" // V2 client registration. -) - -const ( - defaultTestTimeout = 1 * time.Second - defaultTestShortTimeout = 10 * time.Millisecond - testServiceName = "test/foo" - testEDSClusterName = "test/service/eds" -) - -var ( - // A non-empty endpoints update which is expected to be accepted by the EDS - // LB policy. - defaultEndpointsUpdate = xdsclient.EndpointsUpdate{ - Localities: []xdsclient.Locality{ - { - Endpoints: []xdsclient.Endpoint{{Address: "endpoint1"}}, - ID: internal.LocalityID{Zone: "zone"}, - Priority: 1, - Weight: 100, - }, - }, - } -) - -func init() { - balancer.Register(&edsBalancerBuilder{}) -} - -func subConnFromPicker(p balancer.Picker) func() balancer.SubConn { - return func() balancer.SubConn { - scst, _ := p.Pick(balancer.PickInfo{}) - return scst.SubConn - } -} - -type s struct { - grpctest.Tester -} - -func Test(t *testing.T) { - grpctest.RunSubTests(t, s{}) -} - -const testBalancerNameFooBar = "foo.bar" - -func newNoopTestClientConn() *noopTestClientConn { - return &noopTestClientConn{} -} - -// noopTestClientConn is used in EDS balancer config update tests that only -// cover the config update handling, but not SubConn/load-balancing. -type noopTestClientConn struct { - balancer.ClientConn -} - -func (t *noopTestClientConn) NewSubConn([]resolver.Address, balancer.NewSubConnOptions) (balancer.SubConn, error) { - return nil, nil -} - -func (noopTestClientConn) Target() string { return testServiceName } - -type scStateChange struct { - sc balancer.SubConn - state connectivity.State -} - -type fakeEDSBalancer struct { - cc balancer.ClientConn - childPolicy *testutils.Channel - subconnStateChange *testutils.Channel - edsUpdate *testutils.Channel - serviceName *testutils.Channel - serviceRequestMax *testutils.Channel -} - -func (f *fakeEDSBalancer) handleSubConnStateChange(sc balancer.SubConn, state connectivity.State) { - f.subconnStateChange.Send(&scStateChange{sc: sc, state: state}) -} - -func (f *fakeEDSBalancer) handleChildPolicy(name string, config json.RawMessage) { - f.childPolicy.Send(&loadBalancingConfig{Name: name, Config: config}) -} - -func (f *fakeEDSBalancer) handleEDSResponse(edsResp xdsclient.EndpointsUpdate) { - f.edsUpdate.Send(edsResp) -} - -func (f *fakeEDSBalancer) updateState(priority priorityType, s balancer.State) {} - -func (f *fakeEDSBalancer) updateServiceRequestsConfig(serviceName string, max *uint32) { - f.serviceName.Send(serviceName) - f.serviceRequestMax.Send(max) -} - -func (f *fakeEDSBalancer) close() {} - -func (f *fakeEDSBalancer) waitForChildPolicy(ctx context.Context, wantPolicy *loadBalancingConfig) error { - val, err := f.childPolicy.Receive(ctx) - if err != nil { - return err - } - gotPolicy := val.(*loadBalancingConfig) - if !cmp.Equal(gotPolicy, wantPolicy) { - return fmt.Errorf("got childPolicy %v, want %v", gotPolicy, wantPolicy) - } - return nil -} - -func (f *fakeEDSBalancer) waitForSubConnStateChange(ctx context.Context, wantState *scStateChange) error { - val, err := f.subconnStateChange.Receive(ctx) - if err != nil { - return err - } - gotState := val.(*scStateChange) - if !cmp.Equal(gotState, wantState, cmp.AllowUnexported(scStateChange{})) { - return fmt.Errorf("got subconnStateChange %v, want %v", gotState, wantState) - } - return nil -} - -func (f *fakeEDSBalancer) waitForEDSResponse(ctx context.Context, wantUpdate xdsclient.EndpointsUpdate) error { - val, err := f.edsUpdate.Receive(ctx) - if err != nil { - return err - } - gotUpdate := val.(xdsclient.EndpointsUpdate) - if !reflect.DeepEqual(gotUpdate, wantUpdate) { - return fmt.Errorf("got edsUpdate %+v, want %+v", gotUpdate, wantUpdate) - } - return nil -} - -func (f *fakeEDSBalancer) waitForCounterUpdate(ctx context.Context, wantServiceName string) error { - val, err := f.serviceName.Receive(ctx) - if err != nil { - return err - } - gotServiceName := val.(string) - if gotServiceName != wantServiceName { - return fmt.Errorf("got serviceName %v, want %v", gotServiceName, wantServiceName) - } - return nil -} - -func (f *fakeEDSBalancer) waitForCountMaxUpdate(ctx context.Context, want *uint32) error { - val, err := f.serviceRequestMax.Receive(ctx) - if err != nil { - return err - } - got := val.(*uint32) - - if got == nil && want == nil { - return nil - } - if got != nil && want != nil { - if *got != *want { - return fmt.Errorf("got countMax %v, want %v", *got, *want) - } - return nil - } - return fmt.Errorf("got countMax %+v, want %+v", got, want) -} - -func newFakeEDSBalancer(cc balancer.ClientConn) edsBalancerImplInterface { - return &fakeEDSBalancer{ - cc: cc, - childPolicy: testutils.NewChannelWithSize(10), - subconnStateChange: testutils.NewChannelWithSize(10), - edsUpdate: testutils.NewChannelWithSize(10), - serviceName: testutils.NewChannelWithSize(10), - serviceRequestMax: testutils.NewChannelWithSize(10), - } -} - -type fakeSubConn struct{} - -func (*fakeSubConn) UpdateAddresses([]resolver.Address) { panic("implement me") } -func (*fakeSubConn) Connect() { panic("implement me") } - -// waitForNewEDSLB makes sure that a new edsLB is created by the top-level -// edsBalancer. -func waitForNewEDSLB(ctx context.Context, ch *testutils.Channel) (*fakeEDSBalancer, error) { - val, err := ch.Receive(ctx) - if err != nil { - return nil, fmt.Errorf("error when waiting for a new edsLB: %v", err) - } - return val.(*fakeEDSBalancer), nil -} - -// setup overrides the functions which are used to create the xdsClient and the -// edsLB, creates fake version of them and makes them available on the provided -// channels. The returned cancel function should be called by the test for -// cleanup. -func setup(edsLBCh *testutils.Channel) (*fakeclient.Client, func()) { - xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - - origNewEDSBalancer := newEDSBalancer - newEDSBalancer = func(cc balancer.ClientConn, _ balancer.BuildOptions, _ func(priorityType, balancer.State), _ load.PerClusterReporter, _ *grpclog.PrefixLogger) edsBalancerImplInterface { - edsLB := newFakeEDSBalancer(cc) - defer func() { edsLBCh.Send(edsLB) }() - return edsLB - } - return xdsC, func() { - newEDSBalancer = origNewEDSBalancer - newXDSClient = oldNewXDSClient - } -} - -const ( - fakeBalancerA = "fake_balancer_A" - fakeBalancerB = "fake_balancer_B" -) - -// Install two fake balancers for service config update tests. -// -// ParseConfig only accepts the json if the balancer specified is registered. -func init() { - balancer.Register(&fakeBalancerBuilder{name: fakeBalancerA}) - balancer.Register(&fakeBalancerBuilder{name: fakeBalancerB}) -} - -type fakeBalancerBuilder struct { - name string -} - -func (b *fakeBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - return &fakeBalancer{cc: cc} -} - -func (b *fakeBalancerBuilder) Name() string { - return b.name -} - -type fakeBalancer struct { - cc balancer.ClientConn -} - -func (b *fakeBalancer) ResolverError(error) { - panic("implement me") -} - -func (b *fakeBalancer) UpdateClientConnState(balancer.ClientConnState) error { - panic("implement me") -} - -func (b *fakeBalancer) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) { - panic("implement me") -} - -func (b *fakeBalancer) Close() {} - -// TestConfigChildPolicyUpdate verifies scenarios where the childPolicy -// section of the lbConfig is updated. -// -// The test does the following: -// * Builds a new EDS balancer. -// * Pushes a new ClientConnState with a childPolicy set to fakeBalancerA. -// Verifies that an EDS watch is registered. It then pushes a new edsUpdate -// through the fakexds client. Verifies that a new edsLB is created and it -// receives the expected childPolicy. -// * Pushes a new ClientConnState with a childPolicy set to fakeBalancerB. -// Verifies that the existing edsLB receives the new child policy. -func (s) TestConfigChildPolicyUpdate(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - lbCfgA := &loadBalancingConfig{ - Name: fakeBalancerA, - Config: json.RawMessage("{}"), - } - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - ChildPolicy: lbCfgA, - EDSServiceName: testServiceName, - }, - }); err != nil { - t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(defaultEndpointsUpdate, nil) - if err := edsLB.waitForChildPolicy(ctx, lbCfgA); err != nil { - t.Fatal(err) - } - if err := edsLB.waitForCounterUpdate(ctx, testServiceName); err != nil { - t.Fatal(err) - } - if err := edsLB.waitForCountMaxUpdate(ctx, nil); err != nil { - t.Fatal(err) - } - - var testCountMax uint32 = 100 - lbCfgB := &loadBalancingConfig{ - Name: fakeBalancerB, - Config: json.RawMessage("{}"), - } - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - ChildPolicy: lbCfgB, - EDSServiceName: testServiceName, - MaxConcurrentRequests: &testCountMax, - }, - }); err != nil { - t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) - } - if err := edsLB.waitForChildPolicy(ctx, lbCfgB); err != nil { - t.Fatal(err) - } - if err := edsLB.waitForCounterUpdate(ctx, testServiceName); err != nil { - // Counter is updated even though the service name didn't change. The - // eds_impl will compare the service names, and skip if it didn't change. - t.Fatal(err) - } - if err := edsLB.waitForCountMaxUpdate(ctx, &testCountMax); err != nil { - t.Fatal(err) - } -} - -// TestSubConnStateChange verifies if the top-level edsBalancer passes on -// the subConnStateChange to appropriate child balancer. -func (s) TestSubConnStateChange(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(defaultEndpointsUpdate, nil) - - fsc := &fakeSubConn{} - state := connectivity.Ready - edsB.UpdateSubConnState(fsc, balancer.SubConnState{ConnectivityState: state}) - if err := edsLB.waitForSubConnStateChange(ctx, &scStateChange{sc: fsc, state: state}); err != nil { - t.Fatal(err) - } -} - -// TestErrorFromXDSClientUpdate verifies that an error from xdsClient update is -// handled correctly. -// -// If it's resource-not-found, watch will NOT be canceled, the EDS impl will -// receive an empty EDS update, and new RPCs will fail. -// -// If it's connection error, nothing will happen. This will need to change to -// handle fallback. -func (s) TestErrorFromXDSClientUpdate(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatal(err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil) - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("EDS impl got unexpected EDS response: %v", err) - } - - connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, connectionErr) - - sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { - t.Fatal("watch was canceled, want not canceled (timeout error)") - } - - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := edsLB.waitForEDSResponse(sCtx, xdsclient.EndpointsUpdate{}); err != context.DeadlineExceeded { - t.Fatal(err) - } - - resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "edsBalancer resource not found error") - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, resourceErr) - // Even if error is resource not found, watch shouldn't be canceled, because - // this is an EDS resource removed (and xds client actually never sends this - // error, but we still handles it). - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { - t.Fatal("watch was canceled, want not canceled (timeout error)") - } - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("eds impl expecting empty update, got %v", err) - } -} - -// TestErrorFromResolver verifies that resolver errors are handled correctly. -// -// If it's resource-not-found, watch will be canceled, the EDS impl will receive -// an empty EDS update, and new RPCs will fail. -// -// If it's connection error, nothing will happen. This will need to change to -// handle fallback. -func (s) TestErrorFromResolver(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatal(err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil) - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("EDS impl got unexpected EDS response: %v", err) - } - - connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") - edsB.ResolverError(connectionErr) - - sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { - t.Fatal("watch was canceled, want not canceled (timeout error)") - } - - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := edsLB.waitForEDSResponse(sCtx, xdsclient.EndpointsUpdate{}); err != context.DeadlineExceeded { - t.Fatal("eds impl got EDS resp, want timeout error") - } - - resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "edsBalancer resource not found error") - edsB.ResolverError(resourceErr) - if err := xdsC.WaitForCancelEDSWatch(ctx); err != nil { - t.Fatalf("want watch to be canceled, waitForCancel failed: %v", err) - } - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("EDS impl got unexpected EDS response: %v", err) - } -} - -// Given a list of resource names, verifies that EDS requests for the same are -// sent by the EDS balancer, through the fake xDS client. -func verifyExpectedRequests(ctx context.Context, fc *fakeclient.Client, resourceNames ...string) error { - for _, name := range resourceNames { - if name == "" { - // ResourceName empty string indicates a cancel. - if err := fc.WaitForCancelEDSWatch(ctx); err != nil { - return fmt.Errorf("timed out when expecting resource %q", name) - } - return nil - } - - resName, err := fc.WaitForWatchEDS(ctx) - if err != nil { - return fmt.Errorf("timed out when expecting resource %q, %p", name, fc) - } - if resName != name { - return fmt.Errorf("got EDS request for resource %q, expected: %q", resName, name) - } - } - return nil -} - -// TestClientWatchEDS verifies that the xdsClient inside the top-level EDS LB -// policy registers an EDS watch for expected resource upon receiving an update -// from gRPC. -func (s) TestClientWatchEDS(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - // Update with an non-empty edsServiceName should trigger an EDS watch for - // the same. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: "foobar-1"}, - }); err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if err := verifyExpectedRequests(ctx, xdsC, "foobar-1"); err != nil { - t.Fatal(err) - } - - // Also test the case where the edsServerName changes from one non-empty - // name to another, and make sure a new watch is registered. The previously - // registered watch will be cancelled, which will result in an EDS request - // with no resource names being sent to the server. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: "foobar-2"}, - }); err != nil { - t.Fatal(err) - } - if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-2"); err != nil { - t.Fatal(err) - } -} - -// TestCounterUpdate verifies that the counter update is triggered with the -// service name from an update's config. -func (s) TestCounterUpdate(t *testing.T) { - edsLBCh := testutils.NewChannel() - _, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - var testCountMax uint32 = 100 - // Update should trigger counter update with provided service name. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - EDSServiceName: "foobar-1", - MaxConcurrentRequests: &testCountMax, - }, - }); err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsI := edsB.(*edsBalancer).edsImpl.(*fakeEDSBalancer) - if err := edsI.waitForCounterUpdate(ctx, "foobar-1"); err != nil { - t.Fatal(err) - } - if err := edsI.waitForCountMaxUpdate(ctx, &testCountMax); err != nil { - t.Fatal(err) - } -} - -func (s) TestBalancerConfigParsing(t *testing.T) { - const testEDSName = "eds.service" - var testLRSName = "lrs.server" - b := bytes.NewBuffer(nil) - if err := (&jsonpb.Marshaler{}).Marshal(b, &scpb.XdsConfig{ - ChildPolicy: []*scpb.LoadBalancingConfig{ - {Policy: &scpb.LoadBalancingConfig_Xds{}}, - {Policy: &scpb.LoadBalancingConfig_RoundRobin{ - RoundRobin: &scpb.RoundRobinConfig{}, - }}, - }, - FallbackPolicy: []*scpb.LoadBalancingConfig{ - {Policy: &scpb.LoadBalancingConfig_Xds{}}, - {Policy: &scpb.LoadBalancingConfig_PickFirst{ - PickFirst: &scpb.PickFirstConfig{}, - }}, - }, - EdsServiceName: testEDSName, - LrsLoadReportingServerName: &wrapperspb.StringValue{Value: testLRSName}, - }); err != nil { - t.Fatalf("%v", err) - } - - var testMaxConcurrentRequests uint32 = 123 - tests := []struct { - name string - js json.RawMessage - want serviceconfig.LoadBalancingConfig - wantErr bool - }{ - { - name: "bad json", - js: json.RawMessage(`i am not JSON`), - wantErr: true, - }, - { - name: "empty", - js: json.RawMessage(`{}`), - want: &EDSConfig{}, - }, - { - name: "jsonpb-generated", - js: b.Bytes(), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "round_robin", - Config: json.RawMessage("{}"), - }, - FallBackPolicy: &loadBalancingConfig{ - Name: "pick_first", - Config: json.RawMessage("{}"), - }, - EDSServiceName: testEDSName, - LrsLoadReportingServerName: &testLRSName, - }, - }, - { - // json with random balancers, and the first is not registered. - name: "manually-generated", - js: json.RawMessage(` -{ - "childPolicy": [ - {"fake_balancer_C": {}}, - {"fake_balancer_A": {}}, - {"fake_balancer_B": {}} - ], - "fallbackPolicy": [ - {"fake_balancer_C": {}}, - {"fake_balancer_B": {}}, - {"fake_balancer_A": {}} - ], - "edsServiceName": "eds.service", - "maxConcurrentRequests": 123, - "lrsLoadReportingServerName": "lrs.server" -}`), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "fake_balancer_A", - Config: json.RawMessage("{}"), - }, - FallBackPolicy: &loadBalancingConfig{ - Name: "fake_balancer_B", - Config: json.RawMessage("{}"), - }, - EDSServiceName: testEDSName, - MaxConcurrentRequests: &testMaxConcurrentRequests, - LrsLoadReportingServerName: &testLRSName, - }, - }, - { - // json with no lrs server name, LrsLoadReportingServerName should - // be nil (not an empty string). - name: "no-lrs-server-name", - js: json.RawMessage(` -{ - "edsServiceName": "eds.service" -}`), - want: &EDSConfig{ - EDSServiceName: testEDSName, - LrsLoadReportingServerName: nil, - }, - }, - { - name: "good child policy", - js: json.RawMessage(`{"childPolicy":[{"pick_first":{}}]}`), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "pick_first", - Config: json.RawMessage(`{}`), - }, - }, - }, - { - name: "multiple good child policies", - js: json.RawMessage(`{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "round_robin", - Config: json.RawMessage(`{}`), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - b := &edsBalancerBuilder{} - got, err := b.ParseConfig(tt.js) - if (err != nil) != tt.wantErr { - t.Fatalf("edsBalancerBuilder.ParseConfig() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.wantErr { - return - } - if !cmp.Equal(got, tt.want) { - t.Errorf(cmp.Diff(got, tt.want)) - } - }) - } -} - -func (s) TestEqualStringPointers(t *testing.T) { - var ( - ta1 = "test-a" - ta2 = "test-a" - tb = "test-b" - ) - tests := []struct { - name string - a *string - b *string - want bool - }{ - {"both-nil", nil, nil, true}, - {"a-non-nil", &ta1, nil, false}, - {"b-non-nil", nil, &tb, false}, - {"equal", &ta1, &ta2, true}, - {"different", &ta1, &tb, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := equalStringPointers(tt.a, tt.b); got != tt.want { - t.Errorf("equalStringPointers() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/xds/internal/balancer/edsbalancer/load_store_wrapper.go b/xds/internal/balancer/edsbalancer/load_store_wrapper.go deleted file mode 100644 index 18904e47a42..00000000000 --- a/xds/internal/balancer/edsbalancer/load_store_wrapper.go +++ /dev/null @@ -1,88 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package edsbalancer - -import ( - "sync" - - "google.golang.org/grpc/xds/internal/client/load" -) - -type loadStoreWrapper struct { - mu sync.RWMutex - service string - // Both store and perCluster will be nil if load reporting is disabled (EDS - // response doesn't have LRS server name). Note that methods on Store and - // perCluster all handle nil, so there's no need to check nil before calling - // them. - store *load.Store - perCluster load.PerClusterReporter -} - -func (lsw *loadStoreWrapper) updateServiceName(service string) { - lsw.mu.Lock() - defer lsw.mu.Unlock() - if lsw.service == service { - return - } - lsw.service = service - lsw.perCluster = lsw.store.PerCluster(lsw.service, "") -} - -func (lsw *loadStoreWrapper) updateLoadStore(store *load.Store) { - lsw.mu.Lock() - defer lsw.mu.Unlock() - if store == lsw.store { - return - } - lsw.store = store - lsw.perCluster = lsw.store.PerCluster(lsw.service, "") -} - -func (lsw *loadStoreWrapper) CallStarted(locality string) { - lsw.mu.RLock() - defer lsw.mu.RUnlock() - if lsw.perCluster != nil { - lsw.perCluster.CallStarted(locality) - } -} - -func (lsw *loadStoreWrapper) CallFinished(locality string, err error) { - lsw.mu.RLock() - defer lsw.mu.RUnlock() - if lsw.perCluster != nil { - lsw.perCluster.CallFinished(locality, err) - } -} - -func (lsw *loadStoreWrapper) CallServerLoad(locality, name string, val float64) { - lsw.mu.RLock() - defer lsw.mu.RUnlock() - if lsw.perCluster != nil { - lsw.perCluster.CallServerLoad(locality, name, val) - } -} - -func (lsw *loadStoreWrapper) CallDropped(category string) { - lsw.mu.RLock() - defer lsw.mu.RUnlock() - if lsw.perCluster != nil { - lsw.perCluster.CallDropped(category) - } -} diff --git a/xds/internal/balancer/edsbalancer/util.go b/xds/internal/balancer/edsbalancer/util.go deleted file mode 100644 index 13295042646..00000000000 --- a/xds/internal/balancer/edsbalancer/util.go +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "google.golang.org/grpc/internal/wrr" - xdsclient "google.golang.org/grpc/xds/internal/client" -) - -var newRandomWRR = wrr.NewRandom - -type dropper struct { - c xdsclient.OverloadDropConfig - w wrr.WRR -} - -func newDropper(c xdsclient.OverloadDropConfig) *dropper { - w := newRandomWRR() - w.Add(true, int64(c.Numerator)) - w.Add(false, int64(c.Denominator-c.Numerator)) - - return &dropper{ - c: c, - w: w, - } -} - -func (d *dropper) drop() (ret bool) { - return d.w.Next().(bool) -} diff --git a/xds/internal/balancer/edsbalancer/util_test.go b/xds/internal/balancer/edsbalancer/util_test.go deleted file mode 100644 index 748aeffe2bb..00000000000 --- a/xds/internal/balancer/edsbalancer/util_test.go +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "testing" - - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/testutils" -) - -func init() { - newRandomWRR = testutils.NewTestWRR -} - -func (s) TestDropper(t *testing.T) { - const repeat = 2 - - type args struct { - numerator uint32 - denominator uint32 - } - tests := []struct { - name string - args args - }{ - { - name: "2_3", - args: args{ - numerator: 2, - denominator: 3, - }, - }, - { - name: "4_8", - args: args{ - numerator: 4, - denominator: 8, - }, - }, - { - name: "7_20", - args: args{ - numerator: 7, - denominator: 20, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := newDropper(xdsclient.OverloadDropConfig{ - Category: "", - Numerator: tt.args.numerator, - Denominator: tt.args.denominator, - }) - - var ( - dCount int - wantCount = int(tt.args.numerator) * repeat - loopCount = int(tt.args.denominator) * repeat - ) - for i := 0; i < loopCount; i++ { - if d.drop() { - dCount++ - } - } - - if dCount != (wantCount) { - t.Errorf("with numerator %v, denominator %v repeat %v, got drop count: %v, want %v", - tt.args.numerator, tt.args.denominator, repeat, dCount, wantCount) - } - }) - } -} diff --git a/xds/internal/balancer/edsbalancer/xds_lrs_test.go b/xds/internal/balancer/edsbalancer/xds_lrs_test.go deleted file mode 100644 index 9f93e0b42f0..00000000000 --- a/xds/internal/balancer/edsbalancer/xds_lrs_test.go +++ /dev/null @@ -1,71 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package edsbalancer - -import ( - "context" - "testing" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/xds/internal/testutils/fakeclient" -) - -// TestXDSLoadReporting verifies that the edsBalancer starts the loadReport -// stream when the lbConfig passed to it contains a valid value for the LRS -// server (empty string). -func (s) TestXDSLoadReporting(t *testing.T) { - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - EDSServiceName: testEDSClusterName, - LrsLoadReportingServerName: new(string), - }, - }); err != nil { - t.Fatal(err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - gotCluster, err := xdsC.WaitForWatchEDS(ctx) - if err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - if gotCluster != testEDSClusterName { - t.Fatalf("xdsClient.WatchEndpoints() called with cluster: %v, want %v", gotCluster, testEDSClusterName) - } - - got, err := xdsC.WaitForReportLoad(ctx) - if err != nil { - t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) - } - if got.Server != "" { - t.Fatalf("xdsClient.ReportLoad called with {%v}: want {\"\"}", got.Server) - } -} diff --git a/xds/internal/balancer/edsbalancer/xds_old.go b/xds/internal/balancer/edsbalancer/xds_old.go deleted file mode 100644 index 6729e6801f1..00000000000 --- a/xds/internal/balancer/edsbalancer/xds_old.go +++ /dev/null @@ -1,46 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import "google.golang.org/grpc/balancer" - -// The old xds balancer implements logic for both CDS and EDS. With the new -// design, CDS is split and moved to a separate balancer, and the xds balancer -// becomes the EDS balancer. -// -// To keep the existing tests working, this file regisger EDS balancer under the -// old xds balancer name. -// -// TODO: delete this file when migration to new workflow (LDS, RDS, CDS, EDS) is -// done. - -const xdsName = "xds_experimental" - -func init() { - balancer.Register(&xdsBalancerBuilder{}) -} - -// xdsBalancerBuilder register edsBalancerBuilder (now with name -// "eds_experimental") under the old name "xds_experimental". -type xdsBalancerBuilder struct { - edsBalancerBuilder -} - -func (b *xdsBalancerBuilder) Name() string { - return xdsName -} diff --git a/xds/internal/balancer/loadstore/load_store_wrapper.go b/xds/internal/balancer/loadstore/load_store_wrapper.go index 88fa344118c..8ce958d71ca 100644 --- a/xds/internal/balancer/loadstore/load_store_wrapper.go +++ b/xds/internal/balancer/loadstore/load_store_wrapper.go @@ -22,7 +22,7 @@ package loadstore import ( "sync" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // NewWrapper creates a Wrapper. diff --git a/xds/internal/balancer/lrs/balancer.go b/xds/internal/balancer/lrs/balancer.go deleted file mode 100644 index ab9ee7109db..00000000000 --- a/xds/internal/balancer/lrs/balancer.go +++ /dev/null @@ -1,246 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// Package lrs implements load reporting balancer for xds. -package lrs - -import ( - "encoding/json" - "fmt" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal/balancer/loadstore" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" -) - -func init() { - balancer.Register(&lrsBB{}) -} - -var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } - -const lrsBalancerName = "lrs_experimental" - -type lrsBB struct{} - -func (l *lrsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - b := &lrsBalancer{ - cc: cc, - buildOpts: opts, - } - b.logger = prefixLogger(b) - b.logger.Infof("Created") - - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.client = newXDSClientWrapper(client) - - return b -} - -func (l *lrsBB) Name() string { - return lrsBalancerName -} - -func (l *lrsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { - return parseConfig(c) -} - -type lrsBalancer struct { - cc balancer.ClientConn - buildOpts balancer.BuildOptions - - logger *grpclog.PrefixLogger - client *xdsClientWrapper - - config *lbConfig - lb balancer.Balancer // The sub balancer. -} - -func (b *lrsBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - newConfig, ok := s.BalancerConfig.(*lbConfig) - if !ok { - return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) - } - - // Update load reporting config or xds client. This needs to be done before - // updating the child policy because we need the loadStore from the updated - // client to be passed to the ccWrapper. - if err := b.client.update(newConfig); err != nil { - return err - } - - // If child policy is a different type, recreate the sub-balancer. - if b.config == nil || b.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { - bb := balancer.Get(newConfig.ChildPolicy.Name) - if bb == nil { - return fmt.Errorf("balancer %q not registered", newConfig.ChildPolicy.Name) - } - if b.lb != nil { - b.lb.Close() - } - lidJSON, err := newConfig.Locality.ToString() - if err != nil { - return fmt.Errorf("failed to marshal LocalityID: %#v", newConfig.Locality) - } - ccWrapper := newCCWrapper(b.cc, b.client.loadStore(), lidJSON) - b.lb = bb.Build(ccWrapper, b.buildOpts) - } - b.config = newConfig - - // Addresses and sub-balancer config are sent to sub-balancer. - return b.lb.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: s.ResolverState, - BalancerConfig: b.config.ChildPolicy.Config, - }) -} - -func (b *lrsBalancer) ResolverError(err error) { - if b.lb != nil { - b.lb.ResolverError(err) - } -} - -func (b *lrsBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { - if b.lb != nil { - b.lb.UpdateSubConnState(sc, s) - } -} - -func (b *lrsBalancer) Close() { - if b.lb != nil { - b.lb.Close() - b.lb = nil - } - b.client.close() -} - -type ccWrapper struct { - balancer.ClientConn - loadStore load.PerClusterReporter - localityIDJSON string -} - -func newCCWrapper(cc balancer.ClientConn, loadStore load.PerClusterReporter, localityIDJSON string) *ccWrapper { - return &ccWrapper{ - ClientConn: cc, - loadStore: loadStore, - localityIDJSON: localityIDJSON, - } -} - -func (ccw *ccWrapper) UpdateState(s balancer.State) { - s.Picker = newLoadReportPicker(s.Picker, ccw.localityIDJSON, ccw.loadStore) - ccw.ClientConn.UpdateState(s) -} - -// xdsClientInterface contains only the xds_client methods needed by LRS -// balancer. It's defined so we can override xdsclient in tests. -type xdsClientInterface interface { - ReportLoad(server string) (*load.Store, func()) - Close() -} - -type xdsClientWrapper struct { - c xdsClientInterface - cancelLoadReport func() - clusterName string - edsServiceName string - lrsServerName string - // loadWrapper is a wrapper with loadOriginal, with clusterName and - // edsServiceName. It's used children to report loads. - loadWrapper *loadstore.Wrapper -} - -func newXDSClientWrapper(c xdsClientInterface) *xdsClientWrapper { - return &xdsClientWrapper{ - c: c, - loadWrapper: loadstore.NewWrapper(), - } -} - -// update checks the config and xdsclient, and decides whether it needs to -// restart the load reporting stream. -func (w *xdsClientWrapper) update(newConfig *lbConfig) error { - var ( - restartLoadReport bool - updateLoadClusterAndService bool - ) - - // ClusterName is different, restart. ClusterName is from ClusterName and - // EdsServiceName. - if w.clusterName != newConfig.ClusterName { - updateLoadClusterAndService = true - w.clusterName = newConfig.ClusterName - } - if w.edsServiceName != newConfig.EdsServiceName { - updateLoadClusterAndService = true - w.edsServiceName = newConfig.EdsServiceName - } - - if updateLoadClusterAndService { - // This updates the clusterName and serviceName that will reported for the - // loads. The update here is too early, the perfect timing is when the - // picker is updated with the new connection. But from this balancer's point - // of view, it's impossible to tell. - // - // On the other hand, this will almost never happen. Each LRS policy - // shouldn't get updated config. The parent should do a graceful switch when - // the clusterName or serviceName is changed. - w.loadWrapper.UpdateClusterAndService(w.clusterName, w.edsServiceName) - } - - if w.lrsServerName != newConfig.LrsLoadReportingServerName { - // LrsLoadReportingServerName is different, load should be report to a - // different server, restart. - restartLoadReport = true - w.lrsServerName = newConfig.LrsLoadReportingServerName - } - - if restartLoadReport { - if w.cancelLoadReport != nil { - w.cancelLoadReport() - w.cancelLoadReport = nil - } - var loadStore *load.Store - if w.c != nil { - loadStore, w.cancelLoadReport = w.c.ReportLoad(w.lrsServerName) - } - w.loadWrapper.UpdateLoadStore(loadStore) - } - - return nil -} - -func (w *xdsClientWrapper) loadStore() load.PerClusterReporter { - return w.loadWrapper -} - -func (w *xdsClientWrapper) close() { - if w.cancelLoadReport != nil { - w.cancelLoadReport() - w.cancelLoadReport = nil - } - w.c.Close() -} diff --git a/xds/internal/balancer/lrs/balancer_test.go b/xds/internal/balancer/lrs/balancer_test.go deleted file mode 100644 index 0b575b11210..00000000000 --- a/xds/internal/balancer/lrs/balancer_test.go +++ /dev/null @@ -1,144 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/roundrobin" - "google.golang.org/grpc/connectivity" - internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - "google.golang.org/grpc/resolver" - xdsinternal "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/testutils" - "google.golang.org/grpc/xds/internal/testutils/fakeclient" -) - -const defaultTestTimeout = 1 * time.Second - -var ( - testBackendAddrs = []resolver.Address{ - {Addr: "1.1.1.1:1"}, - } - testLocality = &xdsinternal.LocalityID{ - Region: "test-region", - Zone: "test-zone", - SubZone: "test-sub-zone", - } -) - -// TestLoadReporting verifies that the lrs balancer starts the loadReport -// stream when the lbConfig passed to it contains a valid value for the LRS -// server (empty string). -func TestLoadReporting(t *testing.T) { - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() - - builder := balancer.Get(lrsBalancerName) - cc := testutils.NewTestClientConn(t) - lrsB := builder.Build(cc, balancer.BuildOptions{}) - defer lrsB.Close() - - if err := lrsB.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, - BalancerConfig: &lbConfig{ - ClusterName: testClusterName, - EdsServiceName: testServiceName, - LrsLoadReportingServerName: testLRSServerName, - Locality: testLocality, - ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: roundrobin.Name, - }, - }, - }); err != nil { - t.Fatalf("unexpected error from UpdateClientConnState: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - - got, err := xdsC.WaitForReportLoad(ctx) - if err != nil { - t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) - } - if got.Server != testLRSServerName { - t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerName) - } - - sc1 := <-cc.NewSubConnCh - lrsB.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) - lrsB.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) - - // Test pick with one backend. - p1 := <-cc.NewPickerCh - const successCount = 5 - for i := 0; i < successCount; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - gotSCSt.Done(balancer.DoneInfo{}) - } - const errorCount = 5 - for i := 0; i < errorCount; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - gotSCSt.Done(balancer.DoneInfo{Err: fmt.Errorf("error")}) - } - - // Dump load data from the store and compare with expected counts. - loadStore := xdsC.LoadStore() - if loadStore == nil { - t.Fatal("loadStore is nil in xdsClient") - } - sds := loadStore.Stats([]string{testClusterName}) - if len(sds) == 0 { - t.Fatalf("loads for cluster %v not found in store", testClusterName) - } - sd := sds[0] - if sd.Cluster != testClusterName || sd.Service != testServiceName { - t.Fatalf("got unexpected load for %q, %q, want %q, %q", sd.Cluster, sd.Service, testClusterName, testServiceName) - } - testLocalityJSON, _ := testLocality.ToString() - localityData, ok := sd.LocalityStats[testLocalityJSON] - if !ok { - t.Fatalf("loads for %v not found in store", testLocality) - } - reqStats := localityData.RequestStats - if reqStats.Succeeded != successCount { - t.Errorf("got succeeded %v, want %v", reqStats.Succeeded, successCount) - } - if reqStats.Errored != errorCount { - t.Errorf("got errord %v, want %v", reqStats.Errored, errorCount) - } - if reqStats.InProgress != 0 { - t.Errorf("got inProgress %v, want %v", reqStats.InProgress, 0) - } -} diff --git a/xds/internal/balancer/lrs/config.go b/xds/internal/balancer/lrs/config.go deleted file mode 100644 index 3d39961401b..00000000000 --- a/xds/internal/balancer/lrs/config.go +++ /dev/null @@ -1,54 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "encoding/json" - "fmt" - - internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal" -) - -type lbConfig struct { - serviceconfig.LoadBalancingConfig - ClusterName string - EdsServiceName string - LrsLoadReportingServerName string - Locality *internal.LocalityID - ChildPolicy *internalserviceconfig.BalancerConfig -} - -func parseConfig(c json.RawMessage) (*lbConfig, error) { - var cfg lbConfig - if err := json.Unmarshal(c, &cfg); err != nil { - return nil, err - } - if cfg.ClusterName == "" { - return nil, fmt.Errorf("required ClusterName is not set in %+v", cfg) - } - if cfg.LrsLoadReportingServerName == "" { - return nil, fmt.Errorf("required LrsLoadReportingServerName is not set in %+v", cfg) - } - if cfg.Locality == nil { - return nil, fmt.Errorf("required Locality is not set in %+v", cfg) - } - return &cfg, nil -} diff --git a/xds/internal/balancer/lrs/config_test.go b/xds/internal/balancer/lrs/config_test.go deleted file mode 100644 index f49430569fe..00000000000 --- a/xds/internal/balancer/lrs/config_test.go +++ /dev/null @@ -1,127 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer/roundrobin" - internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - xdsinternal "google.golang.org/grpc/xds/internal" -) - -const ( - testClusterName = "test-cluster" - testServiceName = "test-eds-service" - testLRSServerName = "test-lrs-name" -) - -func TestParseConfig(t *testing.T) { - tests := []struct { - name string - js string - want *lbConfig - wantErr bool - }{ - { - name: "no cluster name", - js: `{ - "edsServiceName": "test-eds-service", - "lrsLoadReportingServerName": "test-lrs-name", - "locality": { - "region": "test-region", - "zone": "test-zone", - "subZone": "test-sub-zone" - }, - "childPolicy":[{"round_robin":{}}] -} - `, - wantErr: true, - }, - { - name: "no LRS server name", - js: `{ - "clusterName": "test-cluster", - "edsServiceName": "test-eds-service", - "locality": { - "region": "test-region", - "zone": "test-zone", - "subZone": "test-sub-zone" - }, - "childPolicy":[{"round_robin":{}}] -} - `, - wantErr: true, - }, - { - name: "no locality", - js: `{ - "clusterName": "test-cluster", - "edsServiceName": "test-eds-service", - "lrsLoadReportingServerName": "test-lrs-name", - "childPolicy":[{"round_robin":{}}] -} - `, - wantErr: true, - }, - { - name: "good", - js: `{ - "clusterName": "test-cluster", - "edsServiceName": "test-eds-service", - "lrsLoadReportingServerName": "test-lrs-name", - "locality": { - "region": "test-region", - "zone": "test-zone", - "subZone": "test-sub-zone" - }, - "childPolicy":[{"round_robin":{}}] -} - `, - want: &lbConfig{ - ClusterName: testClusterName, - EdsServiceName: testServiceName, - LrsLoadReportingServerName: testLRSServerName, - Locality: &xdsinternal.LocalityID{ - Region: "test-region", - Zone: "test-zone", - SubZone: "test-sub-zone", - }, - ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: roundrobin.Name, - Config: nil, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := parseConfig([]byte(tt.js)) - if (err != nil) != tt.wantErr { - t.Errorf("parseConfig() error = %v, wantErr %v", err, tt.wantErr) - return - } - if diff := cmp.Diff(got, tt.want); diff != "" { - t.Errorf("parseConfig() got = %v, want %v, diff: %s", got, tt.want, diff) - } - }) - } -} diff --git a/xds/internal/balancer/lrs/picker.go b/xds/internal/balancer/lrs/picker.go deleted file mode 100644 index 1e4ad156e5b..00000000000 --- a/xds/internal/balancer/lrs/picker.go +++ /dev/null @@ -1,85 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - orcapb "github.com/cncf/udpa/go/udpa/data/orca/v1" - "google.golang.org/grpc/balancer" -) - -const ( - serverLoadCPUName = "cpu_utilization" - serverLoadMemoryName = "mem_utilization" -) - -// loadReporter wraps the methods from the loadStore that are used here. -type loadReporter interface { - CallStarted(locality string) - CallFinished(locality string, err error) - CallServerLoad(locality, name string, val float64) -} - -type loadReportPicker struct { - p balancer.Picker - - locality string - loadStore loadReporter -} - -func newLoadReportPicker(p balancer.Picker, id string, loadStore loadReporter) *loadReportPicker { - return &loadReportPicker{ - p: p, - locality: id, - loadStore: loadStore, - } -} - -func (lrp *loadReportPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - res, err := lrp.p.Pick(info) - if err != nil { - return res, err - } - - if lrp.loadStore == nil { - return res, err - } - - lrp.loadStore.CallStarted(lrp.locality) - oldDone := res.Done - res.Done = func(info balancer.DoneInfo) { - if oldDone != nil { - oldDone(info) - } - lrp.loadStore.CallFinished(lrp.locality, info.Err) - - load, ok := info.ServerLoad.(*orcapb.OrcaLoadReport) - if !ok { - return - } - lrp.loadStore.CallServerLoad(lrp.locality, serverLoadCPUName, load.CpuUtilization) - lrp.loadStore.CallServerLoad(lrp.locality, serverLoadMemoryName, load.MemUtilization) - for n, d := range load.RequestCost { - lrp.loadStore.CallServerLoad(lrp.locality, n, d) - } - for n, d := range load.Utilization { - lrp.loadStore.CallServerLoad(lrp.locality, n, d) - } - } - return res, err -} diff --git a/xds/internal/balancer/priority/balancer.go b/xds/internal/balancer/priority/balancer.go index 6c4ff08378e..23e8aa77503 100644 --- a/xds/internal/balancer/priority/balancer.go +++ b/xds/internal/balancer/priority/balancer.go @@ -24,6 +24,7 @@ package priority import ( + "encoding/json" "fmt" "sync" "time" @@ -33,19 +34,22 @@ import ( "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/hierarchy" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/xds/internal/balancer/balancergroup" ) -const priorityBalancerName = "priority_experimental" +// Name is the name of the priority balancer. +const Name = "priority_experimental" func init() { - balancer.Register(priorityBB{}) + balancer.Register(bb{}) } -type priorityBB struct{} +type bb struct{} -func (priorityBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &priorityBalancer{ cc: cc, done: grpcsync.NewEvent(), @@ -60,11 +64,14 @@ func (priorityBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) bal go b.run() b.logger.Infof("Created") return b +} +func (b bb) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + return parseConfig(s) } -func (priorityBB) Name() string { - return priorityBalancerName +func (bb) Name() string { + return Name } // timerWrapper wraps a timer with a boolean. So that when a race happens @@ -102,7 +109,8 @@ type priorityBalancer struct { } func (b *priorityBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - newConfig, ok := s.BalancerConfig.(*lbConfig) + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) + newConfig, ok := s.BalancerConfig.(*LBConfig) if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) } @@ -125,7 +133,7 @@ func (b *priorityBalancer) UpdateClientConnState(s balancer.ClientConnState) err // the balancer isn't built, because this child can be a low // priority. If necessary, it will be built when syncing priorities. cb := newChildBalancer(name, b, bb) - cb.updateConfig(newSubConfig.Config.Config, resolver.State{ + cb.updateConfig(newSubConfig, resolver.State{ Addresses: addressesSplit[name], ServiceConfig: s.ResolverState.ServiceConfig, Attributes: s.ResolverState.Attributes, @@ -140,13 +148,13 @@ func (b *priorityBalancer) UpdateClientConnState(s balancer.ClientConnState) err // rebuild, rebuild will happen when syncing priorities. if currentChild.bb.Name() != bb.Name() { currentChild.stop() - currentChild.bb = bb + currentChild.updateBuilder(bb) } // Update config and address, but note that this doesn't send the // updates to child balancer (the child balancer might not be built, if // it's a low priority). - currentChild.updateConfig(newSubConfig.Config.Config, resolver.State{ + currentChild.updateConfig(newSubConfig, resolver.State{ Addresses: addressesSplit[name], ServiceConfig: s.ResolverState.ServiceConfig, Attributes: s.ResolverState.Attributes, @@ -193,6 +201,10 @@ func (b *priorityBalancer) Close() { b.stopPriorityInitTimer() } +func (b *priorityBalancer) ExitIdle() { + b.bg.ExitIdle() +} + // stopPriorityInitTimer stops the priorityInitTimer if it's not nil, and set it // to nil. // diff --git a/xds/internal/balancer/priority/balancer_child.go b/xds/internal/balancer/priority/balancer_child.go index d012ad4e459..600705da01a 100644 --- a/xds/internal/balancer/priority/balancer_child.go +++ b/xds/internal/balancer/priority/balancer_child.go @@ -29,10 +29,11 @@ import ( type childBalancer struct { name string parent *priorityBalancer - bb balancer.Builder + bb *ignoreResolveNowBalancerBuilder - config serviceconfig.LoadBalancingConfig - rState resolver.State + ignoreReresolutionRequests bool + config serviceconfig.LoadBalancingConfig + rState resolver.State started bool state balancer.State @@ -44,7 +45,7 @@ func newChildBalancer(name string, parent *priorityBalancer, bb balancer.Builder return &childBalancer{ name: name, parent: parent, - bb: bb, + bb: newIgnoreResolveNowBalancerBuilder(bb, false), started: false, // Start with the connecting state and picker with re-pick error, so // that when a priority switch causes this child picked before it's @@ -56,10 +57,16 @@ func newChildBalancer(name string, parent *priorityBalancer, bb balancer.Builder } } +// updateBuilder updates builder for the child, but doesn't build. +func (cb *childBalancer) updateBuilder(bb balancer.Builder) { + cb.bb = newIgnoreResolveNowBalancerBuilder(bb, cb.ignoreReresolutionRequests) +} + // updateConfig sets childBalancer's config and state, but doesn't send update to // the child balancer. -func (cb *childBalancer) updateConfig(config serviceconfig.LoadBalancingConfig, rState resolver.State) { - cb.config = config +func (cb *childBalancer) updateConfig(child *Child, rState resolver.State) { + cb.ignoreReresolutionRequests = child.IgnoreReresolutionRequests + cb.config = child.Config.Config cb.rState = rState } @@ -76,6 +83,7 @@ func (cb *childBalancer) start() { // sendUpdate sends the addresses and config to the child balancer. func (cb *childBalancer) sendUpdate() { + cb.bb.updateIgnoreResolveNow(cb.ignoreReresolutionRequests) // TODO: return and aggregate the returned error in the parent. err := cb.parent.bg.UpdateClientConnState(cb.name, balancer.ClientConnState{ ResolverState: cb.rState, diff --git a/xds/internal/balancer/priority/balancer_priority.go b/xds/internal/balancer/priority/balancer_priority.go index ea2f4f04184..bd2c6724ea5 100644 --- a/xds/internal/balancer/priority/balancer_priority.go +++ b/xds/internal/balancer/priority/balancer_priority.go @@ -28,8 +28,12 @@ import ( ) var ( - errAllPrioritiesRemoved = errors.New("no locality is provided, all priorities are removed") - defaultPriorityInitTimeout = 10 * time.Second + // ErrAllPrioritiesRemoved is returned by the picker when there's no priority available. + ErrAllPrioritiesRemoved = errors.New("no priority is provided, all priorities are removed") + // DefaultPriorityInitTimeout is the timeout after which if a priority is + // not READY, the next will be started. It's exported to be overridden by + // tests. + DefaultPriorityInitTimeout = 10 * time.Second ) // syncPriority handles priority after a config update. It makes sure the @@ -70,7 +74,7 @@ func (b *priorityBalancer) syncPriority() { b.stopPriorityInitTimer() b.cc.UpdateState(balancer.State{ ConnectivityState: connectivity.TransientFailure, - Picker: base.NewErrPicker(errAllPrioritiesRemoved), + Picker: base.NewErrPicker(ErrAllPrioritiesRemoved), }) return } @@ -162,7 +166,7 @@ func (b *priorityBalancer) switchToChild(child *childBalancer, priority int) { // to check the stopped boolean. timerW := &timerWrapper{} b.priorityInitTimer = timerW - timerW.timer = time.AfterFunc(defaultPriorityInitTimeout, func() { + timerW.timer = time.AfterFunc(DefaultPriorityInitTimeout, func() { b.mu.Lock() defer b.mu.Unlock() if timerW.stopped { @@ -221,14 +225,17 @@ func (b *priorityBalancer) handleChildStateUpdate(childName string, s balancer.S child.state = s switch s.ConnectivityState { - case connectivity.Ready: + case connectivity.Ready, connectivity.Idle: + // Note that idle is also handled as if it's Ready. It will close the + // lower priorities (which will be kept in a cache, not deleted), and + // new picks will use the Idle picker. b.handlePriorityWithNewStateReady(child, priority) case connectivity.TransientFailure: b.handlePriorityWithNewStateTransientFailure(child, priority) case connectivity.Connecting: b.handlePriorityWithNewStateConnecting(child, priority, oldState) default: - // New state is Idle, should never happen. Don't forward. + // New state is Shutdown, should never happen. Don't forward. } } diff --git a/xds/internal/balancer/priority/balancer_test.go b/xds/internal/balancer/priority/balancer_test.go index be14231dcb3..b884035442e 100644 --- a/xds/internal/balancer/priority/balancer_test.go +++ b/xds/internal/balancer/priority/balancer_test.go @@ -19,6 +19,7 @@ package priority import ( + "context" "fmt" "testing" "time" @@ -83,7 +84,7 @@ func subConnFromPicker(t *testing.T, p balancer.Picker) func() balancer.SubConn // Init 0 and 1; 0 is up, use 0; add 2, use 0; remove 2, use 0. func (s) TestPriority_HighPriorityReady(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -95,10 +96,10 @@ func (s) TestPriority_HighPriorityReady(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -132,11 +133,11 @@ func (s) TestPriority_HighPriorityReady(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-2"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-2": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-2": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1", "child-2"}, }, @@ -162,10 +163,10 @@ func (s) TestPriority_HighPriorityReady(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -190,7 +191,7 @@ func (s) TestPriority_HighPriorityReady(t *testing.T) { // down, use 2; remove 2, use 1. func (s) TestPriority_SwitchPriority(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -202,10 +203,10 @@ func (s) TestPriority_SwitchPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -269,11 +270,11 @@ func (s) TestPriority_SwitchPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-2"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-2": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-2": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1", "child-2"}, }, @@ -328,10 +329,10 @@ func (s) TestPriority_SwitchPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -373,7 +374,7 @@ func (s) TestPriority_SwitchPriority(t *testing.T) { // use 0. func (s) TestPriority_HighPriorityToConnectingFromReady(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -385,10 +386,10 @@ func (s) TestPriority_HighPriorityToConnectingFromReady(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -468,7 +469,7 @@ func (s) TestPriority_HighPriorityToConnectingFromReady(t *testing.T) { // Init 0 and 1; 0 and 1 both down; add 2, use 2. func (s) TestPriority_HigherDownWhileAddingLower(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -480,10 +481,10 @@ func (s) TestPriority_HigherDownWhileAddingLower(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -534,11 +535,11 @@ func (s) TestPriority_HigherDownWhileAddingLower(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-2"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-2": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-2": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1", "child-2"}, }, @@ -579,7 +580,7 @@ func (s) TestPriority_HigherReadyCloseAllLower(t *testing.T) { // defer time.Sleep(10 * time.Millisecond) cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -592,11 +593,11 @@ func (s) TestPriority_HigherReadyCloseAllLower(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-2"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-2": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-2": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1", "child-2"}, }, @@ -687,15 +688,15 @@ func (s) TestPriority_HigherReadyCloseAllLower(t *testing.T) { func (s) TestPriority_InitTimeout(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := DefaultPriorityInitTimeout + DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + DefaultPriorityInitTimeout = old } }()() cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -707,10 +708,10 @@ func (s) TestPriority_InitTimeout(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -757,15 +758,15 @@ func (s) TestPriority_InitTimeout(t *testing.T) { func (s) TestPriority_RemovesAllPriorities(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := DefaultPriorityInitTimeout + DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + DefaultPriorityInitTimeout = old } }()() cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -777,10 +778,10 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -808,7 +809,7 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { ResolverState: resolver.State{ Addresses: nil, }, - BalancerConfig: &lbConfig{ + BalancerConfig: &LBConfig{ Children: nil, Priorities: nil, }, @@ -825,8 +826,8 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { // Test pick return TransientFailure. pFail := <-cc.NewPickerCh for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != errAllPrioritiesRemoved { - t.Fatalf("want pick error %v, got %v", errAllPrioritiesRemoved, err) + if _, err := pFail.Pick(balancer.PickInfo{}); err != ErrAllPrioritiesRemoved { + t.Fatalf("want pick error %v, got %v", ErrAllPrioritiesRemoved, err) } } @@ -838,10 +839,10 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[3]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -882,9 +883,9 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -933,7 +934,7 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { // will be used. func (s) TestPriority_HighPriorityNoEndpoints(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -945,10 +946,10 @@ func (s) TestPriority_HighPriorityNoEndpoints(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -980,10 +981,10 @@ func (s) TestPriority_HighPriorityNoEndpoints(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -1027,12 +1028,12 @@ func (s) TestPriority_HighPriorityNoEndpoints(t *testing.T) { func (s) TestPriority_FirstPriorityUnavailable(t *testing.T) { const testPriorityInitTimeout = time.Second defer func(t time.Duration) { - defaultPriorityInitTimeout = t - }(defaultPriorityInitTimeout) - defaultPriorityInitTimeout = testPriorityInitTimeout + DefaultPriorityInitTimeout = t + }(DefaultPriorityInitTimeout) + DefaultPriorityInitTimeout = testPriorityInitTimeout cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1043,9 +1044,9 @@ func (s) TestPriority_FirstPriorityUnavailable(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1058,7 +1059,7 @@ func (s) TestPriority_FirstPriorityUnavailable(t *testing.T) { ResolverState: resolver.State{ Addresses: nil, }, - BalancerConfig: &lbConfig{ + BalancerConfig: &LBConfig{ Children: nil, Priorities: nil, }, @@ -1075,7 +1076,7 @@ func (s) TestPriority_FirstPriorityUnavailable(t *testing.T) { // Init a(p0) and b(p1); a(p0) is up, use a; move b to p0, a to p1, use b. func (s) TestPriority_MoveChildToHigherPriority(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1087,10 +1088,10 @@ func (s) TestPriority_MoveChildToHigherPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -1124,10 +1125,10 @@ func (s) TestPriority_MoveChildToHigherPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-1", "child-0"}, }, @@ -1176,7 +1177,7 @@ func (s) TestPriority_MoveChildToHigherPriority(t *testing.T) { // Init a(p0) and b(p1); a(p0) is down, use b; move b to p0, a to p1, use b. func (s) TestPriority_MoveReadyChildToHigherPriority(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1188,10 +1189,10 @@ func (s) TestPriority_MoveReadyChildToHigherPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -1240,10 +1241,10 @@ func (s) TestPriority_MoveReadyChildToHigherPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-1", "child-0"}, }, @@ -1276,7 +1277,7 @@ func (s) TestPriority_MoveReadyChildToHigherPriority(t *testing.T) { // Init a(p0) and b(p1); a(p0) is down, use b; move b to p0, a to p1, use b. func (s) TestPriority_RemoveReadyLowestChild(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1288,10 +1289,10 @@ func (s) TestPriority_RemoveReadyLowestChild(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -1338,9 +1339,9 @@ func (s) TestPriority_RemoveReadyLowestChild(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1384,7 +1385,7 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { }()() cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1395,9 +1396,9 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1426,15 +1427,15 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { // be different. if err := pb.UpdateClientConnState(balancer.ClientConnState{ ResolverState: resolver.State{}, - BalancerConfig: &lbConfig{}, + BalancerConfig: &LBConfig{}, }); err != nil { t.Fatalf("failed to update ClientConn state: %v", err) } pFail := <-cc.NewPickerCh for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != errAllPrioritiesRemoved { - t.Fatalf("want pick error %v, got %v", errAllPrioritiesRemoved, err) + if _, err := pFail.Pick(balancer.PickInfo{}); err != ErrAllPrioritiesRemoved { + t.Fatalf("want pick error %v, got %v", ErrAllPrioritiesRemoved, err) } } @@ -1454,9 +1455,9 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1487,7 +1488,7 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { // Init 0; 0 is up, use 0; change 0's policy, 0 is used. func (s) TestPriority_ChildPolicyChange(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1498,9 +1499,9 @@ func (s) TestPriority_ChildPolicyChange(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1533,9 +1534,9 @@ func (s) TestPriority_ChildPolicyChange(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: testRRBalancerName}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: testRRBalancerName}}, }, Priorities: []string{"child-0"}, }, @@ -1587,7 +1588,7 @@ func init() { // by acquiring a locked mutex. func (s) TestPriority_ChildPolicyUpdatePickerInline(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1598,9 +1599,9 @@ func (s) TestPriority_ChildPolicyUpdatePickerInline(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: inlineUpdateBalancerName}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: inlineUpdateBalancerName}}, }, Priorities: []string{"child-0"}, }, @@ -1616,3 +1617,260 @@ func (s) TestPriority_ChildPolicyUpdatePickerInline(t *testing.T) { } } } + +// When the child policy's configured to ignore reresolution requests, the +// ResolveNow() calls from this child should be all ignored. +func (s) TestPriority_IgnoreReresolutionRequest(t *testing.T) { + cc := testutils.NewTestClientConn(t) + bb := balancer.Get(Name) + pb := bb.Build(cc, balancer.BuildOptions{}) + defer pb.Close() + + // One children, with priorities [0], with one backend, reresolution is + // ignored. + if err := pb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), + }, + }, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": { + Config: &internalserviceconfig.BalancerConfig{Name: resolveNowBalancerName}, + IgnoreReresolutionRequests: true, + }, + }, + Priorities: []string{"child-0"}, + }, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // This is the balancer.ClientConn that the inner resolverNowBalancer is + // built with. + balancerCCI, err := resolveNowBalancerCCCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout waiting for ClientConn from balancer builder") + } + balancerCC := balancerCCI.(balancer.ClientConn) + + // Since IgnoreReresolutionRequests was set to true, all ResolveNow() calls + // should be ignored. + for i := 0; i < 5; i++ { + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + } + select { + case <-cc.ResolveNowCh: + t.Fatalf("got unexpected ResolveNow() call") + case <-time.After(time.Millisecond * 100): + } + + // Send another update to set IgnoreReresolutionRequests to false. + if err := pb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), + }, + }, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": { + Config: &internalserviceconfig.BalancerConfig{Name: resolveNowBalancerName}, + IgnoreReresolutionRequests: false, + }, + }, + Priorities: []string{"child-0"}, + }, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + // Call ResolveNow() on the CC, it should be forwarded. + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + select { + case <-cc.ResolveNowCh: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for ResolveNow()") + } + +} + +// When the child policy's configured to ignore reresolution requests, the +// ResolveNow() calls from this child should be all ignored, from the other +// children are forwarded. +func (s) TestPriority_IgnoreReresolutionRequestTwoChildren(t *testing.T) { + cc := testutils.NewTestClientConn(t) + bb := balancer.Get(Name) + pb := bb.Build(cc, balancer.BuildOptions{}) + defer pb.Close() + + // One children, with priorities [0, 1], each with one backend. + // Reresolution is ignored for p0. + if err := pb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), + }, + }, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": { + Config: &internalserviceconfig.BalancerConfig{Name: resolveNowBalancerName}, + IgnoreReresolutionRequests: true, + }, + "child-1": { + Config: &internalserviceconfig.BalancerConfig{Name: resolveNowBalancerName}, + }, + }, + Priorities: []string{"child-0", "child-1"}, + }, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // This is the balancer.ClientConn from p0. + balancerCCI0, err := resolveNowBalancerCCCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout waiting for ClientConn from balancer builder 0") + } + balancerCC0 := balancerCCI0.(balancer.ClientConn) + + // Set p0 to transient failure, p1 will be started. + addrs0 := <-cc.NewSubConnAddrsCh + if got, want := addrs0[0].Addr, testBackendAddrStrs[0]; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + sc0 := <-cc.NewSubConnCh + pb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + + // This is the balancer.ClientConn from p1. + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) + defer cancel1() + balancerCCI1, err := resolveNowBalancerCCCh.Receive(ctx1) + if err != nil { + t.Fatalf("timeout waiting for ClientConn from balancer builder 1") + } + balancerCC1 := balancerCCI1.(balancer.ClientConn) + + // Since IgnoreReresolutionRequests was set to true for p0, ResolveNow() + // from p0 should all be ignored. + for i := 0; i < 5; i++ { + balancerCC0.ResolveNow(resolver.ResolveNowOptions{}) + } + select { + case <-cc.ResolveNowCh: + t.Fatalf("got unexpected ResolveNow() call") + case <-time.After(time.Millisecond * 100): + } + + // But IgnoreReresolutionRequests was false for p1, ResolveNow() from p1 + // should be forwarded. + balancerCC1.ResolveNow(resolver.ResolveNowOptions{}) + select { + case <-cc.ResolveNowCh: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for ResolveNow()") + } +} + +const initIdleBalancerName = "test-init-Idle-balancer" + +var errsTestInitIdle = []error{ + fmt.Errorf("init Idle balancer error 0"), + fmt.Errorf("init Idle balancer error 1"), +} + +func init() { + for i := 0; i < 2; i++ { + ii := i + stub.Register(fmt.Sprintf("%s-%d", initIdleBalancerName, ii), stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, opts balancer.ClientConnState) error { + bd.ClientConn.NewSubConn(opts.ResolverState.Addresses, balancer.NewSubConnOptions{}) + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { + err := fmt.Errorf("wrong picker error") + if state.ConnectivityState == connectivity.Idle { + err = errsTestInitIdle[ii] + } + bd.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: &testutils.TestConstPicker{Err: err}, + }) + }, + }) + } +} + +// If the high priorities send initial pickers with Idle state, their pickers +// should get picks, because policies like ringhash starts in Idle, and doesn't +// connect. +// +// Init 0, 1; 0 is Idle, use 0; 0 is down, start 1; 1 is Idle, use 1. +func (s) TestPriority_HighPriorityInitIdle(t *testing.T) { + cc := testutils.NewTestClientConn(t) + bb := balancer.Get(Name) + pb := bb.Build(cc, balancer.BuildOptions{}) + defer pb.Close() + + // Two children, with priorities [0, 1], each with one backend. + if err := pb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), + }, + }, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: fmt.Sprintf("%s-%d", initIdleBalancerName, 0)}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: fmt.Sprintf("%s-%d", initIdleBalancerName, 1)}}, + }, + Priorities: []string{"child-0", "child-1"}, + }, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + addrs0 := <-cc.NewSubConnAddrsCh + if got, want := addrs0[0].Addr, testBackendAddrStrs[0]; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + sc0 := <-cc.NewSubConnCh + + // Send an Idle state update to trigger an Idle picker update. + pb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + p0 := <-cc.NewPickerCh + if pr, err := p0.Pick(balancer.PickInfo{}); err != errsTestInitIdle[0] { + t.Fatalf("pick returned %v, %v, want _, %v", pr, err, errsTestInitIdle[0]) + } + + // Turn p0 down, to start p1. + pb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + // Before 1 gets READY, picker should return NoSubConnAvailable, so RPCs + // will retry. + p1 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + if _, err := p1.Pick(balancer.PickInfo{}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("want pick error %v, got %v", balancer.ErrNoSubConnAvailable, err) + } + } + + addrs1 := <-cc.NewSubConnAddrsCh + if got, want := addrs1[0].Addr, testBackendAddrStrs[1]; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + sc1 := <-cc.NewSubConnCh + // Idle picker from p1 should also be forwarded. + pb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + p2 := <-cc.NewPickerCh + if pr, err := p2.Pick(balancer.PickInfo{}); err != errsTestInitIdle[1] { + t.Fatalf("pick returned %v, %v, want _, %v", pr, err, errsTestInitIdle[1]) + } +} diff --git a/xds/internal/balancer/priority/config.go b/xds/internal/balancer/priority/config.go index da085908c71..37f1c9a829a 100644 --- a/xds/internal/balancer/priority/config.go +++ b/xds/internal/balancer/priority/config.go @@ -26,24 +26,27 @@ import ( "google.golang.org/grpc/serviceconfig" ) -type child struct { - Config *internalserviceconfig.BalancerConfig +// Child is a child of priority balancer. +type Child struct { + Config *internalserviceconfig.BalancerConfig `json:"config,omitempty"` + IgnoreReresolutionRequests bool `json:"ignoreReresolutionRequests,omitempty"` } -type lbConfig struct { - serviceconfig.LoadBalancingConfig +// LBConfig represents priority balancer's config. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` // Children is a map from the child balancer names to their configs. Child // names can be found in field Priorities. - Children map[string]*child + Children map[string]*Child `json:"children,omitempty"` // Priorities is a list of child balancer names. They are sorted from // highest priority to low. The type/config for each child can be found in // field Children, with the balancer name as the key. - Priorities []string + Priorities []string `json:"priorities,omitempty"` } -func parseConfig(c json.RawMessage) (*lbConfig, error) { - var cfg lbConfig +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, err } diff --git a/xds/internal/balancer/priority/config_test.go b/xds/internal/balancer/priority/config_test.go index 15c4069dd1e..8316224c91b 100644 --- a/xds/internal/balancer/priority/config_test.go +++ b/xds/internal/balancer/priority/config_test.go @@ -30,7 +30,7 @@ func TestParseConfig(t *testing.T) { tests := []struct { name string js string - want *lbConfig + want *LBConfig wantErr bool }{ { @@ -63,26 +63,27 @@ func TestParseConfig(t *testing.T) { js: `{ "priorities": ["child-1", "child-2", "child-3"], "children": { - "child-1": {"config": [{"round_robin":{}}]}, + "child-1": {"config": [{"round_robin":{}}], "ignoreReresolutionRequests": true}, "child-2": {"config": [{"round_robin":{}}]}, "child-3": {"config": [{"round_robin":{}}]} } } `, - want: &lbConfig{ - Children: map[string]*child{ + want: &LBConfig{ + Children: map[string]*Child{ "child-1": { - &internalserviceconfig.BalancerConfig{ + Config: &internalserviceconfig.BalancerConfig{ Name: roundrobin.Name, }, + IgnoreReresolutionRequests: true, }, "child-2": { - &internalserviceconfig.BalancerConfig{ + Config: &internalserviceconfig.BalancerConfig{ Name: roundrobin.Name, }, }, "child-3": { - &internalserviceconfig.BalancerConfig{ + Config: &internalserviceconfig.BalancerConfig{ Name: roundrobin.Name, }, }, diff --git a/xds/internal/balancer/priority/ignore_resolve_now.go b/xds/internal/balancer/priority/ignore_resolve_now.go new file mode 100644 index 00000000000..9a9f4777269 --- /dev/null +++ b/xds/internal/balancer/priority/ignore_resolve_now.go @@ -0,0 +1,73 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package priority + +import ( + "sync/atomic" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/resolver" +) + +type ignoreResolveNowBalancerBuilder struct { + balancer.Builder + ignoreResolveNow *uint32 +} + +// If `ignore` is true, all `ResolveNow()` from the balancer built from this +// builder will be ignored. +// +// `ignore` can be updated later by `updateIgnoreResolveNow`, and the update +// will be propagated to all the old and new balancers built with this. +func newIgnoreResolveNowBalancerBuilder(bb balancer.Builder, ignore bool) *ignoreResolveNowBalancerBuilder { + ret := &ignoreResolveNowBalancerBuilder{ + Builder: bb, + ignoreResolveNow: new(uint32), + } + ret.updateIgnoreResolveNow(ignore) + return ret +} + +func (irnbb *ignoreResolveNowBalancerBuilder) updateIgnoreResolveNow(b bool) { + if b { + atomic.StoreUint32(irnbb.ignoreResolveNow, 1) + return + } + atomic.StoreUint32(irnbb.ignoreResolveNow, 0) + +} + +func (irnbb *ignoreResolveNowBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + return irnbb.Builder.Build(&ignoreResolveNowClientConn{ + ClientConn: cc, + ignoreResolveNow: irnbb.ignoreResolveNow, + }, opts) +} + +type ignoreResolveNowClientConn struct { + balancer.ClientConn + ignoreResolveNow *uint32 +} + +func (i ignoreResolveNowClientConn) ResolveNow(o resolver.ResolveNowOptions) { + if atomic.LoadUint32(i.ignoreResolveNow) != 0 { + return + } + i.ClientConn.ResolveNow(o) +} diff --git a/xds/internal/balancer/priority/ignore_resolve_now_test.go b/xds/internal/balancer/priority/ignore_resolve_now_test.go new file mode 100644 index 00000000000..b7cecd6c1ff --- /dev/null +++ b/xds/internal/balancer/priority/ignore_resolve_now_test.go @@ -0,0 +1,104 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package priority + +import ( + "context" + "testing" + "time" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/roundrobin" + grpctestutils "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/testutils" +) + +const resolveNowBalancerName = "test-resolve-now-balancer" + +var resolveNowBalancerCCCh = grpctestutils.NewChannel() + +type resolveNowBalancerBuilder struct { + balancer.Builder +} + +func (r *resolveNowBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + resolveNowBalancerCCCh.Send(cc) + return r.Builder.Build(cc, opts) +} + +func (r *resolveNowBalancerBuilder) Name() string { + return resolveNowBalancerName +} + +func init() { + balancer.Register(&resolveNowBalancerBuilder{ + Builder: balancer.Get(roundrobin.Name), + }) +} + +func (s) TestIgnoreResolveNowBalancerBuilder(t *testing.T) { + resolveNowBB := balancer.Get(resolveNowBalancerName) + // Create a build wrapper, but will not ignore ResolveNow(). + ignoreResolveNowBB := newIgnoreResolveNowBalancerBuilder(resolveNowBB, false) + + cc := testutils.NewTestClientConn(t) + tb := ignoreResolveNowBB.Build(cc, balancer.BuildOptions{}) + defer tb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // This is the balancer.ClientConn that the inner resolverNowBalancer is + // built with. + balancerCCI, err := resolveNowBalancerCCCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout waiting for ClientConn from balancer builder") + } + balancerCC := balancerCCI.(balancer.ClientConn) + + // Call ResolveNow() on the CC, it should be forwarded. + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + select { + case <-cc.ResolveNowCh: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for ResolveNow()") + } + + // Update ignoreResolveNow to true, call ResolveNow() on the CC, they should + // all be ignored. + ignoreResolveNowBB.updateIgnoreResolveNow(true) + for i := 0; i < 5; i++ { + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + } + select { + case <-cc.ResolveNowCh: + t.Fatalf("got unexpected ResolveNow() call") + case <-time.After(time.Millisecond * 100): + } + + // Update ignoreResolveNow to false, new ResolveNow() calls should be + // forwarded. + ignoreResolveNowBB.updateIgnoreResolveNow(false) + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + select { + case <-cc.ResolveNowCh: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for ResolveNow()") + } +} diff --git a/xds/internal/balancer/ringhash/config.go b/xds/internal/balancer/ringhash/config.go new file mode 100644 index 00000000000..5cb4aab3d9c --- /dev/null +++ b/xds/internal/balancer/ringhash/config.go @@ -0,0 +1,56 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "encoding/json" + "fmt" + + "google.golang.org/grpc/serviceconfig" +) + +// LBConfig is the balancer config for ring_hash balancer. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + + MinRingSize uint64 `json:"minRingSize,omitempty"` + MaxRingSize uint64 `json:"maxRingSize,omitempty"` +} + +const ( + defaultMinSize = 1024 + defaultMaxSize = 8 * 1024 * 1024 // 8M +) + +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig + if err := json.Unmarshal(c, &cfg); err != nil { + return nil, err + } + if cfg.MinRingSize == 0 { + cfg.MinRingSize = defaultMinSize + } + if cfg.MaxRingSize == 0 { + cfg.MaxRingSize = defaultMaxSize + } + if cfg.MinRingSize > cfg.MaxRingSize { + return nil, fmt.Errorf("min %v is greater than max %v", cfg.MinRingSize, cfg.MaxRingSize) + } + return &cfg, nil +} diff --git a/xds/internal/balancer/ringhash/config_test.go b/xds/internal/balancer/ringhash/config_test.go new file mode 100644 index 00000000000..a2a966dc318 --- /dev/null +++ b/xds/internal/balancer/ringhash/config_test.go @@ -0,0 +1,68 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestParseConfig(t *testing.T) { + tests := []struct { + name string + js string + want *LBConfig + wantErr bool + }{ + { + name: "OK", + js: `{"minRingSize": 1, "maxRingSize": 2}`, + want: &LBConfig{MinRingSize: 1, MaxRingSize: 2}, + }, + { + name: "OK with default min", + js: `{"maxRingSize": 2000}`, + want: &LBConfig{MinRingSize: defaultMinSize, MaxRingSize: 2000}, + }, + { + name: "OK with default max", + js: `{"minRingSize": 2000}`, + want: &LBConfig{MinRingSize: 2000, MaxRingSize: defaultMaxSize}, + }, + { + name: "min greater than max", + js: `{"minRingSize": 10, "maxRingSize": 2}`, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseConfig([]byte(tt.js)) + if (err != nil) != tt.wantErr { + t.Errorf("parseConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("parseConfig() got unexpected output, diff (-got +want): %v", diff) + } + }) + } +} diff --git a/xds/internal/balancer/edsbalancer/logging.go b/xds/internal/balancer/ringhash/logging.go similarity index 83% rename from xds/internal/balancer/edsbalancer/logging.go rename to xds/internal/balancer/ringhash/logging.go index be4d0a512d1..64a1d467f55 100644 --- a/xds/internal/balancer/edsbalancer/logging.go +++ b/xds/internal/balancer/ringhash/logging.go @@ -1,6 +1,6 @@ /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ * */ -package edsbalancer +package ringhash import ( "fmt" @@ -25,10 +25,10 @@ import ( internalgrpclog "google.golang.org/grpc/internal/grpclog" ) -const prefix = "[eds-lb %p] " +const prefix = "[ring-hash-lb %p] " var logger = grpclog.Component("xds") -func prefixLogger(p *edsBalancer) *internalgrpclog.PrefixLogger { +func prefixLogger(p *ringhashBalancer) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) } diff --git a/xds/internal/balancer/ringhash/picker.go b/xds/internal/balancer/ringhash/picker.go new file mode 100644 index 00000000000..dcea6d46e51 --- /dev/null +++ b/xds/internal/balancer/ringhash/picker.go @@ -0,0 +1,154 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "fmt" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/status" +) + +type picker struct { + ring *ring + logger *grpclog.PrefixLogger +} + +func newPicker(ring *ring, logger *grpclog.PrefixLogger) *picker { + return &picker{ring: ring, logger: logger} +} + +// handleRICSResult is the return type of handleRICS. It's needed to wrap the +// returned error from Pick() in a struct. With this, if the return values are +// `balancer.PickResult, error, bool`, linter complains because error is not the +// last return value. +type handleRICSResult struct { + pr balancer.PickResult + err error +} + +// handleRICS generates pick result if the entry is in Ready, Idle, Connecting +// or Shutdown. TransientFailure will be handled specifically after this +// function returns. +// +// The first return value indicates if the state is in Ready, Idle, Connecting +// or Shutdown. If it's true, the PickResult and error should be returned from +// Pick() as is. +func (p *picker) handleRICS(e *ringEntry) (handleRICSResult, bool) { + switch state := e.sc.effectiveState(); state { + case connectivity.Ready: + return handleRICSResult{pr: balancer.PickResult{SubConn: e.sc.sc}}, true + case connectivity.Idle: + // Trigger Connect() and queue the pick. + e.sc.queueConnect() + return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true + case connectivity.Connecting: + return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true + case connectivity.TransientFailure: + // Return ok==false, so TransientFailure will be handled afterwards. + return handleRICSResult{}, false + case connectivity.Shutdown: + // Shutdown can happen in a race where the old picker is called. A new + // picker should already be sent. + return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true + default: + // Should never reach this. All the connectivity states are already + // handled in the cases. + p.logger.Errorf("SubConn has undefined connectivity state: %v", state) + return handleRICSResult{err: status.Errorf(codes.Unavailable, "SubConn has undefined connectivity state: %v", state)}, true + } +} + +func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + e := p.ring.pick(getRequestHash(info.Ctx)) + if hr, ok := p.handleRICS(e); ok { + return hr.pr, hr.err + } + // ok was false, the entry is in transient failure. + return p.handleTransientFailure(e) +} + +func (p *picker) handleTransientFailure(e *ringEntry) (balancer.PickResult, error) { + // Queue a connect on the first picked SubConn. + e.sc.queueConnect() + + // Find next entry in the ring, skipping duplicate SubConns. + e2 := nextSkippingDuplicates(p.ring, e) + if e2 == nil { + // There's no next entry available, fail the pick. + return balancer.PickResult{}, fmt.Errorf("the only SubConn is in Transient Failure") + } + + // For the second SubConn, also check Ready/Idle/Connecting as if it's the + // first entry. + if hr, ok := p.handleRICS(e2); ok { + return hr.pr, hr.err + } + + // The second SubConn is also in TransientFailure. Queue a connect on it. + e2.sc.queueConnect() + + // If it gets here, this is after the second SubConn, and the second SubConn + // was in TransientFailure. + // + // Loop over all other SubConns: + // - If all SubConns so far are all TransientFailure, trigger Connect() on + // the TransientFailure SubConns, and keep going. + // - If there's one SubConn that's not in TransientFailure, keep checking + // the remaining SubConns (in case there's a Ready, which will be returned), + // but don't not trigger Connect() on the other SubConns. + var firstNonFailedFound bool + for ee := nextSkippingDuplicates(p.ring, e2); ee != e; ee = nextSkippingDuplicates(p.ring, ee) { + scState := ee.sc.effectiveState() + if scState == connectivity.Ready { + return balancer.PickResult{SubConn: ee.sc.sc}, nil + } + if firstNonFailedFound { + continue + } + if scState == connectivity.TransientFailure { + // This will queue a connect. + ee.sc.queueConnect() + continue + } + // This is a SubConn in a non-failure state. We continue to check the + // other SubConns, but remember that there was a non-failed SubConn + // seen. After this, Pick() will never trigger any SubConn to Connect(). + firstNonFailedFound = true + if scState == connectivity.Idle { + // This is the first non-failed SubConn, and it is in a real Idle + // state. Trigger it to Connect(). + ee.sc.queueConnect() + } + } + return balancer.PickResult{}, fmt.Errorf("no connection is Ready") +} + +func nextSkippingDuplicates(ring *ring, entry *ringEntry) *ringEntry { + for next := ring.next(entry); next != entry; next = ring.next(next) { + if next.sc != entry.sc { + return next + } + } + // There's no qualifying next entry. + return nil +} diff --git a/xds/internal/balancer/ringhash/picker_test.go b/xds/internal/balancer/ringhash/picker_test.go new file mode 100644 index 00000000000..c88698ebbdf --- /dev/null +++ b/xds/internal/balancer/ringhash/picker_test.go @@ -0,0 +1,285 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/xds/internal/testutils" +) + +func newTestRing(cStats []connectivity.State) *ring { + var items []*ringEntry + for i, st := range cStats { + testSC := testutils.TestSubConns[i] + items = append(items, &ringEntry{ + idx: i, + hash: uint64((i + 1) * 10), + sc: &subConn{ + addr: testSC.String(), + sc: testSC, + state: st, + }, + }) + } + return &ring{items: items} +} + +func TestPickerPickFirstTwo(t *testing.T) { + tests := []struct { + name string + ring *ring + hash uint64 + wantSC balancer.SubConn + wantErr error + wantSCToConnect balancer.SubConn + }{ + { + name: "picked is Ready", + ring: newTestRing([]connectivity.State{connectivity.Ready, connectivity.Idle}), + hash: 5, + wantSC: testutils.TestSubConns[0], + }, + { + name: "picked is connecting, queue", + ring: newTestRing([]connectivity.State{connectivity.Connecting, connectivity.Idle}), + hash: 5, + wantErr: balancer.ErrNoSubConnAvailable, + }, + { + name: "picked is Idle, connect and queue", + ring: newTestRing([]connectivity.State{connectivity.Idle, connectivity.Idle}), + hash: 5, + wantErr: balancer.ErrNoSubConnAvailable, + wantSCToConnect: testutils.TestSubConns[0], + }, + { + name: "picked is TransientFailure, next is ready, return", + ring: newTestRing([]connectivity.State{connectivity.TransientFailure, connectivity.Ready}), + hash: 5, + wantSC: testutils.TestSubConns[1], + }, + { + name: "picked is TransientFailure, next is connecting, queue", + ring: newTestRing([]connectivity.State{connectivity.TransientFailure, connectivity.Connecting}), + hash: 5, + wantErr: balancer.ErrNoSubConnAvailable, + }, + { + name: "picked is TransientFailure, next is Idle, connect and queue", + ring: newTestRing([]connectivity.State{connectivity.TransientFailure, connectivity.Idle}), + hash: 5, + wantErr: balancer.ErrNoSubConnAvailable, + wantSCToConnect: testutils.TestSubConns[1], + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &picker{ring: tt.ring} + got, err := p.Pick(balancer.PickInfo{ + Ctx: SetRequestHash(context.Background(), tt.hash), + }) + if err != tt.wantErr { + t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !cmp.Equal(got, balancer.PickResult{SubConn: tt.wantSC}, cmpOpts) { + t.Errorf("Pick() got = %v, want picked SubConn: %v", got, tt.wantSC) + } + if sc := tt.wantSCToConnect; sc != nil { + select { + case <-sc.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestShortTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc) + } + } + }) + } +} + +// TestPickerPickTriggerTFConnect covers that if the picked SubConn is +// TransientFailures, all SubConns until a non-TransientFailure are queued for +// Connect(). +func TestPickerPickTriggerTFConnect(t *testing.T) { + ring := newTestRing([]connectivity.State{ + connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, + connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, + }) + p := &picker{ring: ring} + _, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) + if err == nil { + t.Fatalf("Pick() error = %v, want non-nil", err) + } + // The first 4 SubConns, all in TransientFailure, should be queued to + // connect. + for i := 0; i < 4; i++ { + it := ring.items[i] + if !it.sc.connectQueued { + t.Errorf("the %d-th SubConn is not queued for connect", i) + } + } + // The other SubConns, after the first Idle, should not be queued to + // connect. + for i := 5; i < len(ring.items); i++ { + it := ring.items[i] + if it.sc.connectQueued { + t.Errorf("the %d-th SubConn is unexpected queued for connect", i) + } + } +} + +// TestPickerPickTriggerTFReturnReady covers that if the picked SubConn is +// TransientFailure, SubConn 2 and 3 are TransientFailure, 4 is Ready. SubConn 2 +// and 3 will Connect(), and 4 will be returned. +func TestPickerPickTriggerTFReturnReady(t *testing.T) { + ring := newTestRing([]connectivity.State{ + connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Ready, + }) + p := &picker{ring: ring} + pr, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) + if err != nil { + t.Fatalf("Pick() error = %v, want nil", err) + } + if wantSC := testutils.TestSubConns[3]; pr.SubConn != wantSC { + t.Fatalf("Pick() = %v, want %v", pr.SubConn, wantSC) + } + // The first 3 SubConns, all in TransientFailure, should be queued to + // connect. + for i := 0; i < 3; i++ { + it := ring.items[i] + if !it.sc.connectQueued { + t.Errorf("the %d-th SubConn is not queued for connect", i) + } + } +} + +// TestPickerPickTriggerTFWithIdle covers that if the picked SubConn is +// TransientFailure, SubConn 2 is TransientFailure, 3 is Idle (init Idle). Pick +// will be queue, SubConn 3 will Connect(), SubConn 4 and 5 (in TransientFailre) +// will not queue a Connect. +func TestPickerPickTriggerTFWithIdle(t *testing.T) { + ring := newTestRing([]connectivity.State{ + connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, + }) + p := &picker{ring: ring} + _, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) + if err == balancer.ErrNoSubConnAvailable { + t.Fatalf("Pick() error = %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + // The first 2 SubConns, all in TransientFailure, should be queued to + // connect. + for i := 0; i < 2; i++ { + it := ring.items[i] + if !it.sc.connectQueued { + t.Errorf("the %d-th SubConn is not queued for connect", i) + } + } + // SubConn 3 was in Idle, so should Connect() + select { + case <-testutils.TestSubConns[2].ConnectCh: + case <-time.After(defaultTestShortTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", testutils.TestSubConns[2]) + } + // The other SubConns, after the first Idle, should not be queued to + // connect. + for i := 3; i < len(ring.items); i++ { + it := ring.items[i] + if it.sc.connectQueued { + t.Errorf("the %d-th SubConn is unexpected queued for connect", i) + } + } +} + +func TestNextSkippingDuplicatesNoDup(t *testing.T) { + testRing := newTestRing([]connectivity.State{connectivity.Idle, connectivity.Idle}) + tests := []struct { + name string + ring *ring + cur *ringEntry + want *ringEntry + }{ + { + name: "no dup", + ring: testRing, + cur: testRing.items[0], + want: testRing.items[1], + }, + { + name: "only one entry", + ring: &ring{items: []*ringEntry{testRing.items[0]}}, + cur: testRing.items[0], + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := nextSkippingDuplicates(tt.ring, tt.cur); !cmp.Equal(got, tt.want, cmpOpts) { + t.Errorf("nextSkippingDuplicates() = %v, want %v", got, tt.want) + } + }) + } +} + +// addDups adds duplicates of items[0] to the ring. +func addDups(r *ring, count int) *ring { + var ( + items []*ringEntry + idx int + ) + for i, it := range r.items { + itt := *it + itt.idx = idx + items = append(items, &itt) + idx++ + if i == 0 { + // Add duplicate of items[0] to the ring + for j := 0; j < count; j++ { + itt2 := *it + itt2.idx = idx + items = append(items, &itt2) + idx++ + } + } + } + return &ring{items: items} +} + +func TestNextSkippingDuplicatesMoreDup(t *testing.T) { + testRing := newTestRing([]connectivity.State{connectivity.Idle, connectivity.Idle}) + // Make a new ring with duplicate SubConns. + dupTestRing := addDups(testRing, 3) + if got := nextSkippingDuplicates(dupTestRing, dupTestRing.items[0]); !cmp.Equal(got, dupTestRing.items[len(dupTestRing.items)-1], cmpOpts) { + t.Errorf("nextSkippingDuplicates() = %v, want %v", got, dupTestRing.items[len(dupTestRing.items)-1]) + } +} + +func TestNextSkippingDuplicatesOnlyDup(t *testing.T) { + testRing := newTestRing([]connectivity.State{connectivity.Idle}) + // Make a new ring with only duplicate SubConns. + dupTestRing := addDups(testRing, 3) + // This ring only has duplicates of items[0], should return nil. + if got := nextSkippingDuplicates(dupTestRing, dupTestRing.items[0]); got != nil { + t.Errorf("nextSkippingDuplicates() = %v, want nil", got) + } +} diff --git a/xds/internal/balancer/ringhash/ring.go b/xds/internal/balancer/ringhash/ring.go new file mode 100644 index 00000000000..68e844cfb48 --- /dev/null +++ b/xds/internal/balancer/ringhash/ring.go @@ -0,0 +1,163 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "fmt" + "math" + "sort" + "strconv" + + xxhash "github.com/cespare/xxhash/v2" + "google.golang.org/grpc/resolver" +) + +type ring struct { + items []*ringEntry +} + +type subConnWithWeight struct { + sc *subConn + weight float64 +} + +type ringEntry struct { + idx int + hash uint64 + sc *subConn +} + +// newRing creates a ring from the subConns. The ring size is limited by the +// passed in max/min. +// +// ring entries will be created for each subConn, and subConn with high weight +// (specified by the address) may have multiple entries. +// +// For example, for subConns with weights {a:3, b:3, c:4}, a generated ring of +// size 10 could be: +// - {idx:0 hash:3689675255460411075 b} +// - {idx:1 hash:4262906501694543955 c} +// - {idx:2 hash:5712155492001633497 c} +// - {idx:3 hash:8050519350657643659 b} +// - {idx:4 hash:8723022065838381142 b} +// - {idx:5 hash:11532782514799973195 a} +// - {idx:6 hash:13157034721563383607 c} +// - {idx:7 hash:14468677667651225770 c} +// - {idx:8 hash:17336016884672388720 a} +// - {idx:9 hash:18151002094784932496 a} +// +// To pick from a ring, a binary search will be done for the given target hash, +// and first item with hash >= given hash will be returned. +func newRing(subConns map[resolver.Address]*subConn, minRingSize, maxRingSize uint64) (*ring, error) { + // https://github.com/envoyproxy/envoy/blob/765c970f06a4c962961a0e03a467e165b276d50f/source/common/upstream/ring_hash_lb.cc#L114 + normalizedWeights, minWeight, err := normalizeWeights(subConns) + if err != nil { + return nil, err + } + // Normalized weights for {3,3,4} is {0.3,0.3,0.4}. + + // Scale up the size of the ring such that the least-weighted host gets a + // whole number of hashes on the ring. + // + // Note that size is limited by the input max/min. + scale := math.Min(math.Ceil(minWeight*float64(minRingSize))/minWeight, float64(maxRingSize)) + ringSize := math.Ceil(scale) + items := make([]*ringEntry, 0, int(ringSize)) + + // For each entry, scale*weight nodes are generated in the ring. + // + // Not all of these are whole numbers. E.g. for weights {a:3,b:3,c:4}, if + // ring size is 7, scale is 6.66. The numbers of nodes will be + // {a,a,b,b,c,c,c}. + // + // A hash is generated for each item, and later the results will be sorted + // based on the hash. + var ( + idx int + targetIdx float64 + ) + for _, scw := range normalizedWeights { + targetIdx += scale * scw.weight + for float64(idx) < targetIdx { + h := xxhash.Sum64String(scw.sc.addr + strconv.Itoa(len(items))) + items = append(items, &ringEntry{idx: idx, hash: h, sc: scw.sc}) + idx++ + } + } + + // Sort items based on hash, to prepare for binary search. + sort.Slice(items, func(i, j int) bool { return items[i].hash < items[j].hash }) + for i, ii := range items { + ii.idx = i + } + return &ring{items: items}, nil +} + +// normalizeWeights divides all the weights by the sum, so that the total weight +// is 1. +func normalizeWeights(subConns map[resolver.Address]*subConn) (_ []subConnWithWeight, min float64, _ error) { + if len(subConns) == 0 { + return nil, 0, fmt.Errorf("number of subconns is 0") + } + var weightSum uint32 + for a := range subConns { + // The address weight was moved from attributes to the Metadata field. + // This is necessary (all the attributes need to be stripped) for the + // balancer to detect identical {address+weight} combination. + weightSum += a.Metadata.(uint32) + } + if weightSum == 0 { + return nil, 0, fmt.Errorf("total weight of all subconns is 0") + } + weightSumF := float64(weightSum) + ret := make([]subConnWithWeight, 0, len(subConns)) + min = math.MaxFloat64 + for a, sc := range subConns { + nw := float64(a.Metadata.(uint32)) / weightSumF + ret = append(ret, subConnWithWeight{sc: sc, weight: nw}) + if nw < min { + min = nw + } + } + // Sort the addresses to return consistent results. + // + // Note: this might not be necessary, but this makes sure the ring is + // consistent as long as the addresses are the same, for example, in cases + // where an address is added and then removed, the RPCs will still pick the + // same old SubConn. + sort.Slice(ret, func(i, j int) bool { return ret[i].sc.addr < ret[j].sc.addr }) + return ret, min, nil +} + +// pick does a binary search. It returns the item with smallest index i that +// r.items[i].hash >= h. +func (r *ring) pick(h uint64) *ringEntry { + i := sort.Search(len(r.items), func(i int) bool { return r.items[i].hash >= h }) + if i == len(r.items) { + // If not found, and h is greater than the largest hash, return the + // first item. + i = 0 + } + return r.items[i] +} + +// next returns the next entry. +func (r *ring) next(e *ringEntry) *ringEntry { + return r.items[(e.idx+1)%len(r.items)] +} diff --git a/xds/internal/balancer/ringhash/ring_test.go b/xds/internal/balancer/ringhash/ring_test.go new file mode 100644 index 00000000000..2d664e05bb2 --- /dev/null +++ b/xds/internal/balancer/ringhash/ring_test.go @@ -0,0 +1,113 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "fmt" + "math" + "testing" + + xxhash "github.com/cespare/xxhash/v2" + "google.golang.org/grpc/resolver" +) + +func testAddr(addr string, weight uint32) resolver.Address { + return resolver.Address{Addr: addr, Metadata: weight} +} + +func TestRingNew(t *testing.T) { + testAddrs := []resolver.Address{ + testAddr("a", 3), + testAddr("b", 3), + testAddr("c", 4), + } + var totalWeight float64 = 10 + testSubConnMap := map[resolver.Address]*subConn{ + testAddr("a", 3): {addr: "a"}, + testAddr("b", 3): {addr: "b"}, + testAddr("c", 4): {addr: "c"}, + } + for _, min := range []uint64{3, 4, 6, 8} { + for _, max := range []uint64{20, 8} { + t.Run(fmt.Sprintf("size-min-%v-max-%v", min, max), func(t *testing.T) { + r, _ := newRing(testSubConnMap, min, max) + totalCount := len(r.items) + if totalCount < int(min) || totalCount > int(max) { + t.Fatalf("unexpect size %v, want min %v, max %v", totalCount, min, max) + } + for _, a := range testAddrs { + var count int + for _, ii := range r.items { + if ii.sc.addr == a.Addr { + count++ + } + } + got := float64(count) / float64(totalCount) + want := float64(a.Metadata.(uint32)) / totalWeight + if !equalApproximately(got, want) { + t.Fatalf("unexpected item weight in ring: %v != %v", got, want) + } + } + }) + } + } +} + +func equalApproximately(x, y float64) bool { + delta := math.Abs(x - y) + mean := math.Abs(x+y) / 2.0 + return delta/mean < 0.25 +} + +func TestRingPick(t *testing.T) { + r, _ := newRing(map[resolver.Address]*subConn{ + {Addr: "a", Metadata: uint32(3)}: {addr: "a"}, + {Addr: "b", Metadata: uint32(3)}: {addr: "b"}, + {Addr: "c", Metadata: uint32(4)}: {addr: "c"}, + }, 10, 20) + for _, h := range []uint64{xxhash.Sum64String("1"), xxhash.Sum64String("2"), xxhash.Sum64String("3"), xxhash.Sum64String("4")} { + t.Run(fmt.Sprintf("picking-hash-%v", h), func(t *testing.T) { + e := r.pick(h) + var low uint64 + if e.idx > 0 { + low = r.items[e.idx-1].hash + } + high := e.hash + // h should be in [low, high). + if h < low || h >= high { + t.Fatalf("unexpected item picked, hash: %v, low: %v, high: %v", h, low, high) + } + }) + } +} + +func TestRingNext(t *testing.T) { + r, _ := newRing(map[resolver.Address]*subConn{ + {Addr: "a", Metadata: uint32(3)}: {addr: "a"}, + {Addr: "b", Metadata: uint32(3)}: {addr: "b"}, + {Addr: "c", Metadata: uint32(4)}: {addr: "c"}, + }, 10, 20) + + for _, e := range r.items { + ne := r.next(e) + if ne.idx != (e.idx+1)%len(r.items) { + t.Fatalf("next(%+v) returned unexpected %+v", e, ne) + } + } +} diff --git a/xds/internal/balancer/ringhash/ringhash.go b/xds/internal/balancer/ringhash/ringhash.go new file mode 100644 index 00000000000..f8a47f165bd --- /dev/null +++ b/xds/internal/balancer/ringhash/ringhash.go @@ -0,0 +1,434 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package ringhash implements the ringhash balancer. +package ringhash + +import ( + "encoding/json" + "errors" + "fmt" + "sync" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" +) + +// Name is the name of the ring_hash balancer. +const Name = "ring_hash_experimental" + +func init() { + balancer.Register(bb{}) +} + +type bb struct{} + +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { + b := &ringhashBalancer{ + cc: cc, + subConns: make(map[resolver.Address]*subConn), + scStates: make(map[balancer.SubConn]*subConn), + csEvltr: &connectivityStateEvaluator{}, + } + b.logger = prefixLogger(b) + b.logger.Infof("Created") + return b +} + +func (bb) Name() string { + return Name +} + +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + return parseConfig(c) +} + +type subConn struct { + addr string + sc balancer.SubConn + + mu sync.RWMutex + // This is the actual state of this SubConn (as updated by the ClientConn). + // The effective state can be different, see comment of attemptedToConnect. + state connectivity.State + // failing is whether this SubConn is in a failing state. A subConn is + // considered to be in a failing state if it was previously in + // TransientFailure. + // + // This affects the effective connectivity state of this SubConn, e.g. + // - if the actual state is Idle or Connecting, but this SubConn is failing, + // the effective state is TransientFailure. + // + // This is used in pick(). E.g. if a subConn is Idle, but has failing as + // true, pick() will + // - consider this SubConn as TransientFailure, and check the state of the + // next SubConn. + // - trigger Connect() (note that normally a SubConn in real + // TransientFailure cannot Connect()) + // + // A subConn starts in non-failing (failing is false). A transition to + // TransientFailure sets failing to true (and it stays true). A transition + // to Ready sets failing to false. + failing bool + // connectQueued is true if a Connect() was queued for this SubConn while + // it's not in Idle (most likely was in TransientFailure). A Connect() will + // be triggered on this SubConn when it turns Idle. + // + // When connectivity state is updated to Idle for this SubConn, if + // connectQueued is true, Connect() will be called on the SubConn. + connectQueued bool +} + +// setState updates the state of this SubConn. +// +// It also handles the queued Connect(). If the new state is Idle, and a +// Connect() was queued, this SubConn will be triggered to Connect(). +func (sc *subConn) setState(s connectivity.State) { + sc.mu.Lock() + defer sc.mu.Unlock() + switch s { + case connectivity.Idle: + // Trigger Connect() if new state is Idle, and there is a queued connect. + if sc.connectQueued { + sc.connectQueued = false + sc.sc.Connect() + } + case connectivity.Connecting: + // Clear connectQueued if the SubConn isn't failing. This state + // transition is unlikely to happen, but handle this just in case. + sc.connectQueued = false + case connectivity.Ready: + // Clear connectQueued if the SubConn isn't failing. This state + // transition is unlikely to happen, but handle this just in case. + sc.connectQueued = false + // Set to a non-failing state. + sc.failing = false + case connectivity.TransientFailure: + // Set to a failing state. + sc.failing = true + } + sc.state = s +} + +// effectiveState returns the effective state of this SubConn. It can be +// different from the actual state, e.g. Idle while the subConn is failing is +// considered TransientFailure. Read comment of field failing for other cases. +func (sc *subConn) effectiveState() connectivity.State { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.failing && (sc.state == connectivity.Idle || sc.state == connectivity.Connecting) { + return connectivity.TransientFailure + } + return sc.state +} + +// queueConnect sets a boolean so that when the SubConn state changes to Idle, +// it's Connect() will be triggered. If the SubConn state is already Idle, it +// will just call Connect(). +func (sc *subConn) queueConnect() { + sc.mu.Lock() + defer sc.mu.Unlock() + if sc.state == connectivity.Idle { + sc.sc.Connect() + return + } + // Queue this connect, and when this SubConn switches back to Idle (happens + // after backoff in TransientFailure), it will Connect(). + sc.connectQueued = true +} + +type ringhashBalancer struct { + cc balancer.ClientConn + logger *grpclog.PrefixLogger + + config *LBConfig + + subConns map[resolver.Address]*subConn // `attributes` is stripped from the keys of this map (the addresses) + scStates map[balancer.SubConn]*subConn + + // ring is always in sync with subConns. When subConns change, a new ring is + // generated. Note that address weights updates (they are keys in the + // subConns map) also regenerates the ring. + ring *ring + picker balancer.Picker + csEvltr *connectivityStateEvaluator + state connectivity.State + + resolverErr error // the last error reported by the resolver; cleared on successful resolution + connErr error // the last connection error; cleared upon leaving TransientFailure +} + +// updateAddresses creates new SubConns and removes SubConns, based on the +// address update. +// +// The return value is whether the new address list is different from the +// previous. True if +// - an address was added +// - an address was removed +// - an address's weight was updated +// +// Note that this function doesn't trigger SubConn connecting, so all the new +// SubConn states are Idle. +func (b *ringhashBalancer) updateAddresses(addrs []resolver.Address) bool { + var addrsUpdated bool + // addrsSet is the set converted from addrs, it's used for quick lookup of + // an address. + // + // Addresses in this map all have attributes stripped, but metadata set to + // the weight. So that weight change can be detected. + // + // TODO: this won't be necessary if there are ways to compare address + // attributes. + addrsSet := make(map[resolver.Address]struct{}) + for _, a := range addrs { + aNoAttrs := a + // Strip attributes but set Metadata to the weight. + aNoAttrs.Attributes = nil + w := weightedroundrobin.GetAddrInfo(a).Weight + if w == 0 { + // If weight is not set, use 1. + w = 1 + } + aNoAttrs.Metadata = w + addrsSet[aNoAttrs] = struct{}{} + if scInfo, ok := b.subConns[aNoAttrs]; !ok { + // When creating SubConn, the original address with attributes is + // passed through. So that connection configurations in attributes + // (like creds) will be used. + sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{HealthCheckEnabled: true}) + if err != nil { + logger.Warningf("base.baseBalancer: failed to create new SubConn: %v", err) + continue + } + scs := &subConn{addr: a.Addr, sc: sc} + scs.setState(connectivity.Idle) + b.state = b.csEvltr.recordTransition(connectivity.Shutdown, connectivity.Idle) + b.subConns[aNoAttrs] = scs + b.scStates[sc] = scs + addrsUpdated = true + } else { + // Always update the subconn's address in case the attributes + // changed. The SubConn does a reflect.DeepEqual of the new and old + // addresses. So this is a noop if the current address is the same + // as the old one (including attributes). + b.subConns[aNoAttrs] = scInfo + b.cc.UpdateAddresses(scInfo.sc, []resolver.Address{a}) + } + } + for a, scInfo := range b.subConns { + // a was removed by resolver. + if _, ok := addrsSet[a]; !ok { + b.cc.RemoveSubConn(scInfo.sc) + delete(b.subConns, a) + addrsUpdated = true + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in UpdateSubConnState. + } + } + return addrsUpdated +} + +func (b *ringhashBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) + if b.config == nil { + newConfig, ok := s.BalancerConfig.(*LBConfig) + if !ok { + return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) + } + b.config = newConfig + } + + // Successful resolution; clear resolver error and ensure we return nil. + b.resolverErr = nil + if b.updateAddresses(s.ResolverState.Addresses) { + // If addresses were updated, no matter whether it resulted in SubConn + // creation/deletion, or just weight update, we will need to regenerate + // the ring. + var err error + b.ring, err = newRing(b.subConns, b.config.MinRingSize, b.config.MaxRingSize) + if err != nil { + panic(err) + } + b.regeneratePicker() + b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) + } + + // If resolver state contains no addresses, return an error so ClientConn + // will trigger re-resolve. Also records this as an resolver error, so when + // the overall state turns transient failure, the error message will have + // the zero address information. + if len(s.ResolverState.Addresses) == 0 { + b.ResolverError(errors.New("produced zero addresses")) + return balancer.ErrBadResolverState + } + return nil +} + +func (b *ringhashBalancer) ResolverError(err error) { + b.resolverErr = err + if len(b.subConns) == 0 { + b.state = connectivity.TransientFailure + } + + if b.state != connectivity.TransientFailure { + // The picker will not change since the balancer does not currently + // report an error. + return + } + b.regeneratePicker() + b.cc.UpdateState(balancer.State{ + ConnectivityState: b.state, + Picker: b.picker, + }) +} + +// UpdateSubConnState updates the per-SubConn state stored in the ring, and also +// the aggregated state. +// +// It triggers an update to cc when: +// - the new state is TransientFailure, to update the error message +// - it's possible that this is a noop, but sending an extra update is easier +// than comparing errors +// - the aggregated state is changed +// - the same picker will be sent again, but this update may trigger a re-pick +// for some RPCs. +func (b *ringhashBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + s := state.ConnectivityState + b.logger.Infof("handle SubConn state change: %p, %v", sc, s) + scs, ok := b.scStates[sc] + if !ok { + b.logger.Infof("got state changes for an unknown SubConn: %p, %v", sc, s) + return + } + oldSCState := scs.effectiveState() + scs.setState(s) + newSCState := scs.effectiveState() + + var sendUpdate bool + oldBalancerState := b.state + b.state = b.csEvltr.recordTransition(oldSCState, newSCState) + if oldBalancerState != b.state { + sendUpdate = true + } + + switch s { + case connectivity.Idle: + // When the overall state is TransientFailure, this will never get picks + // if there's a lower priority. Need to keep the SubConns connecting so + // there's a chance it will recover. + if b.state == connectivity.TransientFailure { + scs.queueConnect() + } + // No need to send an update. No queued RPC can be unblocked. If the + // overall state changed because of this, sendUpdate is already true. + case connectivity.Connecting: + // No need to send an update. No queued RPC can be unblocked. If the + // overall state changed because of this, sendUpdate is already true. + case connectivity.Ready: + // Resend the picker, there's no need to regenerate the picker because + // the ring didn't change. + sendUpdate = true + case connectivity.TransientFailure: + // Save error to be reported via picker. + b.connErr = state.ConnectionError + // Regenerate picker to update error message. + b.regeneratePicker() + sendUpdate = true + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(b.scStates, sc) + } + + if sendUpdate { + b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) + } +} + +// mergeErrors builds an error from the last connection error and the last +// resolver error. Must only be called if b.state is TransientFailure. +func (b *ringhashBalancer) mergeErrors() error { + // connErr must always be non-nil unless there are no SubConns, in which + // case resolverErr must be non-nil. + if b.connErr == nil { + return fmt.Errorf("last resolver error: %v", b.resolverErr) + } + if b.resolverErr == nil { + return fmt.Errorf("last connection error: %v", b.connErr) + } + return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr) +} + +func (b *ringhashBalancer) regeneratePicker() { + if b.state == connectivity.TransientFailure { + b.picker = base.NewErrPicker(b.mergeErrors()) + return + } + b.picker = newPicker(b.ring, b.logger) +} + +func (b *ringhashBalancer) Close() {} + +// connectivityStateEvaluator takes the connectivity states of multiple SubConns +// and returns one aggregated connectivity state. +// +// It's not thread safe. +type connectivityStateEvaluator struct { + nums [5]uint64 +} + +// recordTransition records state change happening in subConn and based on that +// it evaluates what aggregated state should be. +// +// - If there is at least one subchannel in READY state, report READY. +// - If there are 2 or more subchannels in TRANSIENT_FAILURE state, report TRANSIENT_FAILURE. +// - If there is at least one subchannel in CONNECTING state, report CONNECTING. +// - If there is at least one subchannel in Idle state, report Idle. +// - Otherwise, report TRANSIENT_FAILURE. +// +// Note that if there are 1 connecting, 2 transient failure, the overall state +// is transient failure. This is because the second transient failure is a +// fallback of the first failing SubConn, and we want to report transient +// failure to failover to the lower priority. +func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State { + // Update counters. + for idx, state := range []connectivity.State{oldState, newState} { + updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. + cse.nums[state] += updateVal + } + + if cse.nums[connectivity.Ready] > 0 { + return connectivity.Ready + } + if cse.nums[connectivity.TransientFailure] > 1 { + return connectivity.TransientFailure + } + if cse.nums[connectivity.Connecting] > 0 { + return connectivity.Connecting + } + if cse.nums[connectivity.Idle] > 0 { + return connectivity.Idle + } + return connectivity.TransientFailure +} diff --git a/xds/internal/balancer/ringhash/ringhash_test.go b/xds/internal/balancer/ringhash/ringhash_test.go new file mode 100644 index 00000000000..fb85367e4a4 --- /dev/null +++ b/xds/internal/balancer/ringhash/ringhash_test.go @@ -0,0 +1,458 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/testutils" +) + +var ( + cmpOpts = cmp.Options{ + cmp.AllowUnexported(testutils.TestSubConn{}, ringEntry{}, subConn{}), + cmpopts.IgnoreFields(subConn{}, "mu"), + } +) + +const ( + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + + testBackendAddrsCount = 12 +) + +var ( + testBackendAddrStrs []string + testConfig = &LBConfig{MinRingSize: 1, MaxRingSize: 10} +) + +func init() { + for i := 0; i < testBackendAddrsCount; i++ { + testBackendAddrStrs = append(testBackendAddrStrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) + } +} + +func ctxWithHash(h uint64) context.Context { + return SetRequestHash(context.Background(), h) +} + +// setupTest creates the balancer, and does an initial sanity check. +func setupTest(t *testing.T, addrs []resolver.Address) (*testutils.TestClientConn, balancer.Balancer, balancer.Picker) { + t.Helper() + cc := testutils.NewTestClientConn(t) + builder := balancer.Get(Name) + b := builder.Build(cc, balancer.BuildOptions{}) + if b == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: addrs}, + BalancerConfig: testConfig, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + + for _, addr := range addrs { + addr1 := <-cc.NewSubConnAddrsCh + if want := []resolver.Address{addr}; !cmp.Equal(addr1, want, cmp.AllowUnexported(attributes.Attributes{})) { + t.Fatalf("got unexpected new subconn addrs: %v", cmp.Diff(addr1, want, cmp.AllowUnexported(attributes.Attributes{}))) + } + sc1 := <-cc.NewSubConnCh + // All the SubConns start in Idle, and should not Connect(). + select { + case <-sc1.(*testutils.TestSubConn).ConnectCh: + t.Errorf("unexpected Connect() from SubConn %v", sc1) + case <-time.After(defaultTestShortTimeout): + } + } + + // Should also have a picker, with all SubConns in Idle. + p1 := <-cc.NewPickerCh + return cc, b, p1 +} + +func TestOneSubConn(t *testing.T) { + wantAddr1 := resolver.Address{Addr: testBackendAddrStrs[0]} + cc, b, p0 := setupTest(t, []resolver.Address{wantAddr1}) + ring0 := p0.(*picker).ring + + firstHash := ring0.items[0].hash + // firstHash-1 will pick the first (and only) SubConn from the ring. + testHash := firstHash - 1 + // The first pick should be queued, and should trigger Connect() on the only + // SubConn. + if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + sc0 := ring0.items[0].sc.sc + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // Send state updates to Ready. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test pick with one backend. + p1 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } +} + +// TestThreeBackendsAffinity covers that there are 3 SubConns, RPCs with the +// same hash always pick the same SubConn. When the one picked is down, another +// one will be picked. +func TestThreeSubConnsAffinity(t *testing.T) { + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + {Addr: testBackendAddrStrs[2]}, + } + cc, b, p0 := setupTest(t, wantAddrs) + // This test doesn't update addresses, so this ring will be used by all the + // pickers. + ring0 := p0.(*picker).ring + + firstHash := ring0.items[0].hash + // firstHash+1 will pick the second SubConn from the ring. + testHash := firstHash + 1 + // The first pick should be queued, and should trigger Connect() on the only + // SubConn. + if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + // The picked SubConn should be the second in the ring. + sc0 := ring0.items[1].sc.sc + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // Send state updates to Ready. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + p1 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } + + // Turn down the subConn in use. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + p2 := <-cc.NewPickerCh + // Pick with the same hash should be queued, because the SubConn after the + // first picked is Idle. + if _, err := p2.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + + // The third SubConn in the ring should connect. + sc1 := ring0.items[2].sc.sc + select { + case <-sc1.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc1) + } + + // Send state updates to Ready. + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // New picks should all return this SubConn. + p3 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p3.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) + } + } + + // Now, after backoff, the first picked SubConn will turn Idle. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + // The picks above should have queued Connect() for the first picked + // SubConn, so this Idle state change will trigger a Connect(). + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // After the first picked SubConn turn Ready, new picks should return it + // again (even though the second picked SubConn is also Ready). + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + p4 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p4.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } +} + +// TestThreeBackendsAffinity covers that there are 3 SubConns, RPCs with the +// same hash always pick the same SubConn. Then try different hash to pick +// another backend, and verify the first hash still picks the first backend. +func TestThreeSubConnsAffinityMultiple(t *testing.T) { + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + {Addr: testBackendAddrStrs[2]}, + } + cc, b, p0 := setupTest(t, wantAddrs) + // This test doesn't update addresses, so this ring will be used by all the + // pickers. + ring0 := p0.(*picker).ring + + firstHash := ring0.items[0].hash + // firstHash+1 will pick the second SubConn from the ring. + testHash := firstHash + 1 + // The first pick should be queued, and should trigger Connect() on the only + // SubConn. + if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + sc0 := ring0.items[1].sc.sc + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // Send state updates to Ready. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // First hash should always pick sc0. + p1 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } + + secondHash := ring0.items[1].hash + // secondHash+1 will pick the third SubConn from the ring. + testHash2 := secondHash + 1 + if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash2)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + sc1 := ring0.items[2].sc.sc + select { + case <-sc1.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc1) + } + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // With the new generated picker, hash2 always picks sc1. + p2 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash2)}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) + } + } + // But the first hash still picks sc0. + for i := 0; i < 5; i++ { + gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } +} + +func TestAddrWeightChange(t *testing.T) { + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + {Addr: testBackendAddrStrs[2]}, + } + cc, b, p0 := setupTest(t, wantAddrs) + ring0 := p0.(*picker).ring + + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: wantAddrs}, + BalancerConfig: nil, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + select { + case <-cc.NewPickerCh: + t.Fatalf("unexpected picker after UpdateClientConn with the same addresses") + case <-time.After(defaultTestShortTimeout): + } + + // Delete an address, should send a new Picker. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + }}, + BalancerConfig: nil, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + var p1 balancer.Picker + select { + case p1 = <-cc.NewPickerCh: + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for picker after UpdateClientConn with different addresses") + } + ring1 := p1.(*picker).ring + if ring1 == ring0 { + t.Fatalf("new picker after removing address has the same ring as before, want different") + } + + // Another update with the same addresses, but different weight. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + weightedroundrobin.SetAddrInfo( + resolver.Address{Addr: testBackendAddrStrs[1]}, + weightedroundrobin.AddrInfo{Weight: 2}), + }}, + BalancerConfig: nil, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + var p2 balancer.Picker + select { + case p2 = <-cc.NewPickerCh: + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for picker after UpdateClientConn with different addresses") + } + if p2.(*picker).ring == ring1 { + t.Fatalf("new picker after changing address weight has the same ring as before, want different") + } +} + +// TestSubConnToConnectWhenOverallTransientFailure covers the situation when the +// overall state is TransientFailure, the SubConns turning Idle will be +// triggered to Connect(). But not when the overall state is not +// TransientFailure. +func TestSubConnToConnectWhenOverallTransientFailure(t *testing.T) { + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + {Addr: testBackendAddrStrs[2]}, + } + _, b, p0 := setupTest(t, wantAddrs) + ring0 := p0.(*picker).ring + + // Turn all SubConns to TransientFailure. + for _, it := range ring0.items { + b.UpdateSubConnState(it.sc.sc, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + } + + // The next one turning Idle should Connect(). + sc0 := ring0.items[0].sc.sc + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // If this SubConn is ready. Other SubConns turning Idle will not Connect(). + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // The third SubConn in the ring should connect. + sc1 := ring0.items[1].sc.sc + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + select { + case <-sc1.(*testutils.TestSubConn).ConnectCh: + t.Errorf("unexpected Connect() from SubConn %v", sc1) + case <-time.After(defaultTestShortTimeout): + } +} + +func TestConnectivityStateEvaluatorRecordTransition(t *testing.T) { + tests := []struct { + name string + from, to []connectivity.State + want connectivity.State + }{ + { + name: "one ready", + from: []connectivity.State{connectivity.Idle}, + to: []connectivity.State{connectivity.Ready}, + want: connectivity.Ready, + }, + { + name: "one connecting", + from: []connectivity.State{connectivity.Idle}, + to: []connectivity.State{connectivity.Connecting}, + want: connectivity.Connecting, + }, + { + name: "one ready one transient failure", + from: []connectivity.State{connectivity.Idle, connectivity.Idle}, + to: []connectivity.State{connectivity.Ready, connectivity.TransientFailure}, + want: connectivity.Ready, + }, + { + name: "one connecting one transient failure", + from: []connectivity.State{connectivity.Idle, connectivity.Idle}, + to: []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, + want: connectivity.Connecting, + }, + { + name: "one connecting two transient failure", + from: []connectivity.State{connectivity.Idle, connectivity.Idle, connectivity.Idle}, + to: []connectivity.State{connectivity.Connecting, connectivity.TransientFailure, connectivity.TransientFailure}, + want: connectivity.TransientFailure, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cse := &connectivityStateEvaluator{} + var got connectivity.State + for i, fff := range tt.from { + ttt := tt.to[i] + got = cse.recordTransition(fff, ttt) + } + if got != tt.want { + t.Errorf("recordTransition() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/xds/internal/balancer/ringhash/util.go b/xds/internal/balancer/ringhash/util.go new file mode 100644 index 00000000000..92bb3ae5b79 --- /dev/null +++ b/xds/internal/balancer/ringhash/util.go @@ -0,0 +1,40 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import "context" + +type clusterKey struct{} + +func getRequestHash(ctx context.Context) uint64 { + requestHash, _ := ctx.Value(clusterKey{}).(uint64) + return requestHash +} + +// GetRequestHashForTesting returns the request hash in the context; to be used +// for testing only. +func GetRequestHashForTesting(ctx context.Context) uint64 { + return getRequestHash(ctx) +} + +// SetRequestHash adds the request hash to the context for use in Ring Hash Load +// Balancing. +func SetRequestHash(ctx context.Context, requestHash uint64) context.Context { + return context.WithValue(ctx, clusterKey{}, requestHash) +} diff --git a/xds/internal/balancer/weightedtarget/weightedaggregator/aggregator.go b/xds/internal/balancer/weightedtarget/weightedaggregator/aggregator.go index 6c36e2a69cd..7e1d106e9ff 100644 --- a/xds/internal/balancer/weightedtarget/weightedaggregator/aggregator.go +++ b/xds/internal/balancer/weightedtarget/weightedaggregator/aggregator.go @@ -200,7 +200,9 @@ func (wbsa *Aggregator) BuildAndUpdate() { func (wbsa *Aggregator) build() balancer.State { wbsa.logger.Infof("Child pickers with config: %+v", wbsa.idToPickerState) m := wbsa.idToPickerState - var readyN, connectingN int + // TODO: use balancer.ConnectivityStateEvaluator to calculate the aggregated + // state. + var readyN, connectingN, idleN int readyPickerWithWeights := make([]weightedPickerState, 0, len(m)) for _, ps := range m { switch ps.stateToAggregate { @@ -209,6 +211,8 @@ func (wbsa *Aggregator) build() balancer.State { readyPickerWithWeights = append(readyPickerWithWeights, *ps) case connectivity.Connecting: connectingN++ + case connectivity.Idle: + idleN++ } } var aggregatedState connectivity.State @@ -217,6 +221,8 @@ func (wbsa *Aggregator) build() balancer.State { aggregatedState = connectivity.Ready case connectingN > 0: aggregatedState = connectivity.Connecting + case idleN > 0: + aggregatedState = connectivity.Idle default: aggregatedState = connectivity.TransientFailure } diff --git a/xds/internal/balancer/weightedtarget/weightedtarget.go b/xds/internal/balancer/weightedtarget/weightedtarget.go index 02b199258cd..f05e0aca19f 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget.go @@ -26,6 +26,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/hierarchy" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/internal/wrr" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" @@ -33,22 +34,23 @@ import ( "google.golang.org/grpc/xds/internal/balancer/weightedtarget/weightedaggregator" ) -const weightedTargetName = "weighted_target_experimental" +// Name is the name of the weighted_target balancer. +const Name = "weighted_target_experimental" -// newRandomWRR is the WRR constructor used to pick sub-pickers from +// NewRandomWRR is the WRR constructor used to pick sub-pickers from // sub-balancers. It's to be modified in tests. -var newRandomWRR = wrr.NewRandom +var NewRandomWRR = wrr.NewRandom func init() { - balancer.Register(&weightedTargetBB{}) + balancer.Register(bb{}) } -type weightedTargetBB struct{} +type bb struct{} -func (wt *weightedTargetBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &weightedTargetBalancer{} b.logger = prefixLogger(b) - b.stateAggregator = weightedaggregator.New(cc, b.logger, newRandomWRR) + b.stateAggregator = weightedaggregator.New(cc, b.logger, NewRandomWRR) b.stateAggregator.Start() b.bg = balancergroup.New(cc, bOpts, b.stateAggregator, nil, b.logger) b.bg.Start() @@ -56,11 +58,11 @@ func (wt *weightedTargetBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOp return b } -func (wt *weightedTargetBB) Name() string { - return weightedTargetName +func (bb) Name() string { + return Name } -func (wt *weightedTargetBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(c) } @@ -75,14 +77,15 @@ type weightedTargetBalancer struct { bg *balancergroup.BalancerGroup stateAggregator *weightedaggregator.Aggregator - targets map[string]target + targets map[string]Target } // UpdateClientConnState takes the new targets in balancer group, -// creates/deletes sub-balancers and sends them update. Addresses are split into +// creates/deletes sub-balancers and sends them update. addresses are split into // groups based on hierarchy path. -func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - newConfig, ok := s.BalancerConfig.(*lbConfig) +func (b *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) + newConfig, ok := s.BalancerConfig.(*LBConfig) if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) } @@ -91,10 +94,10 @@ func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat var rebuildStateAndPicker bool // Remove sub-pickers and sub-balancers that are not in the new config. - for name := range w.targets { + for name := range b.targets { if _, ok := newConfig.Targets[name]; !ok { - w.stateAggregator.Remove(name) - w.bg.Remove(name) + b.stateAggregator.Remove(name) + b.bg.Remove(name) // Trigger a state/picker update, because we don't want `ClientConn` // to pick this sub-balancer anymore. rebuildStateAndPicker = true @@ -107,29 +110,39 @@ func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat // // For all sub-balancers, forward the address/balancer config update. for name, newT := range newConfig.Targets { - oldT, ok := w.targets[name] + oldT, ok := b.targets[name] if !ok { // If this is a new sub-balancer, add weights to the picker map. - w.stateAggregator.Add(name, newT.Weight) + b.stateAggregator.Add(name, newT.Weight) // Then add to the balancer group. - w.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) + b.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) // Not trigger a state/picker update. Wait for the new sub-balancer // to send its updates. + } else if newT.ChildPolicy.Name != oldT.ChildPolicy.Name { + // If the child policy name is differet, remove from balancer group + // and re-add. + b.stateAggregator.Remove(name) + b.bg.Remove(name) + b.stateAggregator.Add(name, newT.Weight) + b.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) + // Trigger a state/picker update, because we don't want `ClientConn` + // to pick this sub-balancer anymore. + rebuildStateAndPicker = true } else if newT.Weight != oldT.Weight { // If this is an existing sub-balancer, update weight if necessary. - w.stateAggregator.UpdateWeight(name, newT.Weight) + b.stateAggregator.UpdateWeight(name, newT.Weight) // Trigger a state/picker update, because we don't want `ClientConn` // should do picks with the new weights now. rebuildStateAndPicker = true } // Forwards all the update: - // - Addresses are from the map after splitting with hierarchy path, + // - addresses are from the map after splitting with hierarchy path, // - Top level service config and attributes are the same, // - Balancer config comes from the targets map. // // TODO: handle error? How to aggregate errors and return? - _ = w.bg.UpdateClientConnState(name, balancer.ClientConnState{ + _ = b.bg.UpdateClientConnState(name, balancer.ClientConnState{ ResolverState: resolver.State{ Addresses: addressesSplit[name], ServiceConfig: s.ResolverState.ServiceConfig, @@ -139,23 +152,27 @@ func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat }) } - w.targets = newConfig.Targets + b.targets = newConfig.Targets if rebuildStateAndPicker { - w.stateAggregator.BuildAndUpdate() + b.stateAggregator.BuildAndUpdate() } return nil } -func (w *weightedTargetBalancer) ResolverError(err error) { - w.bg.ResolverError(err) +func (b *weightedTargetBalancer) ResolverError(err error) { + b.bg.ResolverError(err) +} + +func (b *weightedTargetBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + b.bg.UpdateSubConnState(sc, state) } -func (w *weightedTargetBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { - w.bg.UpdateSubConnState(sc, state) +func (b *weightedTargetBalancer) Close() { + b.stateAggregator.Stop() + b.bg.Close() } -func (w *weightedTargetBalancer) Close() { - w.stateAggregator.Stop() - w.bg.Close() +func (b *weightedTargetBalancer) ExitIdle() { + b.bg.ExitIdle() } diff --git a/xds/internal/balancer/weightedtarget/weightedtarget_config.go b/xds/internal/balancer/weightedtarget/weightedtarget_config.go index 747ce918bc6..52090cd67b0 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget_config.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget_config.go @@ -25,30 +25,23 @@ import ( "google.golang.org/grpc/serviceconfig" ) -type target struct { +// Target represents one target with the weight and the child policy. +type Target struct { // Weight is the weight of the child policy. - Weight uint32 + Weight uint32 `json:"weight,omitempty"` // ChildPolicy is the child policy and it's config. - ChildPolicy *internalserviceconfig.BalancerConfig + ChildPolicy *internalserviceconfig.BalancerConfig `json:"childPolicy,omitempty"` } -// lbConfig is the balancer config for weighted_target. The proto representation -// is: -// -// message WeightedTargetConfig { -// message Target { -// uint32 weight = 1; -// repeated LoadBalancingConfig child_policy = 2; -// } -// map targets = 1; -// } -type lbConfig struct { - serviceconfig.LoadBalancingConfig - Targets map[string]target +// LBConfig is the balancer config for weighted_target. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + + Targets map[string]Target `json:"targets,omitempty"` } -func parseConfig(c json.RawMessage) (*lbConfig, error) { - var cfg lbConfig +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, err } diff --git a/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go b/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go index 2208117f60e..c239a3ae5a4 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go @@ -24,7 +24,7 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/balancer" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" + "google.golang.org/grpc/xds/internal/balancer/priority" ) const ( @@ -32,31 +32,29 @@ const ( "targets": { "cluster_1" : { "weight":75, - "childPolicy":[{"cds_experimental":{"cluster":"cluster_1"}}] + "childPolicy":[{"priority_experimental":{"priorities": ["child-1"], "children": {"child-1": {"config": [{"round_robin":{}}]}}}}] }, "cluster_2" : { "weight":25, - "childPolicy":[{"cds_experimental":{"cluster":"cluster_2"}}] + "childPolicy":[{"priority_experimental":{"priorities": ["child-2"], "children": {"child-2": {"config": [{"round_robin":{}}]}}}}] } } }` - - cdsName = "cds_experimental" ) var ( - cdsConfigParser = balancer.Get(cdsName).(balancer.ConfigParser) - cdsConfigJSON1 = `{"cluster":"cluster_1"}` - cdsConfig1, _ = cdsConfigParser.ParseConfig([]byte(cdsConfigJSON1)) - cdsConfigJSON2 = `{"cluster":"cluster_2"}` - cdsConfig2, _ = cdsConfigParser.ParseConfig([]byte(cdsConfigJSON2)) + testConfigParser = balancer.Get(priority.Name).(balancer.ConfigParser) + testConfigJSON1 = `{"priorities": ["child-1"], "children": {"child-1": {"config": [{"round_robin":{}}]}}}` + testConfig1, _ = testConfigParser.ParseConfig([]byte(testConfigJSON1)) + testConfigJSON2 = `{"priorities": ["child-2"], "children": {"child-2": {"config": [{"round_robin":{}}]}}}` + testConfig2, _ = testConfigParser.ParseConfig([]byte(testConfigJSON2)) ) func Test_parseConfig(t *testing.T) { tests := []struct { name string js string - want *lbConfig + want *LBConfig wantErr bool }{ { @@ -68,20 +66,20 @@ func Test_parseConfig(t *testing.T) { { name: "OK", js: testJSONConfig, - want: &lbConfig{ - Targets: map[string]target{ + want: &LBConfig{ + Targets: map[string]Target{ "cluster_1": { Weight: 75, ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: cdsName, - Config: cdsConfig1, + Name: priority.Name, + Config: testConfig1, }, }, "cluster_2": { Weight: 25, ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: cdsName, - Config: cdsConfig2, + Name: priority.Name, + Config: testConfig2, }, }, }, diff --git a/xds/internal/balancer/weightedtarget/weightedtarget_test.go b/xds/internal/balancer/weightedtarget/weightedtarget_test.go index 7f9e566ca5b..b0e4df89588 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget_test.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget_test.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/hierarchy" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" @@ -103,7 +104,7 @@ func init() { for i := 0; i < testBackendAddrsCount; i++ { testBackendAddrStrs = append(testBackendAddrStrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) } - wtbBuilder = balancer.Get(weightedTargetName) + wtbBuilder = balancer.Get(Name) wtbParser = wtbBuilder.(balancer.ConfigParser) balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond @@ -215,6 +216,46 @@ func TestWeightedTarget(t *testing.T) { if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { t.Fatalf("want %v, got %v", want, err) } + + // Replace child policy of "cluster_1" to "round_robin". + config3, err := wtbParser.ParseConfig([]byte(`{"targets":{"cluster_2":{"weight":1,"childPolicy":[{"round_robin":""}]}}}`)) + if err != nil { + t.Fatalf("failed to parse balancer config: %v", err) + } + + // Send the config, and an address with hierarchy path ["cluster_1"]. + wantAddr4 := resolver.Address{Addr: testBackendAddrStrs[0], Attributes: nil} + if err := wtb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + hierarchy.Set(wantAddr4, []string{"cluster_2"}), + }}, + BalancerConfig: config3, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + // Verify that a subconn is created with the address, and the hierarchy path + // in the address is cleared. + addr4 := <-cc.NewSubConnAddrsCh + if want := []resolver.Address{ + hierarchy.Set(wantAddr4, []string{}), + }; !cmp.Equal(addr4, want, cmp.AllowUnexported(attributes.Attributes{})) { + t.Fatalf("got unexpected new subconn addrs: %v", cmp.Diff(addr4, want, cmp.AllowUnexported(attributes.Attributes{}))) + } + + // Send subconn state change. + sc4 := <-cc.NewSubConnCh + wtb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + wtb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test pick with one backend. + p3 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p3.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc4, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc4) + } + } } func subConnFromPicker(p balancer.Picker) func() balancer.SubConn { @@ -223,3 +264,63 @@ func subConnFromPicker(p balancer.Picker) func() balancer.SubConn { return scst.SubConn } } + +const initIdleBalancerName = "test-init-Idle-balancer" + +var errTestInitIdle = fmt.Errorf("init Idle balancer error 0") + +func init() { + stub.Register(initIdleBalancerName, stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, opts balancer.ClientConnState) error { + bd.ClientConn.NewSubConn(opts.ResolverState.Addresses, balancer.NewSubConnOptions{}) + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { + err := fmt.Errorf("wrong picker error") + if state.ConnectivityState == connectivity.Idle { + err = errTestInitIdle + } + bd.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: &testutils.TestConstPicker{Err: err}, + }) + }, + }) +} + +// TestInitialIdle covers the case that if the child reports Idle, the overall +// state will be Idle. +func TestInitialIdle(t *testing.T) { + cc := testutils.NewTestClientConn(t) + wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) + + // Start with "cluster_1: round_robin". + config1, err := wtbParser.ParseConfig([]byte(`{"targets":{"cluster_1":{"weight":1,"childPolicy":[{"test-init-Idle-balancer":""}]}}}`)) + if err != nil { + t.Fatalf("failed to parse balancer config: %v", err) + } + + // Send the config, and an address with hierarchy path ["cluster_1"]. + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0], Attributes: nil}, + } + if err := wtb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + hierarchy.Set(wantAddrs[0], []string{"cds:cluster_1"}), + }}, + BalancerConfig: config1, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + // Verify that a subconn is created with the address, and the hierarchy path + // in the address is cleared. + for range wantAddrs { + sc := <-cc.NewSubConnCh + wtb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + } + + if state1 := <-cc.NewStateCh; state1 != connectivity.Idle { + t.Fatalf("Received aggregated state: %v, want Idle", state1) + } +} diff --git a/xds/internal/client/cds_test.go b/xds/internal/client/cds_test.go deleted file mode 100644 index c5f1d76d32c..00000000000 --- a/xds/internal/client/cds_test.go +++ /dev/null @@ -1,833 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "testing" - - v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" - v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" - v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" - "github.com/golang/protobuf/proto" - anypb "github.com/golang/protobuf/ptypes/any" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/version" - "google.golang.org/protobuf/types/known/wrapperspb" -) - -const ( - clusterName = "clusterName" - serviceName = "service" -) - -var emptyUpdate = ClusterUpdate{ServiceName: "", EnableLRS: false} - -func (s) TestValidateCluster_Failure(t *testing.T) { - tests := []struct { - name string - cluster *v3clusterpb.Cluster - wantUpdate ClusterUpdate - wantErr bool - }{ - { - name: "non-eds-cluster-type", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - }, - LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, - }, - wantUpdate: emptyUpdate, - wantErr: true, - }, - { - name: "no-eds-config", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - }, - wantUpdate: emptyUpdate, - wantErr: true, - }, - { - name: "no-ads-config-source", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{}, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - }, - wantUpdate: emptyUpdate, - wantErr: true, - }, - { - name: "non-round-robin-lb-policy", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - }, - LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, - }, - wantUpdate: emptyUpdate, - wantErr: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if update, err := validateCluster(test.cluster); err == nil { - t.Errorf("validateCluster(%+v) = %v, wanted error", test.cluster, update) - } - }) - } -} - -func (s) TestValidateCluster_Success(t *testing.T) { - tests := []struct { - name string - cluster *v3clusterpb.Cluster - wantUpdate ClusterUpdate - }{ - { - name: "happy-case-no-service-name-no-lrs", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - }, - wantUpdate: emptyUpdate, - }, - { - name: "happy-case-no-lrs", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - }, - wantUpdate: ClusterUpdate{ServiceName: serviceName, EnableLRS: false}, - }, - { - name: "happiest-case", - cluster: &v3clusterpb.Cluster{ - Name: clusterName, - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - LrsServer: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ - Self: &v3corepb.SelfConfigSource{}, - }, - }, - }, - wantUpdate: ClusterUpdate{ServiceName: serviceName, EnableLRS: true}, - }, - { - name: "happiest-case-with-circuitbreakers", - cluster: &v3clusterpb.Cluster{ - Name: clusterName, - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - CircuitBreakers: &v3clusterpb.CircuitBreakers{ - Thresholds: []*v3clusterpb.CircuitBreakers_Thresholds{ - { - Priority: v3corepb.RoutingPriority_DEFAULT, - MaxRequests: wrapperspb.UInt32(512), - }, - { - Priority: v3corepb.RoutingPriority_HIGH, - MaxRequests: nil, - }, - }, - }, - LrsServer: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ - Self: &v3corepb.SelfConfigSource{}, - }, - }, - }, - wantUpdate: ClusterUpdate{ServiceName: serviceName, EnableLRS: true, MaxRequests: func() *uint32 { i := uint32(512); return &i }()}, - }, - } - - origCircuitBreakingSupport := env.CircuitBreakingSupport - env.CircuitBreakingSupport = true - defer func() { env.CircuitBreakingSupport = origCircuitBreakingSupport }() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - update, err := validateCluster(test.cluster) - if err != nil { - t.Errorf("validateCluster(%+v) failed: %v", test.cluster, err) - } - if !cmp.Equal(update, test.wantUpdate, cmpopts.EquateEmpty()) { - t.Errorf("validateCluster(%+v) = %v, want: %v", test.cluster, update, test.wantUpdate) - } - }) - } -} - -func (s) TestValidateClusterWithSecurityConfig_EnvVarOff(t *testing.T) { - // Turn off the env var protection for client-side security. - origClientSideSecurityEnvVar := env.ClientSideSecuritySupport - env.ClientSideSecuritySupport = false - defer func() { env.ClientSideSecuritySupport = origClientSideSecurityEnvVar }() - - cluster := &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "rootInstance", - CertificateName: "rootCert", - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - } - wantUpdate := ClusterUpdate{ - ServiceName: serviceName, - EnableLRS: false, - } - gotUpdate, err := validateCluster(cluster) - if err != nil { - t.Errorf("validateCluster() failed: %v", err) - } - if diff := cmp.Diff(wantUpdate, gotUpdate); diff != "" { - t.Errorf("validateCluster() returned unexpected diff (-want, got):\n%s", diff) - } -} - -func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { - // Turn on the env var protection for client-side security. - origClientSideSecurityEnvVar := env.ClientSideSecuritySupport - env.ClientSideSecuritySupport = true - defer func() { env.ClientSideSecuritySupport = origClientSideSecurityEnvVar }() - - const ( - identityPluginInstance = "identityPluginInstance" - identityCertName = "identityCert" - rootPluginInstance = "rootPluginInstance" - rootCertName = "rootCert" - serviceName = "service" - san1 = "san1" - san2 = "san2" - ) - - tests := []struct { - name string - cluster *v3clusterpb.Cluster - wantUpdate ClusterUpdate - wantErr bool - }{ - { - name: "transport-socket-unsupported-name", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "unsupported-foo", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - }, - }, - }, - }, - wantErr: true, - }, - { - name: "transport-socket-unsupported-typeURL", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3HTTPConnManagerURL, - }, - }, - }, - }, - wantErr: true, - }, - { - name: "transport-socket-unsupported-type", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: []byte{1, 2, 3, 4}, - }, - }, - }, - }, - wantErr: true, - }, - { - name: "transport-socket-unsupported-validation-context", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ - ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ - Name: "foo-sds-secret", - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantErr: true, - }, - { - name: "transport-socket-without-validation-context", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{}, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantErr: true, - }, - { - name: "happy-case-with-no-identity-certs", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: rootPluginInstance, - CertificateName: rootCertName, - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantUpdate: ClusterUpdate{ - ServiceName: serviceName, - EnableLRS: false, - SecurityCfg: &SecurityConfig{ - RootInstanceName: rootPluginInstance, - RootCertName: rootCertName, - }, - }, - }, - { - name: "happy-case-with-validation-context-provider-instance", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: identityPluginInstance, - CertificateName: identityCertName, - }, - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: rootPluginInstance, - CertificateName: rootCertName, - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantUpdate: ClusterUpdate{ - ServiceName: serviceName, - EnableLRS: false, - SecurityCfg: &SecurityConfig{ - RootInstanceName: rootPluginInstance, - RootCertName: rootCertName, - IdentityInstanceName: identityPluginInstance, - IdentityCertName: identityCertName, - }, - }, - }, - { - name: "happy-case-with-combined-validation-context", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: identityPluginInstance, - CertificateName: identityCertName, - }, - ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ - CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ - DefaultValidationContext: &v3tlspb.CertificateValidationContext{ - MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ - {MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: san1}}, - {MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: san2}}, - }, - }, - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: rootPluginInstance, - CertificateName: rootCertName, - }, - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantUpdate: ClusterUpdate{ - ServiceName: serviceName, - EnableLRS: false, - SecurityCfg: &SecurityConfig{ - RootInstanceName: rootPluginInstance, - RootCertName: rootCertName, - IdentityInstanceName: identityPluginInstance, - IdentityCertName: identityCertName, - AcceptedSANs: []string{san1, san2}, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - update, err := validateCluster(test.cluster) - if ((err != nil) != test.wantErr) || !cmp.Equal(update, test.wantUpdate, cmpopts.EquateEmpty()) { - t.Errorf("validateCluster(%+v) = (%+v, %v), want: (%+v, %v)", test.cluster, update, err, test.wantUpdate, test.wantErr) - } - }) - } -} - -func (s) TestUnmarshalCluster(t *testing.T) { - const ( - v2ClusterName = "v2clusterName" - v3ClusterName = "v3clusterName" - v2Service = "v2Service" - v3Service = "v2Service" - ) - var ( - v2Cluster = &v2xdspb.Cluster{ - Name: v2ClusterName, - ClusterDiscoveryType: &v2xdspb.Cluster_Type{Type: v2xdspb.Cluster_EDS}, - EdsClusterConfig: &v2xdspb.Cluster_EdsClusterConfig{ - EdsConfig: &v2corepb.ConfigSource{ - ConfigSourceSpecifier: &v2corepb.ConfigSource_Ads{ - Ads: &v2corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: v2Service, - }, - LbPolicy: v2xdspb.Cluster_ROUND_ROBIN, - LrsServer: &v2corepb.ConfigSource{ - ConfigSourceSpecifier: &v2corepb.ConfigSource_Self{ - Self: &v2corepb.SelfConfigSource{}, - }, - }, - } - v2ClusterAny = &anypb.Any{ - TypeUrl: version.V2ClusterURL, - Value: func() []byte { - mcl, _ := proto.Marshal(v2Cluster) - return mcl - }(), - } - - v3Cluster = &v3clusterpb.Cluster{ - Name: v3ClusterName, - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: v3Service, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - LrsServer: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ - Self: &v3corepb.SelfConfigSource{}, - }, - }, - } - v3ClusterAny = &anypb.Any{ - TypeUrl: version.V3ClusterURL, - Value: func() []byte { - mcl, _ := proto.Marshal(v3Cluster) - return mcl - }(), - } - ) - const testVersion = "test-version-cds" - - tests := []struct { - name string - resources []*anypb.Any - wantUpdate map[string]ClusterUpdate - wantMD UpdateMetadata - wantErr bool - }{ - { - name: "non-cluster resource type", - resources: []*anypb.Any{{TypeUrl: version.V3HTTPConnManagerURL}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "badly marshaled cluster resource", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ClusterURL, - Value: []byte{1, 2, 3, 4}, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "bad cluster resource", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ClusterURL, - Value: func() []byte { - cl := &v3clusterpb.Cluster{ - Name: "test", - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, - } - mcl, _ := proto.Marshal(cl) - return mcl - }(), - }, - }, - wantUpdate: map[string]ClusterUpdate{"test": {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v2 cluster", - resources: []*anypb.Any{v2ClusterAny}, - wantUpdate: map[string]ClusterUpdate{ - v2ClusterName: { - ServiceName: v2Service, EnableLRS: true, - Raw: v2ClusterAny, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 cluster", - resources: []*anypb.Any{v3ClusterAny}, - wantUpdate: map[string]ClusterUpdate{ - v3ClusterName: { - ServiceName: v3Service, EnableLRS: true, - Raw: v3ClusterAny, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "multiple clusters", - resources: []*anypb.Any{v2ClusterAny, v3ClusterAny}, - wantUpdate: map[string]ClusterUpdate{ - v2ClusterName: { - ServiceName: v2Service, EnableLRS: true, - Raw: v2ClusterAny, - }, - v3ClusterName: { - ServiceName: v3Service, EnableLRS: true, - Raw: v3ClusterAny, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - // To test that unmarshal keeps processing on errors. - name: "good and bad clusters", - resources: []*anypb.Any{ - v2ClusterAny, - { - // bad cluster resource - TypeUrl: version.V3ClusterURL, - Value: func() []byte { - cl := &v3clusterpb.Cluster{ - Name: "bad", - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, - } - mcl, _ := proto.Marshal(cl) - return mcl - }(), - }, - v3ClusterAny, - }, - wantUpdate: map[string]ClusterUpdate{ - v2ClusterName: { - ServiceName: v2Service, EnableLRS: true, - Raw: v2ClusterAny, - }, - v3ClusterName: { - ServiceName: v3Service, EnableLRS: true, - Raw: v3ClusterAny, - }, - "bad": {}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - update, md, err := UnmarshalCluster(testVersion, test.resources, nil) - if (err != nil) != test.wantErr { - t.Fatalf("UnmarshalCluster(), got err: %v, wantErr: %v", err, test.wantErr) - } - if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { - t.Errorf("got unexpected update, diff (-got +want): %v", diff) - } - if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { - t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) - } - }) - } -} diff --git a/xds/internal/client/lds_test.go b/xds/internal/client/lds_test.go deleted file mode 100644 index 26e79b78d13..00000000000 --- a/xds/internal/client/lds_test.go +++ /dev/null @@ -1,1628 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "fmt" - "strings" - "testing" - "time" - - v1typepb "github.com/cncf/udpa/go/udpa/type/v1" - v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" - spb "github.com/golang/protobuf/ptypes/struct" - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/httpfilter" - "google.golang.org/grpc/xds/internal/version" - "google.golang.org/protobuf/types/known/durationpb" - - v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" - v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - v2httppb "github.com/envoyproxy/go-control-plane/envoy/config/filter/network/http_connection_manager/v2" - v2listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v2" - v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" - v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" - v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" - anypb "github.com/golang/protobuf/ptypes/any" - wrapperspb "github.com/golang/protobuf/ptypes/wrappers" -) - -func (s) TestUnmarshalListener_ClientSide(t *testing.T) { - const ( - v2LDSTarget = "lds.target.good:2222" - v3LDSTarget = "lds.target.good:3333" - v2RouteConfigName = "v2RouteConfig" - v3RouteConfigName = "v3RouteConfig" - ) - - var ( - v2Lis = &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v2httppb.HttpConnectionManager{ - RouteSpecifier: &v2httppb.HttpConnectionManager_Rds{ - Rds: &v2httppb.Rds{ - ConfigSource: &v2corepb.ConfigSource{ - ConfigSourceSpecifier: &v2corepb.ConfigSource_Ads{Ads: &v2corepb.AggregatedConfigSource{}}, - }, - RouteConfigName: v2RouteConfigName, - }, - }, - } - mcm, _ := proto.Marshal(cm) - lis := &v2xdspb.Listener{ - Name: v2LDSTarget, - ApiListener: &v2listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2HTTPConnManagerURL, - Value: mcm, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - customFilter = &v3httppb.HttpFilter{ - Name: "customFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, - } - typedStructFilter = &v3httppb.HttpFilter{ - Name: "customFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: wrappedCustomFilterTypedStructConfig}, - } - customOptionalFilter = &v3httppb.HttpFilter{ - Name: "customFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, - IsOptional: true, - } - customFilter2 = &v3httppb.HttpFilter{ - Name: "customFilter2", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, - } - errFilter = &v3httppb.HttpFilter{ - Name: "errFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: errFilterConfig}, - } - errOptionalFilter = &v3httppb.HttpFilter{ - Name: "errFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: errFilterConfig}, - IsOptional: true, - } - clientOnlyCustomFilter = &v3httppb.HttpFilter{ - Name: "clientOnlyCustomFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, - } - serverOnlyCustomFilter = &v3httppb.HttpFilter{ - Name: "serverOnlyCustomFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, - } - serverOnlyOptionalCustomFilter = &v3httppb.HttpFilter{ - Name: "serverOnlyOptionalCustomFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, - IsOptional: true, - } - unknownFilter = &v3httppb.HttpFilter{ - Name: "unknownFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: unknownFilterConfig}, - } - unknownOptionalFilter = &v3httppb.HttpFilter{ - Name: "unknownFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: unknownFilterConfig}, - IsOptional: true, - } - v3LisWithFilters = func(fs ...*v3httppb.HttpFilter) *anypb.Any { - hcm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{ - ConfigSource: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, - }, - RouteConfigName: v3RouteConfigName, - }, - }, - CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ - MaxStreamDuration: durationpb.New(time.Second), - }, - HttpFilters: fs, - } - return &anypb.Any{ - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - mcm, _ := ptypes.MarshalAny(hcm) - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: mcm, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - } - ) - const testVersion = "test-version-lds-client" - - tests := []struct { - name string - resources []*anypb.Any - wantUpdate map[string]ListenerUpdate - wantMD UpdateMetadata - wantErr bool - disableFI bool // disable fault injection - }{ - { - name: "non-listener resource", - resources: []*anypb.Any{{TypeUrl: version.V3HTTPConnManagerURL}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "badly marshaled listener resource", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V3HTTPConnManagerURL, - Value: []byte{1, 2, 3, 4}, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "wrong type in apiListener", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{ - ConfigSource: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, - }, - RouteConfigName: v3RouteConfigName, - }, - }, - } - mcm, _ := proto.Marshal(cm) - return mcm - }(), - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "empty httpConnMgr in apiListener", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{}, - }, - } - mcm, _ := proto.Marshal(cm) - return mcm - }(), - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "scopedRoutes routeConfig in apiListener", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, - } - mcm, _ := proto.Marshal(cm) - return mcm - }(), - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "rds.ConfigSource in apiListener is not ADS", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{ - ConfigSource: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Path{ - Path: "/some/path", - }, - }, - RouteConfigName: v3RouteConfigName, - }, - }, - } - mcm, _ := proto.Marshal(cm) - return mcm - }(), - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "empty resource list", - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with no filters", - resources: []*anypb.Any{v3LisWithFilters()}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters()}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with custom filter", - resources: []*anypb.Any{v3LisWithFilters(customFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "customFilter", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterConfig}, - }}, - Raw: v3LisWithFilters(customFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with custom filter in typed struct", - resources: []*anypb.Any{v3LisWithFilters(typedStructFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "customFilter", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterTypedStructConfig}, - }}, - Raw: v3LisWithFilters(typedStructFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with optional custom filter", - resources: []*anypb.Any{v3LisWithFilters(customOptionalFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "customFilter", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterConfig}, - }}, - Raw: v3LisWithFilters(customOptionalFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with custom filter, fault injection disabled", - resources: []*anypb.Any{v3LisWithFilters(customFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters(customFilter)}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - disableFI: true, - }, - { - name: "v3 with two filters with same name", - resources: []*anypb.Any{v3LisWithFilters(customFilter, customFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with two filters - same type different name", - resources: []*anypb.Any{v3LisWithFilters(customFilter, customFilter2)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "customFilter", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterConfig}, - }, { - Name: "customFilter2", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterConfig}, - }}, - Raw: v3LisWithFilters(customFilter, customFilter2), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with server-only filter", - resources: []*anypb.Any{v3LisWithFilters(serverOnlyCustomFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with optional server-only filter", - resources: []*anypb.Any{v3LisWithFilters(serverOnlyOptionalCustomFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, - MaxStreamDuration: time.Second, - Raw: v3LisWithFilters(serverOnlyOptionalCustomFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with client-only filter", - resources: []*anypb.Any{v3LisWithFilters(clientOnlyCustomFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "clientOnlyCustomFilter", - Filter: clientOnlyHTTPFilter{}, - Config: filterConfig{Cfg: clientOnlyCustomFilterConfig}, - }}, - Raw: v3LisWithFilters(clientOnlyCustomFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with err filter", - resources: []*anypb.Any{v3LisWithFilters(errFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with optional err filter", - resources: []*anypb.Any{v3LisWithFilters(errOptionalFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with unknown filter", - resources: []*anypb.Any{v3LisWithFilters(unknownFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with unknown filter (optional)", - resources: []*anypb.Any{v3LisWithFilters(unknownOptionalFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, - MaxStreamDuration: time.Second, - Raw: v3LisWithFilters(unknownOptionalFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with error filter, fault injection disabled", - resources: []*anypb.Any{v3LisWithFilters(errFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, - MaxStreamDuration: time.Second, - Raw: v3LisWithFilters(errFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - disableFI: true, - }, - { - name: "v2 listener resource", - resources: []*anypb.Any{v2Lis}, - wantUpdate: map[string]ListenerUpdate{ - v2LDSTarget: {RouteConfigName: v2RouteConfigName, Raw: v2Lis}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 listener resource", - resources: []*anypb.Any{v3LisWithFilters()}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters()}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "multiple listener resources", - resources: []*anypb.Any{v2Lis, v3LisWithFilters()}, - wantUpdate: map[string]ListenerUpdate{ - v2LDSTarget: {RouteConfigName: v2RouteConfigName, Raw: v2Lis}, - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters()}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - // To test that unmarshal keeps processing on errors. - name: "good and bad listener resources", - resources: []*anypb.Any{ - v2Lis, - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: "bad", - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, - } - mcm, _ := proto.Marshal(cm) - return mcm - }()}}} - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - v3LisWithFilters(), - }, - wantUpdate: map[string]ListenerUpdate{ - v2LDSTarget: {RouteConfigName: v2RouteConfigName, Raw: v2Lis}, - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters()}, - "bad": {}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = !test.disableFI - - update, md, err := UnmarshalListener(testVersion, test.resources, nil) - if (err != nil) != test.wantErr { - t.Fatalf("UnmarshalListener(), got err: %v, wantErr: %v", err, test.wantErr) - } - if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { - t.Errorf("got unexpected update, diff (-got +want): %v", diff) - } - if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { - t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) - } - env.FaultInjectionSupport = oldFI - }) - } -} - -func (s) TestUnmarshalListener_ServerSide(t *testing.T) { - const v3LDSTarget = "grpc/server?udpa.resource.listening_address=0.0.0.0:9999" - - var ( - listenerEmptyTransportSocket = &anypb.Any{ - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - listenerNoValidationContext = &anypb.Any{ - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "identityPluginInstance", - CertificateName: "identityCertName", - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - listenerWithValidationContext = &anypb.Any{ - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "identityPluginInstance", - CertificateName: "identityCertName", - }, - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "rootPluginInstance", - CertificateName: "rootCertName", - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - ) - - const testVersion = "test-version-lds-server" - - tests := []struct { - name string - resources []*anypb.Any - wantUpdate map[string]ListenerUpdate - wantMD UpdateMetadata - wantErr string - }{ - { - name: "no address field", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "no address field in LDS response", - }, - { - name: "no socket address field", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{}, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "no socket_address field in LDS response", - }, - { - name: "listener name does not match expected format", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: "foo", - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{"foo": {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "no host:port in name field of LDS response", - }, - { - name: "host mismatch", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "1.2.3.4", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "socket_address host does not match the one in name", - }, - { - name: "port mismatch", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 1234, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "socket_address port does not match the one in name", - }, - { - name: "unexpected number of filter chains", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - {Name: "filter-chain-1"}, - {Name: "filter-chain-2"}, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "filter chains count in LDS response does not match expected", - }, - { - name: "unexpected transport socket name", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "unsupported-transport-socket-name", - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "transport_socket field has unexpected name", - }, - { - name: "unexpected transport socket typedConfig URL", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "transport_socket field has unexpected typeURL", - }, - { - name: "badly marshaled transport socket", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: []byte{1, 2, 3, 4}, - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "failed to unmarshal DownstreamTlsContext in LDS response", - }, - { - name: "missing CommonTlsContext", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{} - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "DownstreamTlsContext in LDS response does not contain a CommonTlsContext", - }, - { - name: "unsupported validation context in transport socket", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ - ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ - Name: "foo-sds-secret", - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "validation context contains unexpected type", - }, - { - name: "empty transport socket", - resources: []*anypb.Any{listenerEmptyTransportSocket}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: {Raw: listenerEmptyTransportSocket}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "no identity and root certificate providers", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "identityPluginInstance", - CertificateName: "identityCertName", - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set", - }, - { - name: "no identity certificate provider with require_client_cert", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{}, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "security configuration on the server-side does not contain identity certificate provider instance name", - }, - { - name: "happy case with no validation context", - resources: []*anypb.Any{listenerNoValidationContext}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - SecurityCfg: &SecurityConfig{ - IdentityInstanceName: "identityPluginInstance", - IdentityCertName: "identityCertName", - }, - Raw: listenerNoValidationContext, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "happy case with validation context provider instance", - resources: []*anypb.Any{listenerWithValidationContext}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - SecurityCfg: &SecurityConfig{ - RootInstanceName: "rootPluginInstance", - RootCertName: "rootCertName", - IdentityInstanceName: "identityPluginInstance", - IdentityCertName: "identityCertName", - RequireClientCert: true, - }, - Raw: listenerWithValidationContext, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - gotUpdate, md, err := UnmarshalListener(testVersion, test.resources, nil) - if (err != nil) != (test.wantErr != "") { - t.Fatalf("UnmarshalListener(), got err: %v, wantErr: %v", err, test.wantErr) - } - if err != nil && !strings.Contains(err.Error(), test.wantErr) { - t.Fatalf("UnmarshalListener() = %v wantErr: %q", err, test.wantErr) - } - if diff := cmp.Diff(gotUpdate, test.wantUpdate, cmpOpts); diff != "" { - t.Errorf("got unexpected update, diff (-got +want): %v", diff) - } - if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { - t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) - } - }) - } -} - -type filterConfig struct { - httpfilter.FilterConfig - Cfg proto.Message - Override proto.Message -} - -// httpFilter allows testing the http filter registry and parsing functionality. -type httpFilter struct { - httpfilter.ClientInterceptorBuilder - httpfilter.ServerInterceptorBuilder -} - -func (httpFilter) TypeURLs() []string { return []string{"custom.filter"} } - -func (httpFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Cfg: cfg}, nil -} - -func (httpFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Override: override}, nil -} - -// errHTTPFilter returns errors no matter what is passed to ParseFilterConfig. -type errHTTPFilter struct { - httpfilter.ClientInterceptorBuilder -} - -func (errHTTPFilter) TypeURLs() []string { return []string{"err.custom.filter"} } - -func (errHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { - return nil, fmt.Errorf("error from ParseFilterConfig") -} - -func (errHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { - return nil, fmt.Errorf("error from ParseFilterConfigOverride") -} - -func init() { - httpfilter.Register(httpFilter{}) - httpfilter.Register(errHTTPFilter{}) - httpfilter.Register(serverOnlyHTTPFilter{}) - httpfilter.Register(clientOnlyHTTPFilter{}) -} - -// serverOnlyHTTPFilter does not implement ClientInterceptorBuilder -type serverOnlyHTTPFilter struct { - httpfilter.ServerInterceptorBuilder -} - -func (serverOnlyHTTPFilter) TypeURLs() []string { return []string{"serverOnly.custom.filter"} } - -func (serverOnlyHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Cfg: cfg}, nil -} - -func (serverOnlyHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Override: override}, nil -} - -// clientOnlyHTTPFilter does not implement ServerInterceptorBuilder -type clientOnlyHTTPFilter struct { - httpfilter.ClientInterceptorBuilder -} - -func (clientOnlyHTTPFilter) TypeURLs() []string { return []string{"clientOnly.custom.filter"} } - -func (clientOnlyHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Cfg: cfg}, nil -} - -func (clientOnlyHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Override: override}, nil -} - -var customFilterConfig = &anypb.Any{ - TypeUrl: "custom.filter", - Value: []byte{1, 2, 3}, -} - -var errFilterConfig = &anypb.Any{ - TypeUrl: "err.custom.filter", - Value: []byte{1, 2, 3}, -} - -var serverOnlyCustomFilterConfig = &anypb.Any{ - TypeUrl: "serverOnly.custom.filter", - Value: []byte{1, 2, 3}, -} - -var clientOnlyCustomFilterConfig = &anypb.Any{ - TypeUrl: "clientOnly.custom.filter", - Value: []byte{1, 2, 3}, -} - -var customFilterTypedStructConfig = &v1typepb.TypedStruct{ - TypeUrl: "custom.filter", - Value: &spb.Struct{ - Fields: map[string]*spb.Value{ - "foo": {Kind: &spb.Value_StringValue{StringValue: "bar"}}, - }, - }, -} -var wrappedCustomFilterTypedStructConfig *anypb.Any - -func init() { - var err error - wrappedCustomFilterTypedStructConfig, err = ptypes.MarshalAny(customFilterTypedStructConfig) - if err != nil { - panic(err.Error()) - } -} - -var unknownFilterConfig = &anypb.Any{ - TypeUrl: "unknown.custom.filter", - Value: []byte{1, 2, 3}, -} - -func wrappedOptionalFilter(name string) *anypb.Any { - filter := &v3routepb.FilterConfig{ - IsOptional: true, - Config: &anypb.Any{ - TypeUrl: name, - Value: []byte{1, 2, 3}, - }, - } - w, err := ptypes.MarshalAny(filter) - if err != nil { - panic("error marshalling any: " + err.Error()) - } - return w -} diff --git a/xds/internal/client/requests_counter.go b/xds/internal/client/requests_counter.go deleted file mode 100644 index 7ef18345ed6..00000000000 --- a/xds/internal/client/requests_counter.go +++ /dev/null @@ -1,82 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "fmt" - "sync" - "sync/atomic" -) - -type servicesRequestsCounter struct { - mu sync.Mutex - services map[string]*ServiceRequestsCounter -} - -var src = &servicesRequestsCounter{ - services: make(map[string]*ServiceRequestsCounter), -} - -// ServiceRequestsCounter is used to track the total inflight requests for a -// service with the provided name. -type ServiceRequestsCounter struct { - ServiceName string - numRequests uint32 -} - -// GetServiceRequestsCounter returns the ServiceRequestsCounter with the -// provided serviceName. If one does not exist, it creates it. -func GetServiceRequestsCounter(serviceName string) *ServiceRequestsCounter { - src.mu.Lock() - defer src.mu.Unlock() - c, ok := src.services[serviceName] - if !ok { - c = &ServiceRequestsCounter{ServiceName: serviceName} - src.services[serviceName] = c - } - return c -} - -// StartRequest starts a request for a service, incrementing its number of -// requests by 1. Returns an error if the max number of requests is exceeded. -func (c *ServiceRequestsCounter) StartRequest(max uint32) error { - if atomic.LoadUint32(&c.numRequests) >= max { - return fmt.Errorf("max requests %v exceeded on service %v", max, c.ServiceName) - } - atomic.AddUint32(&c.numRequests, 1) - return nil -} - -// EndRequest ends a request for a service, decrementing its number of requests -// by 1. -func (c *ServiceRequestsCounter) EndRequest() { - atomic.AddUint32(&c.numRequests, ^uint32(0)) -} - -// ClearCounterForTesting clears the counter for the service. Should be only -// used in tests. -func ClearCounterForTesting(serviceName string) { - src.mu.Lock() - defer src.mu.Unlock() - c, ok := src.services[serviceName] - if !ok { - return - } - c.numRequests = 0 -} diff --git a/xds/internal/client/singleton.go b/xds/internal/client/singleton.go deleted file mode 100644 index 5d92b4146bb..00000000000 --- a/xds/internal/client/singleton.go +++ /dev/null @@ -1,101 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "fmt" - "sync" - "time" - - "google.golang.org/grpc/xds/internal/client/bootstrap" -) - -const defaultWatchExpiryTimeout = 15 * time.Second - -// This is the Client returned by New(). It contains one client implementation, -// and maintains the refcount. -var singletonClient = &Client{} - -// To override in tests. -var bootstrapNewConfig = bootstrap.NewConfig - -// Client is a full fledged gRPC client which queries a set of discovery APIs -// (collectively termed as xDS) on a remote management server, to discover -// various dynamic resources. -// -// The xds client is a singleton. It will be shared by the xds resolver and -// balancer implementations, across multiple ClientConns and Servers. -type Client struct { - *clientImpl - - // This mu protects all the fields, including the embedded clientImpl above. - mu sync.Mutex - refCount int -} - -// New returns a new xdsClient configured by the bootstrap file specified in env -// variable GRPC_XDS_BOOTSTRAP. -func New() (*Client, error) { - singletonClient.mu.Lock() - defer singletonClient.mu.Unlock() - // If the client implementation was created, increment ref count and return - // the client. - if singletonClient.clientImpl != nil { - singletonClient.refCount++ - return singletonClient, nil - } - - // Create the new client implementation. - config, err := bootstrapNewConfig() - if err != nil { - return nil, fmt.Errorf("xds: failed to read bootstrap file: %v", err) - } - c, err := newWithConfig(config, defaultWatchExpiryTimeout) - if err != nil { - return nil, err - } - - singletonClient.clientImpl = c - singletonClient.refCount++ - return singletonClient, nil -} - -// Close closes the client. It does ref count of the xds client implementation, -// and closes the gRPC connection to the management server when ref count -// reaches 0. -func (c *Client) Close() { - c.mu.Lock() - defer c.mu.Unlock() - c.refCount-- - if c.refCount == 0 { - c.clientImpl.Close() - // Set clientImpl back to nil. So if New() is called after this, a new - // implementation will be created. - c.clientImpl = nil - } -} - -// NewWithConfigForTesting is exported for testing only. -func NewWithConfigForTesting(config *bootstrap.Config, watchExpiryTimeout time.Duration) (*Client, error) { - cl, err := newWithConfig(config, watchExpiryTimeout) - if err != nil { - return nil, err - } - return &Client{clientImpl: cl, refCount: 1}, nil -} diff --git a/xds/internal/client/tests/README.md b/xds/internal/client/tests/README.md deleted file mode 100644 index 6dc940c103f..00000000000 --- a/xds/internal/client/tests/README.md +++ /dev/null @@ -1 +0,0 @@ -This package contains tests which cannot live in the `client` package because they need to import one of the API client packages (which itself has a dependency on the `client` package). diff --git a/xds/internal/client/watchers_listener_test.go b/xds/internal/client/watchers_listener_test.go deleted file mode 100644 index bf3a122da07..00000000000 --- a/xds/internal/client/watchers_listener_test.go +++ /dev/null @@ -1,358 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "context" - "testing" - - "google.golang.org/grpc/internal/testutils" -) - -type ldsUpdateErr struct { - u ListenerUpdate - err error -} - -// TestLDSWatch covers the cases: -// - an update is received after a watch() -// - an update for another resource name -// - an update is received after cancel() -func (s) TestLDSWatch(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - ldsUpdateCh := testutils.NewChannel() - cancelWatch := client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { - ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - - wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate); err != nil { - t.Fatal(err) - } - - // Another update, with an extra resource for a different resource name. - client.NewListeners(map[string]ListenerUpdate{ - testLDSName: wantUpdate, - "randomName": {}, - }, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate); err != nil { - t.Fatal(err) - } - - // Cancel watch, and send update again. - cancelWatch() - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) - } -} - -// TestLDSTwoWatchSameResourceName covers the case where an update is received -// after two watch() for the same resource name. -func (s) TestLDSTwoWatchSameResourceName(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - const count = 2 - var ( - ldsUpdateChs []*testutils.Channel - cancelLastWatch func() - ) - - for i := 0; i < count; i++ { - ldsUpdateCh := testutils.NewChannel() - ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) - cancelLastWatch = client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { - ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) - }) - - if i == 0 { - // A new watch is registered on the underlying API client only for - // the first iteration because we are using the same resource name. - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - } - } - - wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count; i++ { - if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } - } - - // Cancel the last watch, and send update again. - cancelLastWatch() - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count-1; i++ { - if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } - } - - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := ldsUpdateChs[count-1].Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) - } -} - -// TestLDSThreeWatchDifferentResourceName covers the case where an update is -// received after three watch() for different resource names. -func (s) TestLDSThreeWatchDifferentResourceName(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - var ldsUpdateChs []*testutils.Channel - const count = 2 - - // Two watches for the same name. - for i := 0; i < count; i++ { - ldsUpdateCh := testutils.NewChannel() - ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) - client.WatchListener(testLDSName+"1", func(update ListenerUpdate, err error) { - ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) - }) - - if i == 0 { - // A new watch is registered on the underlying API client only for - // the first iteration because we are using the same resource name. - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - } - } - - // Third watch for a different name. - ldsUpdateCh2 := testutils.NewChannel() - client.WatchListener(testLDSName+"2", func(update ListenerUpdate, err error) { - ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - - wantUpdate1 := ListenerUpdate{RouteConfigName: testRDSName + "1"} - wantUpdate2 := ListenerUpdate{RouteConfigName: testRDSName + "2"} - client.NewListeners(map[string]ListenerUpdate{ - testLDSName + "1": wantUpdate1, - testLDSName + "2": wantUpdate2, - }, UpdateMetadata{}) - - for i := 0; i < count; i++ { - if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate1); err != nil { - t.Fatal(err) - } - } - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) - } -} - -// TestLDSWatchAfterCache covers the case where watch is called after the update -// is in cache. -func (s) TestLDSWatchAfterCache(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - ldsUpdateCh := testutils.NewChannel() - client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { - ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - - wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate); err != nil { - t.Fatal(err) - } - - // Another watch for the resource in cache. - ldsUpdateCh2 := testutils.NewChannel() - client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { - ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err}) - }) - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if n, err := apiClient.addWatches[ListenerResource].Receive(sCtx); err != context.DeadlineExceeded { - t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err) - } - - // New watch should receive the update. - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate); err != nil { - t.Fatal(err) - } - - // Old watch should see nothing. - sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) - } -} - -// TestLDSResourceRemoved covers the cases: -// - an update is received after a watch() -// - another update is received, with one resource removed -// - this should trigger callback with resource removed error -// - one more update without the removed resource -// - the callback (above) shouldn't receive any update -func (s) TestLDSResourceRemoved(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - ldsUpdateCh1 := testutils.NewChannel() - client.WatchListener(testLDSName+"1", func(update ListenerUpdate, err error) { - ldsUpdateCh1.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - // Another watch for a different name. - ldsUpdateCh2 := testutils.NewChannel() - client.WatchListener(testLDSName+"2", func(update ListenerUpdate, err error) { - ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - - wantUpdate1 := ListenerUpdate{RouteConfigName: testEDSName + "1"} - wantUpdate2 := ListenerUpdate{RouteConfigName: testEDSName + "2"} - client.NewListeners(map[string]ListenerUpdate{ - testLDSName + "1": wantUpdate1, - testLDSName + "2": wantUpdate2, - }, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh1, wantUpdate1); err != nil { - t.Fatal(err) - } - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) - } - - // Send another update to remove resource 1. - client.NewListeners(map[string]ListenerUpdate{testLDSName + "2": wantUpdate2}, UpdateMetadata{}) - - // Watcher 1 should get an error. - if u, err := ldsUpdateCh1.Receive(ctx); err != nil || ErrType(u.(ldsUpdateErr).err) != ErrorTypeResourceNotFound { - t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err) - } - - // Watcher 2 should get the same update again. - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) - } - - // Send one more update without resource 1. - client.NewListeners(map[string]ListenerUpdate{testLDSName + "2": wantUpdate2}, UpdateMetadata{}) - - // Watcher 1 should not see an update. - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := ldsUpdateCh1.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u) - } - - // Watcher 2 should get the same update again. - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) - } -} diff --git a/xds/internal/client/xds.go b/xds/internal/client/xds.go deleted file mode 100644 index 78dc139e8fe..00000000000 --- a/xds/internal/client/xds.go +++ /dev/null @@ -1,915 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "errors" - "fmt" - "net" - "strconv" - "strings" - "time" - - v1typepb "github.com/cncf/udpa/go/udpa/type/v1" - v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" - v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" - v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" - v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" - v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" - v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" - "google.golang.org/protobuf/types/known/anypb" - - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/httpfilter" - "google.golang.org/grpc/xds/internal/version" -) - -// TransportSocket proto message has a `name` field which is expected to be set -// to this value by the management server. -const transportSocketName = "envoy.transport_sockets.tls" - -// UnmarshalListener processes resources received in an LDS response, validates -// them, and transforms them into a native struct which contains only fields we -// are interested in. -func UnmarshalListener(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger) (map[string]ListenerUpdate, UpdateMetadata, error) { - update := make(map[string]ListenerUpdate) - md, err := processAllResources(version, resources, logger, update) - return update, md, err -} - -func unmarshalListenerResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, ListenerUpdate, error) { - if !IsListenerResource(r.GetTypeUrl()) { - return "", ListenerUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) - } - // TODO: Pass version.TransportAPI instead of relying upon the type URL - v2 := r.GetTypeUrl() == version.V2ListenerURL - lis := &v3listenerpb.Listener{} - if err := proto.Unmarshal(r.GetValue(), lis); err != nil { - return "", ListenerUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) - } - logger.Infof("Resource with name: %v, type: %T, contains: %v", lis.GetName(), lis, lis) - - lu, err := processListener(lis, v2) - if err != nil { - return lis.GetName(), ListenerUpdate{}, err - } - lu.Raw = r - return lis.GetName(), *lu, nil -} - -func processListener(lis *v3listenerpb.Listener, v2 bool) (*ListenerUpdate, error) { - if lis.GetApiListener() != nil { - return processClientSideListener(lis, v2) - } - return processServerSideListener(lis) -} - -// processClientSideListener checks if the provided Listener proto meets -// the expected criteria. If so, it returns a non-empty routeConfigName. -func processClientSideListener(lis *v3listenerpb.Listener, v2 bool) (*ListenerUpdate, error) { - update := &ListenerUpdate{} - - apiLisAny := lis.GetApiListener().GetApiListener() - if !IsHTTPConnManagerResource(apiLisAny.GetTypeUrl()) { - return nil, fmt.Errorf("unexpected resource type: %q", apiLisAny.GetTypeUrl()) - } - apiLis := &v3httppb.HttpConnectionManager{} - if err := proto.Unmarshal(apiLisAny.GetValue(), apiLis); err != nil { - return nil, fmt.Errorf("failed to unmarshal api_listner: %v", err) - } - - switch apiLis.RouteSpecifier.(type) { - case *v3httppb.HttpConnectionManager_Rds: - if apiLis.GetRds().GetConfigSource().GetAds() == nil { - return nil, fmt.Errorf("ConfigSource is not ADS: %+v", lis) - } - name := apiLis.GetRds().GetRouteConfigName() - if name == "" { - return nil, fmt.Errorf("empty route_config_name: %+v", lis) - } - update.RouteConfigName = name - case *v3httppb.HttpConnectionManager_RouteConfig: - // TODO: Add support for specifying the RouteConfiguration inline - // in the LDS response. - return nil, fmt.Errorf("LDS response contains RDS config inline. Not supported for now: %+v", apiLis) - case nil: - return nil, fmt.Errorf("no RouteSpecifier: %+v", apiLis) - default: - return nil, fmt.Errorf("unsupported type %T for RouteSpecifier", apiLis.RouteSpecifier) - } - - if v2 { - return update, nil - } - - // The following checks and fields only apply to xDS protocol versions v3+. - - update.MaxStreamDuration = apiLis.GetCommonHttpProtocolOptions().GetMaxStreamDuration().AsDuration() - - var err error - if update.HTTPFilters, err = processHTTPFilters(apiLis.GetHttpFilters(), false); err != nil { - return nil, err - } - - return update, nil -} - -func unwrapHTTPFilterConfig(config *anypb.Any) (proto.Message, string, error) { - // The real type name is inside the TypedStruct. - s := new(v1typepb.TypedStruct) - if !ptypes.Is(config, s) { - return config, config.GetTypeUrl(), nil - } - if err := ptypes.UnmarshalAny(config, s); err != nil { - return nil, "", fmt.Errorf("error unmarshalling TypedStruct filter config: %v", err) - } - return s, s.GetTypeUrl(), nil -} - -func validateHTTPFilterConfig(cfg *anypb.Any, lds, optional bool) (httpfilter.Filter, httpfilter.FilterConfig, error) { - config, typeURL, err := unwrapHTTPFilterConfig(cfg) - if err != nil { - return nil, nil, err - } - filterBuilder := httpfilter.Get(typeURL) - if filterBuilder == nil { - if optional { - return nil, nil, nil - } - return nil, nil, fmt.Errorf("no filter implementation found for %q", typeURL) - } - parseFunc := filterBuilder.ParseFilterConfig - if !lds { - parseFunc = filterBuilder.ParseFilterConfigOverride - } - filterConfig, err := parseFunc(config) - if err != nil { - return nil, nil, fmt.Errorf("error parsing config for filter %q: %v", typeURL, err) - } - return filterBuilder, filterConfig, nil -} - -func processHTTPFilterOverrides(cfgs map[string]*anypb.Any) (map[string]httpfilter.FilterConfig, error) { - if !env.FaultInjectionSupport || len(cfgs) == 0 { - return nil, nil - } - m := make(map[string]httpfilter.FilterConfig) - for name, cfg := range cfgs { - optional := false - s := new(v3routepb.FilterConfig) - if ptypes.Is(cfg, s) { - if err := ptypes.UnmarshalAny(cfg, s); err != nil { - return nil, fmt.Errorf("filter override %q: error unmarshalling FilterConfig: %v", name, err) - } - cfg = s.GetConfig() - optional = s.GetIsOptional() - } - - httpFilter, config, err := validateHTTPFilterConfig(cfg, false, optional) - if err != nil { - return nil, fmt.Errorf("filter override %q: %v", name, err) - } - if httpFilter == nil { - // Optional configs are ignored. - continue - } - m[name] = config - } - return m, nil -} - -func processHTTPFilters(filters []*v3httppb.HttpFilter, server bool) ([]HTTPFilter, error) { - if !env.FaultInjectionSupport { - return nil, nil - } - - ret := make([]HTTPFilter, 0, len(filters)) - seenNames := make(map[string]bool, len(filters)) - for _, filter := range filters { - name := filter.GetName() - if name == "" { - return nil, errors.New("filter missing name field") - } - if seenNames[name] { - return nil, fmt.Errorf("duplicate filter name %q", name) - } - seenNames[name] = true - - httpFilter, config, err := validateHTTPFilterConfig(filter.GetTypedConfig(), true, filter.GetIsOptional()) - if err != nil { - return nil, err - } - if httpFilter == nil { - // Optional configs are ignored. - continue - } - if server { - if _, ok := httpFilter.(httpfilter.ServerInterceptorBuilder); !ok { - if filter.GetIsOptional() { - continue - } - return nil, fmt.Errorf("HTTP filter %q not supported server-side", name) - } - } else if _, ok := httpFilter.(httpfilter.ClientInterceptorBuilder); !ok { - if filter.GetIsOptional() { - continue - } - return nil, fmt.Errorf("HTTP filter %q not supported client-side", name) - } - - // Save name/config - ret = append(ret, HTTPFilter{Name: name, Filter: httpFilter, Config: config}) - } - return ret, nil -} - -func processServerSideListener(lis *v3listenerpb.Listener) (*ListenerUpdate, error) { - // Make sure that an address encoded in the received listener resource, and - // that it matches the one specified in the name. Listener names on the - // server-side as in the following format: - // grpc/server?udpa.resource.listening_address=IP:Port. - addr := lis.GetAddress() - if addr == nil { - return nil, fmt.Errorf("no address field in LDS response: %+v", lis) - } - sockAddr := addr.GetSocketAddress() - if sockAddr == nil { - return nil, fmt.Errorf("no socket_address field in LDS response: %+v", lis) - } - host, port, err := getAddressFromName(lis.GetName()) - if err != nil { - return nil, fmt.Errorf("no host:port in name field of LDS response: %+v, error: %v", lis, err) - } - if h := sockAddr.GetAddress(); host != h { - return nil, fmt.Errorf("socket_address host does not match the one in name. Got %q, want %q", h, host) - } - if p := strconv.Itoa(int(sockAddr.GetPortValue())); port != p { - return nil, fmt.Errorf("socket_address port does not match the one in name. Got %q, want %q", p, port) - } - - // Make sure the listener resource contains a single filter chain. We do not - // support multiple filter chains and picking the best match from the list. - fcs := lis.GetFilterChains() - if n := len(fcs); n != 1 { - return nil, fmt.Errorf("filter chains count in LDS response does not match expected. Got %d, want 1", n) - } - fc := fcs[0] - - // If the transport_socket field is not specified, it means that the control - // plane has not sent us any security config. This is fine and the server - // will use the fallback credentials configured as part of the - // xdsCredentials. - ts := fc.GetTransportSocket() - if ts == nil { - return &ListenerUpdate{}, nil - } - if name := ts.GetName(); name != transportSocketName { - return nil, fmt.Errorf("transport_socket field has unexpected name: %s", name) - } - any := ts.GetTypedConfig() - if any == nil || any.TypeUrl != version.V3DownstreamTLSContextURL { - return nil, fmt.Errorf("transport_socket field has unexpected typeURL: %s", any.TypeUrl) - } - downstreamCtx := &v3tlspb.DownstreamTlsContext{} - if err := proto.Unmarshal(any.GetValue(), downstreamCtx); err != nil { - return nil, fmt.Errorf("failed to unmarshal DownstreamTlsContext in LDS response: %v", err) - } - if downstreamCtx.GetCommonTlsContext() == nil { - return nil, errors.New("DownstreamTlsContext in LDS response does not contain a CommonTlsContext") - } - sc, err := securityConfigFromCommonTLSContext(downstreamCtx.GetCommonTlsContext()) - if err != nil { - return nil, err - } - if sc.IdentityInstanceName == "" { - return nil, errors.New("security configuration on the server-side does not contain identity certificate provider instance name") - } - sc.RequireClientCert = downstreamCtx.GetRequireClientCertificate().GetValue() - if sc.RequireClientCert && sc.RootInstanceName == "" { - return nil, errors.New("security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set") - } - return &ListenerUpdate{SecurityCfg: sc}, nil -} - -func getAddressFromName(name string) (host string, port string, err error) { - parts := strings.SplitN(name, "udpa.resource.listening_address=", 2) - if len(parts) != 2 { - return "", "", fmt.Errorf("udpa.resource_listening_address not found in name: %v", name) - } - return net.SplitHostPort(parts[1]) -} - -// UnmarshalRouteConfig processes resources received in an RDS response, -// validates them, and transforms them into a native struct which contains only -// fields we are interested in. The provided hostname determines the route -// configuration resources of interest. -func UnmarshalRouteConfig(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger) (map[string]RouteConfigUpdate, UpdateMetadata, error) { - update := make(map[string]RouteConfigUpdate) - md, err := processAllResources(version, resources, logger, update) - return update, md, err -} - -func unmarshalRouteConfigResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, RouteConfigUpdate, error) { - if !IsRouteConfigResource(r.GetTypeUrl()) { - return "", RouteConfigUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) - } - rc := &v3routepb.RouteConfiguration{} - if err := proto.Unmarshal(r.GetValue(), rc); err != nil { - return "", RouteConfigUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) - } - logger.Infof("Resource with name: %v, type: %T, contains: %v.", rc.GetName(), rc, rc) - - // TODO: Pass version.TransportAPI instead of relying upon the type URL - v2 := r.GetTypeUrl() == version.V2RouteConfigURL - u, err := generateRDSUpdateFromRouteConfiguration(rc, logger, v2) - if err != nil { - return rc.GetName(), RouteConfigUpdate{}, err - } - u.Raw = r - return rc.GetName(), u, nil -} - -// generateRDSUpdateFromRouteConfiguration checks if the provided -// RouteConfiguration meets the expected criteria. If so, it returns a -// RouteConfigUpdate with nil error. -// -// A RouteConfiguration resource is considered valid when only if it contains a -// VirtualHost whose domain field matches the server name from the URI passed -// to the gRPC channel, and it contains a clusterName or a weighted cluster. -// -// The RouteConfiguration includes a list of VirtualHosts, which may have zero -// or more elements. We are interested in the element whose domains field -// matches the server name specified in the "xds:" URI. The only field in the -// VirtualHost proto that the we are interested in is the list of routes. We -// only look at the last route in the list (the default route), whose match -// field must be empty and whose route field must be set. Inside that route -// message, the cluster field will contain the clusterName or weighted clusters -// we are looking for. -func generateRDSUpdateFromRouteConfiguration(rc *v3routepb.RouteConfiguration, logger *grpclog.PrefixLogger, v2 bool) (RouteConfigUpdate, error) { - var vhs []*VirtualHost - for _, vh := range rc.GetVirtualHosts() { - routes, err := routesProtoToSlice(vh.Routes, logger, v2) - if err != nil { - return RouteConfigUpdate{}, fmt.Errorf("received route is invalid: %v", err) - } - vhOut := &VirtualHost{ - Domains: vh.GetDomains(), - Routes: routes, - } - if !v2 { - cfgs, err := processHTTPFilterOverrides(vh.GetTypedPerFilterConfig()) - if err != nil { - return RouteConfigUpdate{}, fmt.Errorf("virtual host %+v: %v", vh, err) - } - vhOut.HTTPFilterConfigOverride = cfgs - } - vhs = append(vhs, vhOut) - } - return RouteConfigUpdate{VirtualHosts: vhs}, nil -} - -func routesProtoToSlice(routes []*v3routepb.Route, logger *grpclog.PrefixLogger, v2 bool) ([]*Route, error) { - var routesRet []*Route - - for _, r := range routes { - match := r.GetMatch() - if match == nil { - return nil, fmt.Errorf("route %+v doesn't have a match", r) - } - - if len(match.GetQueryParameters()) != 0 { - // Ignore route with query parameters. - logger.Warningf("route %+v has query parameter matchers, the route will be ignored", r) - continue - } - - pathSp := match.GetPathSpecifier() - if pathSp == nil { - return nil, fmt.Errorf("route %+v doesn't have a path specifier", r) - } - - var route Route - switch pt := pathSp.(type) { - case *v3routepb.RouteMatch_Prefix: - route.Prefix = &pt.Prefix - case *v3routepb.RouteMatch_Path: - route.Path = &pt.Path - case *v3routepb.RouteMatch_SafeRegex: - route.Regex = &pt.SafeRegex.Regex - default: - return nil, fmt.Errorf("route %+v has an unrecognized path specifier: %+v", r, pt) - } - - if caseSensitive := match.GetCaseSensitive(); caseSensitive != nil { - route.CaseInsensitive = !caseSensitive.Value - } - - for _, h := range match.GetHeaders() { - var header HeaderMatcher - switch ht := h.GetHeaderMatchSpecifier().(type) { - case *v3routepb.HeaderMatcher_ExactMatch: - header.ExactMatch = &ht.ExactMatch - case *v3routepb.HeaderMatcher_SafeRegexMatch: - header.RegexMatch = &ht.SafeRegexMatch.Regex - case *v3routepb.HeaderMatcher_RangeMatch: - header.RangeMatch = &Int64Range{ - Start: ht.RangeMatch.Start, - End: ht.RangeMatch.End, - } - case *v3routepb.HeaderMatcher_PresentMatch: - header.PresentMatch = &ht.PresentMatch - case *v3routepb.HeaderMatcher_PrefixMatch: - header.PrefixMatch = &ht.PrefixMatch - case *v3routepb.HeaderMatcher_SuffixMatch: - header.SuffixMatch = &ht.SuffixMatch - default: - return nil, fmt.Errorf("route %+v has an unrecognized header matcher: %+v", r, ht) - } - header.Name = h.GetName() - invert := h.GetInvertMatch() - header.InvertMatch = &invert - route.Headers = append(route.Headers, &header) - } - - if fr := match.GetRuntimeFraction(); fr != nil { - d := fr.GetDefaultValue() - n := d.GetNumerator() - switch d.GetDenominator() { - case v3typepb.FractionalPercent_HUNDRED: - n *= 10000 - case v3typepb.FractionalPercent_TEN_THOUSAND: - n *= 100 - case v3typepb.FractionalPercent_MILLION: - } - route.Fraction = &n - } - - route.WeightedClusters = make(map[string]WeightedCluster) - action := r.GetRoute() - switch a := action.GetClusterSpecifier().(type) { - case *v3routepb.RouteAction_Cluster: - route.WeightedClusters[a.Cluster] = WeightedCluster{Weight: 1} - case *v3routepb.RouteAction_WeightedClusters: - wcs := a.WeightedClusters - var totalWeight uint32 - for _, c := range wcs.Clusters { - w := c.GetWeight().GetValue() - if w == 0 { - continue - } - wc := WeightedCluster{Weight: w} - if !v2 { - cfgs, err := processHTTPFilterOverrides(c.GetTypedPerFilterConfig()) - if err != nil { - return nil, fmt.Errorf("route %+v, action %+v: %v", r, a, err) - } - wc.HTTPFilterConfigOverride = cfgs - } - route.WeightedClusters[c.GetName()] = wc - totalWeight += w - } - if totalWeight != wcs.GetTotalWeight().GetValue() { - return nil, fmt.Errorf("route %+v, action %+v, weights of clusters do not add up to total total weight, got: %v, want %v", r, a, wcs.GetTotalWeight().GetValue(), totalWeight) - } - if totalWeight == 0 { - return nil, fmt.Errorf("route %+v, action %+v, has no valid cluster in WeightedCluster action", r, a) - } - case *v3routepb.RouteAction_ClusterHeader: - continue - } - - msd := action.GetMaxStreamDuration() - // Prefer grpc_timeout_header_max, if set. - dur := msd.GetGrpcTimeoutHeaderMax() - if dur == nil { - dur = msd.GetMaxStreamDuration() - } - if dur != nil { - d := dur.AsDuration() - route.MaxStreamDuration = &d - } - - if !v2 { - cfgs, err := processHTTPFilterOverrides(r.GetTypedPerFilterConfig()) - if err != nil { - return nil, fmt.Errorf("route %+v: %v", r, err) - } - route.HTTPFilterConfigOverride = cfgs - } - routesRet = append(routesRet, &route) - } - return routesRet, nil -} - -// UnmarshalCluster processes resources received in an CDS response, validates -// them, and transforms them into a native struct which contains only fields we -// are interested in. -func UnmarshalCluster(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger) (map[string]ClusterUpdate, UpdateMetadata, error) { - update := make(map[string]ClusterUpdate) - md, err := processAllResources(version, resources, logger, update) - return update, md, err -} - -func unmarshalClusterResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, ClusterUpdate, error) { - if !IsClusterResource(r.GetTypeUrl()) { - return "", ClusterUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) - } - - cluster := &v3clusterpb.Cluster{} - if err := proto.Unmarshal(r.GetValue(), cluster); err != nil { - return "", ClusterUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) - } - logger.Infof("Resource with name: %v, type: %T, contains: %v", cluster.GetName(), cluster, cluster) - - cu, err := validateCluster(cluster) - if err != nil { - return cluster.GetName(), ClusterUpdate{}, err - } - cu.Raw = r - // If the Cluster message in the CDS response did not contain a - // serviceName, we will just use the clusterName for EDS. - if cu.ServiceName == "" { - cu.ServiceName = cluster.GetName() - } - return cluster.GetName(), cu, nil -} - -func validateCluster(cluster *v3clusterpb.Cluster) (ClusterUpdate, error) { - emptyUpdate := ClusterUpdate{ServiceName: "", EnableLRS: false} - switch { - case cluster.GetType() != v3clusterpb.Cluster_EDS: - return emptyUpdate, fmt.Errorf("unexpected cluster type %v in response: %+v", cluster.GetType(), cluster) - case cluster.GetEdsClusterConfig().GetEdsConfig().GetAds() == nil: - return emptyUpdate, fmt.Errorf("unexpected edsConfig in response: %+v", cluster) - case cluster.GetLbPolicy() != v3clusterpb.Cluster_ROUND_ROBIN: - return emptyUpdate, fmt.Errorf("unexpected lbPolicy %v in response: %+v", cluster.GetLbPolicy(), cluster) - } - - // Process security configuration received from the control plane iff the - // corresponding environment variable is set. - var sc *SecurityConfig - if env.ClientSideSecuritySupport { - var err error - if sc, err = securityConfigFromCluster(cluster); err != nil { - return emptyUpdate, err - } - } - - return ClusterUpdate{ - ServiceName: cluster.GetEdsClusterConfig().GetServiceName(), - EnableLRS: cluster.GetLrsServer().GetSelf() != nil, - SecurityCfg: sc, - MaxRequests: circuitBreakersFromCluster(cluster), - }, nil -} - -// securityConfigFromCluster extracts the relevant security configuration from -// the received Cluster resource. -func securityConfigFromCluster(cluster *v3clusterpb.Cluster) (*SecurityConfig, error) { - // The Cluster resource contains a `transport_socket` field, which contains - // a oneof `typed_config` field of type `protobuf.Any`. The any proto - // contains a marshaled representation of an `UpstreamTlsContext` message. - ts := cluster.GetTransportSocket() - if ts == nil { - return nil, nil - } - if name := ts.GetName(); name != transportSocketName { - return nil, fmt.Errorf("transport_socket field has unexpected name: %s", name) - } - any := ts.GetTypedConfig() - if any == nil || any.TypeUrl != version.V3UpstreamTLSContextURL { - return nil, fmt.Errorf("transport_socket field has unexpected typeURL: %s", any.TypeUrl) - } - upstreamCtx := &v3tlspb.UpstreamTlsContext{} - if err := proto.Unmarshal(any.GetValue(), upstreamCtx); err != nil { - return nil, fmt.Errorf("failed to unmarshal UpstreamTlsContext in CDS response: %v", err) - } - if upstreamCtx.GetCommonTlsContext() == nil { - return nil, errors.New("UpstreamTlsContext in CDS response does not contain a CommonTlsContext") - } - - sc, err := securityConfigFromCommonTLSContext(upstreamCtx.GetCommonTlsContext()) - if err != nil { - return nil, err - } - if sc.RootInstanceName == "" { - return nil, errors.New("security configuration on the client-side does not contain root certificate provider instance name") - } - return sc, nil -} - -// common is expected to be not nil. -func securityConfigFromCommonTLSContext(common *v3tlspb.CommonTlsContext) (*SecurityConfig, error) { - // The `CommonTlsContext` contains a - // `tls_certificate_certificate_provider_instance` field of type - // `CertificateProviderInstance`, which contains the provider instance name - // and the certificate name to fetch identity certs. - sc := &SecurityConfig{} - if identity := common.GetTlsCertificateCertificateProviderInstance(); identity != nil { - sc.IdentityInstanceName = identity.GetInstanceName() - sc.IdentityCertName = identity.GetCertificateName() - } - - // The `CommonTlsContext` contains a `validation_context_type` field which - // is a oneof. We can get the values that we are interested in from two of - // those possible values: - // - combined validation context: - // - contains a default validation context which holds the list of - // accepted SANs. - // - contains certificate provider instance configuration - // - certificate provider instance configuration - // - in this case, we do not get a list of accepted SANs. - switch t := common.GetValidationContextType().(type) { - case *v3tlspb.CommonTlsContext_CombinedValidationContext: - combined := common.GetCombinedValidationContext() - if def := combined.GetDefaultValidationContext(); def != nil { - for _, matcher := range def.GetMatchSubjectAltNames() { - // We only support exact matches for now. - if exact := matcher.GetExact(); exact != "" { - sc.AcceptedSANs = append(sc.AcceptedSANs, exact) - } - } - } - if pi := combined.GetValidationContextCertificateProviderInstance(); pi != nil { - sc.RootInstanceName = pi.GetInstanceName() - sc.RootCertName = pi.GetCertificateName() - } - case *v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance: - pi := common.GetValidationContextCertificateProviderInstance() - sc.RootInstanceName = pi.GetInstanceName() - sc.RootCertName = pi.GetCertificateName() - case nil: - // It is valid for the validation context to be nil on the server side. - default: - return nil, fmt.Errorf("validation context contains unexpected type: %T", t) - } - return sc, nil -} - -// circuitBreakersFromCluster extracts the circuit breakers configuration from -// the received cluster resource. Returns nil if no CircuitBreakers or no -// Thresholds in CircuitBreakers. -func circuitBreakersFromCluster(cluster *v3clusterpb.Cluster) *uint32 { - if !env.CircuitBreakingSupport { - return nil - } - for _, threshold := range cluster.GetCircuitBreakers().GetThresholds() { - if threshold.GetPriority() != v3corepb.RoutingPriority_DEFAULT { - continue - } - maxRequestsPb := threshold.GetMaxRequests() - if maxRequestsPb == nil { - return nil - } - maxRequests := maxRequestsPb.GetValue() - return &maxRequests - } - return nil -} - -// UnmarshalEndpoints processes resources received in an EDS response, -// validates them, and transforms them into a native struct which contains only -// fields we are interested in. -func UnmarshalEndpoints(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger) (map[string]EndpointsUpdate, UpdateMetadata, error) { - update := make(map[string]EndpointsUpdate) - md, err := processAllResources(version, resources, logger, update) - return update, md, err -} - -func unmarshalEndpointsResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, EndpointsUpdate, error) { - if !IsEndpointsResource(r.GetTypeUrl()) { - return "", EndpointsUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) - } - - cla := &v3endpointpb.ClusterLoadAssignment{} - if err := proto.Unmarshal(r.GetValue(), cla); err != nil { - return "", EndpointsUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) - } - logger.Infof("Resource with name: %v, type: %T, contains: %v", cla.GetClusterName(), cla, cla) - - u, err := parseEDSRespProto(cla) - if err != nil { - return cla.GetClusterName(), EndpointsUpdate{}, err - } - u.Raw = r - return cla.GetClusterName(), u, nil -} - -func parseAddress(socketAddress *v3corepb.SocketAddress) string { - return net.JoinHostPort(socketAddress.GetAddress(), strconv.Itoa(int(socketAddress.GetPortValue()))) -} - -func parseDropPolicy(dropPolicy *v3endpointpb.ClusterLoadAssignment_Policy_DropOverload) OverloadDropConfig { - percentage := dropPolicy.GetDropPercentage() - var ( - numerator = percentage.GetNumerator() - denominator uint32 - ) - switch percentage.GetDenominator() { - case v3typepb.FractionalPercent_HUNDRED: - denominator = 100 - case v3typepb.FractionalPercent_TEN_THOUSAND: - denominator = 10000 - case v3typepb.FractionalPercent_MILLION: - denominator = 1000000 - } - return OverloadDropConfig{ - Category: dropPolicy.GetCategory(), - Numerator: numerator, - Denominator: denominator, - } -} - -func parseEndpoints(lbEndpoints []*v3endpointpb.LbEndpoint) []Endpoint { - endpoints := make([]Endpoint, 0, len(lbEndpoints)) - for _, lbEndpoint := range lbEndpoints { - endpoints = append(endpoints, Endpoint{ - HealthStatus: EndpointHealthStatus(lbEndpoint.GetHealthStatus()), - Address: parseAddress(lbEndpoint.GetEndpoint().GetAddress().GetSocketAddress()), - Weight: lbEndpoint.GetLoadBalancingWeight().GetValue(), - }) - } - return endpoints -} - -func parseEDSRespProto(m *v3endpointpb.ClusterLoadAssignment) (EndpointsUpdate, error) { - ret := EndpointsUpdate{} - for _, dropPolicy := range m.GetPolicy().GetDropOverloads() { - ret.Drops = append(ret.Drops, parseDropPolicy(dropPolicy)) - } - priorities := make(map[uint32]struct{}) - for _, locality := range m.Endpoints { - l := locality.GetLocality() - if l == nil { - return EndpointsUpdate{}, fmt.Errorf("EDS response contains a locality without ID, locality: %+v", locality) - } - lid := internal.LocalityID{ - Region: l.Region, - Zone: l.Zone, - SubZone: l.SubZone, - } - priority := locality.GetPriority() - priorities[priority] = struct{}{} - ret.Localities = append(ret.Localities, Locality{ - ID: lid, - Endpoints: parseEndpoints(locality.GetLbEndpoints()), - Weight: locality.GetLoadBalancingWeight().GetValue(), - Priority: priority, - }) - } - for i := 0; i < len(priorities); i++ { - if _, ok := priorities[uint32(i)]; !ok { - return EndpointsUpdate{}, fmt.Errorf("priority %v missing (with different priorities %v received)", i, priorities) - } - } - return ret, nil -} - -// processAllResources unmarshals and validates the resources, populates the -// provided ret (a map), and returns metadata and error. -// -// The type of the resource is determined by the type of ret. E.g. -// map[string]ListenerUpdate means this is for LDS. -func processAllResources(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger, ret interface{}) (UpdateMetadata, error) { - timestamp := time.Now() - md := UpdateMetadata{ - Version: version, - Timestamp: timestamp, - } - var topLevelErrors []error - perResourceErrors := make(map[string]error) - - for _, r := range resources { - switch ret2 := ret.(type) { - case map[string]ListenerUpdate: - name, update, err := unmarshalListenerResource(r, logger) - if err == nil { - ret2[name] = update - continue - } - if name == "" { - topLevelErrors = append(topLevelErrors, err) - continue - } - perResourceErrors[name] = err - // Add place holder in the map so we know this resource name was in - // the response. - ret2[name] = ListenerUpdate{} - case map[string]RouteConfigUpdate: - name, update, err := unmarshalRouteConfigResource(r, logger) - if err == nil { - ret2[name] = update - continue - } - if name == "" { - topLevelErrors = append(topLevelErrors, err) - continue - } - perResourceErrors[name] = err - // Add place holder in the map so we know this resource name was in - // the response. - ret2[name] = RouteConfigUpdate{} - case map[string]ClusterUpdate: - name, update, err := unmarshalClusterResource(r, logger) - if err == nil { - ret2[name] = update - continue - } - if name == "" { - topLevelErrors = append(topLevelErrors, err) - continue - } - perResourceErrors[name] = err - // Add place holder in the map so we know this resource name was in - // the response. - ret2[name] = ClusterUpdate{} - case map[string]EndpointsUpdate: - name, update, err := unmarshalEndpointsResource(r, logger) - if err == nil { - ret2[name] = update - continue - } - if name == "" { - topLevelErrors = append(topLevelErrors, err) - continue - } - perResourceErrors[name] = err - // Add place holder in the map so we know this resource name was in - // the response. - ret2[name] = EndpointsUpdate{} - } - } - - if len(topLevelErrors) == 0 && len(perResourceErrors) == 0 { - md.Status = ServiceStatusACKed - return md, nil - } - - var typeStr string - switch ret.(type) { - case map[string]ListenerUpdate: - typeStr = "LDS" - case map[string]RouteConfigUpdate: - typeStr = "RDS" - case map[string]ClusterUpdate: - typeStr = "CDS" - case map[string]EndpointsUpdate: - typeStr = "EDS" - } - - md.Status = ServiceStatusNACKed - errRet := combineErrors(typeStr, topLevelErrors, perResourceErrors) - md.ErrState = &UpdateErrorMetadata{ - Version: version, - Err: errRet, - Timestamp: timestamp, - } - return md, errRet -} - -func combineErrors(rType string, topLevelErrors []error, perResourceErrors map[string]error) error { - var errStrB strings.Builder - errStrB.WriteString(fmt.Sprintf("error parsing %q response: ", rType)) - if len(topLevelErrors) > 0 { - errStrB.WriteString("top level errors: ") - for i, err := range topLevelErrors { - if i != 0 { - errStrB.WriteString(";\n") - } - errStrB.WriteString(err.Error()) - } - } - if len(perResourceErrors) > 0 { - var i int - for name, err := range perResourceErrors { - if i != 0 { - errStrB.WriteString(";\n") - } - i++ - errStrB.WriteString(fmt.Sprintf("resource %q: %v", name, err.Error())) - } - } - return errors.New(errStrB.String()) -} diff --git a/xds/internal/httpfilter/fault/fault.go b/xds/internal/httpfilter/fault/fault.go new file mode 100644 index 00000000000..725b50a76a8 --- /dev/null +++ b/xds/internal/httpfilter/fault/fault.go @@ -0,0 +1,301 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package fault implements the Envoy Fault Injection HTTP filter. +package fault + +import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "sync/atomic" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpcrand" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/protobuf/types/known/anypb" + + cpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/common/fault/v3" + fpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/fault/v3" + tpb "github.com/envoyproxy/go-control-plane/envoy/type/v3" +) + +const headerAbortHTTPStatus = "x-envoy-fault-abort-request" +const headerAbortGRPCStatus = "x-envoy-fault-abort-grpc-request" +const headerAbortPercentage = "x-envoy-fault-abort-request-percentage" + +const headerDelayPercentage = "x-envoy-fault-delay-request-percentage" +const headerDelayDuration = "x-envoy-fault-delay-request" + +var statusMap = map[int]codes.Code{ + 400: codes.Internal, + 401: codes.Unauthenticated, + 403: codes.PermissionDenied, + 404: codes.Unimplemented, + 429: codes.Unavailable, + 502: codes.Unavailable, + 503: codes.Unavailable, + 504: codes.Unavailable, +} + +func init() { + httpfilter.Register(builder{}) +} + +type builder struct { +} + +type config struct { + httpfilter.FilterConfig + config *fpb.HTTPFault +} + +func (builder) TypeURLs() []string { + return []string{"type.googleapis.com/envoy.extensions.filters.http.fault.v3.HTTPFault"} +} + +// Parsing is the same for the base config and the override config. +func parseConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + if cfg == nil { + return nil, fmt.Errorf("fault: nil configuration message provided") + } + any, ok := cfg.(*anypb.Any) + if !ok { + return nil, fmt.Errorf("fault: error parsing config %v: unknown type %T", cfg, cfg) + } + msg := new(fpb.HTTPFault) + if err := ptypes.UnmarshalAny(any, msg); err != nil { + return nil, fmt.Errorf("fault: error parsing config %v: %v", cfg, err) + } + return config{config: msg}, nil +} + +func (builder) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return parseConfig(cfg) +} + +func (builder) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return parseConfig(override) +} + +func (builder) IsTerminal() bool { + return false +} + +var _ httpfilter.ClientInterceptorBuilder = builder{} + +func (builder) BuildClientInterceptor(cfg, override httpfilter.FilterConfig) (iresolver.ClientInterceptor, error) { + if cfg == nil { + return nil, fmt.Errorf("fault: nil config provided") + } + + c, ok := cfg.(config) + if !ok { + return nil, fmt.Errorf("fault: incorrect config type provided (%T): %v", cfg, cfg) + } + + if override != nil { + // override completely replaces the listener configuration; but we + // still validate the listener config type. + c, ok = override.(config) + if !ok { + return nil, fmt.Errorf("fault: incorrect override config type provided (%T): %v", override, override) + } + } + + icfg := c.config + if (icfg.GetMaxActiveFaults() != nil && icfg.GetMaxActiveFaults().GetValue() == 0) || + (icfg.GetDelay() == nil && icfg.GetAbort() == nil) { + return nil, nil + } + return &interceptor{config: icfg}, nil +} + +type interceptor struct { + config *fpb.HTTPFault +} + +var activeFaults uint32 // global active faults; accessed atomically + +func (i *interceptor) NewStream(ctx context.Context, ri iresolver.RPCInfo, done func(), newStream func(ctx context.Context, done func()) (iresolver.ClientStream, error)) (iresolver.ClientStream, error) { + if maxAF := i.config.GetMaxActiveFaults(); maxAF != nil { + defer atomic.AddUint32(&activeFaults, ^uint32(0)) // decrement counter + if af := atomic.AddUint32(&activeFaults, 1); af > maxAF.GetValue() { + // Would exceed maximum active fault limit. + return newStream(ctx, done) + } + } + + if err := injectDelay(ctx, i.config.GetDelay()); err != nil { + return nil, err + } + + if err := injectAbort(ctx, i.config.GetAbort()); err != nil { + if err == errOKStream { + return &okStream{ctx: ctx}, nil + } + return nil, err + } + return newStream(ctx, done) +} + +// For overriding in tests +var randIntn = grpcrand.Intn +var newTimer = time.NewTimer + +func injectDelay(ctx context.Context, delayCfg *cpb.FaultDelay) error { + numerator, denominator := splitPct(delayCfg.GetPercentage()) + var delay time.Duration + switch delayType := delayCfg.GetFaultDelaySecifier().(type) { + case *cpb.FaultDelay_FixedDelay: + delay = delayType.FixedDelay.AsDuration() + case *cpb.FaultDelay_HeaderDelay_: + md, _ := metadata.FromOutgoingContext(ctx) + v := md[headerDelayDuration] + if v == nil { + // No delay configured for this RPC. + return nil + } + ms, ok := parseIntFromMD(v) + if !ok { + // Malformed header; no delay. + return nil + } + delay = time.Duration(ms) * time.Millisecond + if v := md[headerDelayPercentage]; v != nil { + if num, ok := parseIntFromMD(v); ok && num < numerator { + numerator = num + } + } + } + if delay == 0 || randIntn(denominator) >= numerator { + return nil + } + t := newTimer(delay) + select { + case <-t.C: + case <-ctx.Done(): + t.Stop() + return ctx.Err() + } + return nil +} + +func injectAbort(ctx context.Context, abortCfg *fpb.FaultAbort) error { + numerator, denominator := splitPct(abortCfg.GetPercentage()) + code := codes.OK + okCode := false + switch errType := abortCfg.GetErrorType().(type) { + case *fpb.FaultAbort_HttpStatus: + code, okCode = grpcFromHTTP(int(errType.HttpStatus)) + case *fpb.FaultAbort_GrpcStatus: + code, okCode = sanitizeGRPCCode(codes.Code(errType.GrpcStatus)), true + case *fpb.FaultAbort_HeaderAbort_: + md, _ := metadata.FromOutgoingContext(ctx) + if v := md[headerAbortHTTPStatus]; v != nil { + // HTTP status has priority over gRPC status. + if httpStatus, ok := parseIntFromMD(v); ok { + code, okCode = grpcFromHTTP(httpStatus) + } + } else if v := md[headerAbortGRPCStatus]; v != nil { + if grpcStatus, ok := parseIntFromMD(v); ok { + code, okCode = sanitizeGRPCCode(codes.Code(grpcStatus)), true + } + } + if v := md[headerAbortPercentage]; v != nil { + if num, ok := parseIntFromMD(v); ok && num < numerator { + numerator = num + } + } + } + if !okCode || randIntn(denominator) >= numerator { + return nil + } + if code == codes.OK { + return errOKStream + } + return status.Errorf(code, "RPC terminated due to fault injection") +} + +var errOKStream = errors.New("stream terminated early with OK status") + +// parseIntFromMD returns the integer in the last header or nil if parsing +// failed. +func parseIntFromMD(header []string) (int, bool) { + if len(header) == 0 { + return 0, false + } + v, err := strconv.Atoi(header[len(header)-1]) + return v, err == nil +} + +func splitPct(fp *tpb.FractionalPercent) (num int, den int) { + if fp == nil { + return 0, 100 + } + num = int(fp.GetNumerator()) + switch fp.GetDenominator() { + case tpb.FractionalPercent_HUNDRED: + return num, 100 + case tpb.FractionalPercent_TEN_THOUSAND: + return num, 10 * 1000 + case tpb.FractionalPercent_MILLION: + return num, 1000 * 1000 + } + return num, 100 +} + +func grpcFromHTTP(httpStatus int) (codes.Code, bool) { + if httpStatus < 200 || httpStatus >= 600 { + // Malformed; ignore this fault type. + return codes.OK, false + } + if c := statusMap[httpStatus]; c != codes.OK { + // OK = 0/the default for the map. + return c, true + } + // All undefined HTTP status codes convert to Unknown. HTTP status of 200 + // is "success", but gRPC converts to Unknown due to missing grpc status. + return codes.Unknown, true +} + +func sanitizeGRPCCode(c codes.Code) codes.Code { + if c > codes.Code(16) { + return codes.Unknown + } + return c +} + +type okStream struct { + ctx context.Context +} + +func (*okStream) Header() (metadata.MD, error) { return nil, nil } +func (*okStream) Trailer() metadata.MD { return nil } +func (*okStream) CloseSend() error { return nil } +func (o *okStream) Context() context.Context { return o.ctx } +func (*okStream) SendMsg(m interface{}) error { return io.EOF } +func (*okStream) RecvMsg(m interface{}) error { return io.EOF } diff --git a/xds/internal/httpfilter/fault/fault_test.go b/xds/internal/httpfilter/fault/fault_test.go new file mode 100644 index 00000000000..c2959054da9 --- /dev/null +++ b/xds/internal/httpfilter/fault/fault_test.go @@ -0,0 +1,672 @@ +//go:build !386 +// +build !386 + +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package xds_test contains e2e tests for xDS use. +package fault + +import ( + "context" + "fmt" + "io" + "net" + "reflect" + "testing" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + xtestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/protobuf/types/known/wrapperspb" + + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + cpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/common/fault/v3" + fpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/fault/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + tpb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + testpb "google.golang.org/grpc/test/grpc_testing" + + _ "google.golang.org/grpc/xds/internal/balancer" // Register the balancers. + _ "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register the v3 xDS API client. +) + +const defaultTestTimeout = 10 * time.Second + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type testService struct { + testpb.TestServiceServer +} + +func (*testService) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil +} + +func (*testService) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { + // End RPC after client does a CloseSend. + for { + if _, err := stream.Recv(); err == io.EOF { + return nil + } else if err != nil { + return err + } + } +} + +// clientSetup performs a bunch of steps common to all xDS server tests here: +// - spin up an xDS management server on a local port +// - spin up a gRPC server and register the test service on it +// - create a local TCP listener and start serving on it +// +// Returns the following: +// - the management server: tests use this to configure resources +// - nodeID expected by the management server: this is set in the Node proto +// sent by the xdsClient for queries. +// - the port the server is listening on +// - cleanup function to be invoked by the tests when done +func clientSetup(t *testing.T) (*e2e.ManagementServer, string, uint32, func()) { + // Spin up a xDS management server on a local port. + nodeID := uuid.New().String() + fs, err := e2e.StartManagementServer() + if err != nil { + t.Fatal(err) + } + + // Create a bootstrap file in a temporary directory. + bootstrapCleanup, err := xds.SetupBootstrapFile(xds.BootstrapOptions{ + Version: xds.TransportV3, + NodeID: nodeID, + ServerURI: fs.Address, + ServerListenerResourceNameTemplate: "grpc/server", + }) + if err != nil { + t.Fatal(err) + } + + // Initialize a gRPC server and register the stubServer on it. + server := grpc.NewServer() + testpb.RegisterTestServiceServer(server, &testService{}) + + // Create a local listener and pass it to Serve(). + lis, err := xtestutils.LocalTCPListener() + if err != nil { + t.Fatalf("xtestutils.LocalTCPListener() failed: %v", err) + } + + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + return fs, nodeID, uint32(lis.Addr().(*net.TCPAddr).Port), func() { + fs.Stop() + bootstrapCleanup() + server.Stop() + } +} + +func (s) TestFaultInjection_Unary(t *testing.T) { + type subcase struct { + name string + code codes.Code + repeat int + randIn []int // Intn calls per-repeat (not per-subcase) + delays []time.Duration // NewTimer calls per-repeat (not per-subcase) + md metadata.MD + } + testCases := []struct { + name string + cfgs []*fpb.HTTPFault + randOutInc int + want []subcase + }{{ + name: "max faults zero", + cfgs: []*fpb.HTTPFault{{ + MaxActiveFaults: wrapperspb.UInt32(0), + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 100, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Aborted)}, + }, + }}, + randOutInc: 5, + want: []subcase{{ + code: codes.OK, + repeat: 25, + }}, + }, { + name: "no abort or delay", + cfgs: []*fpb.HTTPFault{{}}, + randOutInc: 5, + want: []subcase{{ + code: codes.OK, + repeat: 25, + }}, + }, { + name: "abort always", + cfgs: []*fpb.HTTPFault{{ + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 100, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Aborted)}, + }, + }}, + randOutInc: 5, + want: []subcase{{ + code: codes.Aborted, + randIn: []int{100}, + repeat: 25, + }}, + }, { + name: "abort 10%", + cfgs: []*fpb.HTTPFault{{ + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 100000, Denominator: tpb.FractionalPercent_MILLION}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Aborted)}, + }, + }}, + randOutInc: 50000, + want: []subcase{{ + name: "[0,10]%", + code: codes.Aborted, + randIn: []int{1000000}, + repeat: 2, + }, { + name: "(10,100]%", + code: codes.OK, + randIn: []int{1000000}, + repeat: 18, + }, { + name: "[0,10]% again", + code: codes.Aborted, + randIn: []int{1000000}, + repeat: 2, + }}, + }, { + name: "delay always", + cfgs: []*fpb.HTTPFault{{ + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 100, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(time.Second)}, + }, + }}, + randOutInc: 5, + want: []subcase{{ + randIn: []int{100}, + repeat: 25, + delays: []time.Duration{time.Second}, + }}, + }, { + name: "delay 10%", + cfgs: []*fpb.HTTPFault{{ + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 1000, Denominator: tpb.FractionalPercent_TEN_THOUSAND}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(time.Second)}, + }, + }}, + randOutInc: 500, + want: []subcase{{ + name: "[0,10]%", + randIn: []int{10000}, + repeat: 2, + delays: []time.Duration{time.Second}, + }, { + name: "(10,100]%", + randIn: []int{10000}, + repeat: 18, + }, { + name: "[0,10]% again", + randIn: []int{10000}, + repeat: 2, + delays: []time.Duration{time.Second}, + }}, + }, { + name: "delay 80%, abort 50%", + cfgs: []*fpb.HTTPFault{{ + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 80, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(3 * time.Second)}, + }, + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 50, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Unimplemented)}, + }, + }}, + randOutInc: 5, + want: []subcase{{ + name: "50% delay and abort", + code: codes.Unimplemented, + randIn: []int{100, 100}, + repeat: 10, + delays: []time.Duration{3 * time.Second}, + }, { + name: "30% delay, no abort", + randIn: []int{100, 100}, + repeat: 6, + delays: []time.Duration{3 * time.Second}, + }, { + name: "20% success", + randIn: []int{100, 100}, + repeat: 4, + }, { + name: "50% delay and abort again", + code: codes.Unimplemented, + randIn: []int{100, 100}, + repeat: 10, + delays: []time.Duration{3 * time.Second}, + }}, + }, { + name: "header abort", + cfgs: []*fpb.HTTPFault{{ + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 80, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_HeaderAbort_{}, + }, + }}, + randOutInc: 10, + want: []subcase{{ + name: "30% abort; [0,30]%", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"30"}, + }, + code: codes.DataLoss, + randIn: []int{100}, + repeat: 3, + }, { + name: "30% abort; (30,60]%", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"30"}, + }, + randIn: []int{100}, + repeat: 3, + }, { + name: "80% abort; (60,80]%", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"80"}, + }, + code: codes.DataLoss, + randIn: []int{100}, + repeat: 2, + }, { + name: "cannot exceed percentage in filter", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"100"}, + }, + randIn: []int{100}, + repeat: 2, + }, { + name: "HTTP Status 404", + md: metadata.MD{ + headerAbortHTTPStatus: []string{"404"}, + headerAbortPercentage: []string{"100"}, + }, + code: codes.Unimplemented, + randIn: []int{100}, + repeat: 1, + }, { + name: "HTTP Status 429", + md: metadata.MD{ + headerAbortHTTPStatus: []string{"429"}, + headerAbortPercentage: []string{"100"}, + }, + code: codes.Unavailable, + randIn: []int{100}, + repeat: 1, + }, { + name: "HTTP Status 200", + md: metadata.MD{ + headerAbortHTTPStatus: []string{"200"}, + headerAbortPercentage: []string{"100"}, + }, + // No GRPC status, but HTTP Status of 200 translates to Unknown, + // per spec in statuscodes.md. + code: codes.Unknown, + randIn: []int{100}, + repeat: 1, + }, { + name: "gRPC Status OK", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.OK)}, + headerAbortPercentage: []string{"100"}, + }, + // This should be Unimplemented (mismatched request/response + // count), per spec in statuscodes.md, but grpc-go currently + // returns io.EOF which status.Code() converts to Unknown + code: codes.Unknown, + randIn: []int{100}, + repeat: 1, + }, { + name: "invalid header results in no abort", + md: metadata.MD{ + headerAbortGRPCStatus: []string{"error"}, + headerAbortPercentage: []string{"100"}, + }, + repeat: 1, + }, { + name: "invalid header results in default percentage", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"error"}, + }, + code: codes.DataLoss, + randIn: []int{100}, + repeat: 1, + }}, + }, { + name: "header delay", + cfgs: []*fpb.HTTPFault{{ + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 80, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_HeaderDelay_{}, + }, + }}, + randOutInc: 10, + want: []subcase{{ + name: "30% delay; [0,30]%", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"30"}, + }, + randIn: []int{100}, + delays: []time.Duration{2 * time.Millisecond}, + repeat: 3, + }, { + name: "30% delay; (30, 60]%", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"30"}, + }, + randIn: []int{100}, + repeat: 3, + }, { + name: "invalid header results in no delay", + md: metadata.MD{ + headerDelayDuration: []string{"error"}, + headerDelayPercentage: []string{"80"}, + }, + repeat: 1, + }, { + name: "invalid header results in default percentage", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"error"}, + }, + randIn: []int{100}, + delays: []time.Duration{2 * time.Millisecond}, + repeat: 1, + }, { + name: "invalid header results in default percentage", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"error"}, + }, + randIn: []int{100}, + repeat: 1, + }, { + name: "cannot exceed percentage in filter", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"100"}, + }, + randIn: []int{100}, + repeat: 1, + }}, + }, { + name: "abort then delay filters", + cfgs: []*fpb.HTTPFault{{ + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 50, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Unimplemented)}, + }, + }, { + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 80, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(time.Second)}, + }, + }}, + randOutInc: 10, + want: []subcase{{ + name: "50% delay and abort (abort skips delay)", + code: codes.Unimplemented, + randIn: []int{100}, + repeat: 5, + }, { + name: "30% delay, no abort", + randIn: []int{100, 100}, + repeat: 3, + delays: []time.Duration{time.Second}, + }, { + name: "20% success", + randIn: []int{100, 100}, + repeat: 2, + }}, + }} + + fs, nodeID, port, cleanup := clientSetup(t) + defer cleanup() + + for tcNum, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer func() { randIntn = grpcrand.Intn; newTimer = time.NewTimer }() + var intnCalls []int + var newTimerCalls []time.Duration + randOut := 0 + randIntn = func(n int) int { + intnCalls = append(intnCalls, n) + return randOut % n + } + + newTimer = func(d time.Duration) *time.Timer { + newTimerCalls = append(newTimerCalls, d) + return time.NewTimer(0) + } + + serviceName := fmt.Sprintf("myservice%d", tcNum) + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + hcm := new(v3httppb.HttpConnectionManager) + err := ptypes.UnmarshalAny(resources.Listeners[0].GetApiListener().GetApiListener(), hcm) + if err != nil { + t.Fatal(err) + } + routerFilter := hcm.HttpFilters[len(hcm.HttpFilters)-1] + + hcm.HttpFilters = nil + for i, cfg := range tc.cfgs { + hcm.HttpFilters = append(hcm.HttpFilters, e2e.HTTPFilter(fmt.Sprintf("fault%d", i), cfg)) + } + hcm.HttpFilters = append(hcm.HttpFilters, routerFilter) + hcmAny := testutils.MarshalAny(hcm) + resources.Listeners[0].ApiListener.ApiListener = hcmAny + resources.Listeners[0].FilterChains[0].Filters[0].ConfigType = &v3listenerpb.Filter_TypedConfig{TypedConfig: hcmAny} + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := fs.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create a ClientConn and run the test case. + cc, err := grpc.Dial("xds:///"+serviceName, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + count := 0 + for _, want := range tc.want { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if want.repeat == 0 { + t.Fatalf("invalid repeat count") + } + for n := 0; n < want.repeat; n++ { + intnCalls = nil + newTimerCalls = nil + ctx = metadata.NewOutgoingContext(ctx, want.md) + _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)) + t.Logf("%v: RPC %d: err: %v, intnCalls: %v, newTimerCalls: %v", want.name, count, err, intnCalls, newTimerCalls) + if status.Code(err) != want.code || !reflect.DeepEqual(intnCalls, want.randIn) || !reflect.DeepEqual(newTimerCalls, want.delays) { + t.Fatalf("WANTED code: %v, intnCalls: %v, newTimerCalls: %v", want.code, want.randIn, want.delays) + } + randOut += tc.randOutInc + count++ + } + } + }) + } +} + +func (s) TestFaultInjection_MaxActiveFaults(t *testing.T) { + fs, nodeID, port, cleanup := clientSetup(t) + defer cleanup() + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: "myservice", + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + hcm := new(v3httppb.HttpConnectionManager) + err := ptypes.UnmarshalAny(resources.Listeners[0].GetApiListener().GetApiListener(), hcm) + if err != nil { + t.Fatal(err) + } + + defer func() { newTimer = time.NewTimer }() + timers := make(chan *time.Timer, 2) + newTimer = func(d time.Duration) *time.Timer { + t := time.NewTimer(24 * time.Hour) // Will reset to fire. + timers <- t + return t + } + + hcm.HttpFilters = append([]*v3httppb.HttpFilter{ + e2e.HTTPFilter("fault", &fpb.HTTPFault{ + MaxActiveFaults: wrapperspb.UInt32(2), + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 100, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(time.Second)}, + }, + })}, + hcm.HttpFilters...) + hcmAny := testutils.MarshalAny(hcm) + resources.Listeners[0].ApiListener.ApiListener = hcmAny + resources.Listeners[0].FilterChains[0].Filters[0].ConfigType = &v3listenerpb.Filter_TypedConfig{TypedConfig: hcmAny} + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := fs.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create a ClientConn + cc, err := grpc.Dial("xds:///myservice", grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + + streams := make(chan testpb.TestService_FullDuplexCallClient, 5) // startStream() is called 5 times + startStream := func() { + str, err := client.FullDuplexCall(ctx) + if err != nil { + t.Error("RPC error:", err) + } + streams <- str + } + endStream := func() { + str := <-streams + str.CloseSend() + if _, err := str.Recv(); err != io.EOF { + t.Error("stream error:", err) + } + } + releaseStream := func() { + timer := <-timers + timer.Reset(0) + } + + // Start three streams; two should delay. + go startStream() + go startStream() + go startStream() + + // End one of the streams. Ensure the others are blocked on creation. + endStream() + + select { + case <-streams: + t.Errorf("unexpected second stream created before delay expires") + case <-time.After(50 * time.Millisecond): + // Wait a short time to ensure no other streams were started yet. + } + + // Start one more; it should not be blocked. + go startStream() + endStream() + + // Expire one stream's delay; it should be created. + releaseStream() + endStream() + + // Another new stream should delay. + go startStream() + select { + case <-streams: + t.Errorf("unexpected second stream created before delay expires") + case <-time.After(50 * time.Millisecond): + // Wait a short time to ensure no other streams were started yet. + } + + // Expire both pending timers and end the two streams. + releaseStream() + releaseStream() + endStream() + endStream() +} diff --git a/xds/internal/httpfilter/httpfilter.go b/xds/internal/httpfilter/httpfilter.go index 6650241fab7..b4399f9faeb 100644 --- a/xds/internal/httpfilter/httpfilter.go +++ b/xds/internal/httpfilter/httpfilter.go @@ -50,6 +50,9 @@ type Filter interface { // not accept a custom type. The resulting FilterConfig will later be // passed to Build. ParseFilterConfigOverride(proto.Message) (FilterConfig, error) + // IsTerminal returns whether this Filter is terminal or not (i.e. it must + // be last filter in the filter chain). + IsTerminal() bool } // ClientInterceptorBuilder constructs a Client Interceptor. If this type is @@ -65,9 +68,6 @@ type ClientInterceptorBuilder interface { // ServerInterceptorBuilder constructs a Server Interceptor. If this type is // implemented by a Filter, it is capable of working on a server. -// -// Server side filters are not currently supported, but this interface is -// defined for clarity. type ServerInterceptorBuilder interface { // BuildServerInterceptor uses the FilterConfigs produced above to produce // an HTTP filter interceptor for servers. config will always be non-nil, @@ -94,6 +94,11 @@ func Register(b Filter) { } } +// UnregisterForTesting unregisters the HTTP Filter for testing purposes. +func UnregisterForTesting(typeURL string) { + delete(m, typeURL) +} + // Get returns the HTTPFilter registered with typeURL. // // If no filter is register with typeURL, nil will be returned. diff --git a/xds/internal/httpfilter/rbac/rbac.go b/xds/internal/httpfilter/rbac/rbac.go new file mode 100644 index 00000000000..e92e2e64421 --- /dev/null +++ b/xds/internal/httpfilter/rbac/rbac.go @@ -0,0 +1,220 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package rbac implements the Envoy RBAC HTTP filter. +package rbac + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/internal/xds/rbac" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/protobuf/types/known/anypb" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + rpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/rbac/v3" +) + +func init() { + if env.RBACSupport { + httpfilter.Register(builder{}) + } +} + +// RegisterForTesting registers the RBAC HTTP Filter for testing purposes, regardless +// of the RBAC environment variable. This is needed because there is no way to set the RBAC +// environment variable to true in a test before init() in this package is run. +func RegisterForTesting() { + httpfilter.Register(builder{}) +} + +// UnregisterForTesting unregisters the RBAC HTTP Filter for testing purposes. This is needed because +// there is no way to unregister the HTTP Filter after registering it solely for testing purposes using +// rbac.RegisterForTesting() +func UnregisterForTesting() { + for _, typeURL := range builder.TypeURLs(builder{}) { + httpfilter.UnregisterForTesting(typeURL) + } +} + +type builder struct { +} + +type config struct { + httpfilter.FilterConfig + config *rpb.RBAC +} + +func (builder) TypeURLs() []string { + return []string{ + "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBAC", + "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBACPerRoute", + } +} + +// Parsing is the same for the base config and the override config. +func parseConfig(rbacCfg *rpb.RBAC) (httpfilter.FilterConfig, error) { + // All the validation logic described in A41. + for _, policy := range rbacCfg.GetRules().GetPolicies() { + // "Policy.condition and Policy.checked_condition must cause a + // validation failure if present." - A41 + if policy.Condition != nil { + return nil, errors.New("rbac: Policy.condition is present") + } + if policy.CheckedCondition != nil { + return nil, errors.New("rbac: policy.CheckedCondition is present") + } + + // "It is also a validation failure if Permission or Principal has a + // header matcher for a grpc- prefixed header name or :scheme." - A41 + for _, principal := range policy.Principals { + if principal.GetHeader() != nil { + name := principal.GetHeader().GetName() + if name == ":scheme" || strings.HasPrefix(name, "grpc-") { + return nil, fmt.Errorf("rbac: principal header matcher for %v is :scheme or starts with grpc", name) + } + } + } + for _, permission := range policy.Permissions { + if permission.GetHeader() != nil { + name := permission.GetHeader().GetName() + if name == ":scheme" || strings.HasPrefix(name, "grpc-") { + return nil, fmt.Errorf("rbac: permission header matcher for %v is :scheme or starts with grpc", name) + } + } + } + } + return config{config: rbacCfg}, nil +} + +func (builder) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + if cfg == nil { + return nil, fmt.Errorf("rbac: nil configuration message provided") + } + any, ok := cfg.(*anypb.Any) + if !ok { + return nil, fmt.Errorf("rbac: error parsing config %v: unknown type %T", cfg, cfg) + } + msg := new(rpb.RBAC) + if err := ptypes.UnmarshalAny(any, msg); err != nil { + return nil, fmt.Errorf("rbac: error parsing config %v: %v", cfg, err) + } + return parseConfig(msg) +} + +func (builder) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + if override == nil { + return nil, fmt.Errorf("rbac: nil configuration message provided") + } + any, ok := override.(*anypb.Any) + if !ok { + return nil, fmt.Errorf("rbac: error parsing override config %v: unknown type %T", override, override) + } + msg := new(rpb.RBACPerRoute) + if err := ptypes.UnmarshalAny(any, msg); err != nil { + return nil, fmt.Errorf("rbac: error parsing override config %v: %v", override, err) + } + return parseConfig(msg.Rbac) +} + +func (builder) IsTerminal() bool { + return false +} + +var _ httpfilter.ServerInterceptorBuilder = builder{} + +// BuildServerInterceptor is an optional interface builder implements in order +// to signify it works server side. +func (builder) BuildServerInterceptor(cfg httpfilter.FilterConfig, override httpfilter.FilterConfig) (resolver.ServerInterceptor, error) { + if cfg == nil { + return nil, fmt.Errorf("rbac: nil config provided") + } + + c, ok := cfg.(config) + if !ok { + return nil, fmt.Errorf("rbac: incorrect config type provided (%T): %v", cfg, cfg) + } + + if override != nil { + // override completely replaces the listener configuration; but we + // still validate the listener config type. + c, ok = override.(config) + if !ok { + return nil, fmt.Errorf("rbac: incorrect override config type provided (%T): %v", override, override) + } + } + + icfg := c.config + // "If absent, no enforcing RBAC policy will be applied" - RBAC + // Documentation for Rules field. + if icfg.Rules == nil { + return nil, nil + } + + // "At this time, if the RBAC.action is Action.LOG then the policy will be + // completely ignored, as if RBAC was not configurated." - A41 + if icfg.Rules.Action == v3rbacpb.RBAC_LOG { + return nil, nil + } + + // "Envoy aliases :authority and Host in its header map implementation, so + // they should be treated equivalent for the RBAC matchers; there must be no + // behavior change depending on which of the two header names is used in the + // RBAC policy." - A41. Loop through config's principals and policies, change + // any header matcher with value "host" to :authority", as that is what + // grpc-go shifts both headers to in transport layer. + for _, policy := range icfg.Rules.GetPolicies() { + for _, principal := range policy.Principals { + if principal.GetHeader() != nil { + name := principal.GetHeader().GetName() + if name == "host" { + principal.GetHeader().Name = ":authority" + } + } + } + for _, permission := range policy.Permissions { + if permission.GetHeader() != nil { + name := permission.GetHeader().GetName() + if name == "host" { + permission.GetHeader().Name = ":authority" + } + } + } + } + + ce, err := rbac.NewChainEngine([]*v3rbacpb.RBAC{icfg.Rules}) + if err != nil { + return nil, fmt.Errorf("error constructing matching engine: %v", err) + } + return &interceptor{chainEngine: ce}, nil +} + +type interceptor struct { + chainEngine *rbac.ChainEngine +} + +func (i *interceptor) AllowRPC(ctx context.Context) error { + return i.chainEngine.IsAuthorized(ctx) +} diff --git a/xds/internal/httpfilter/router/router.go b/xds/internal/httpfilter/router/router.go index 26e3acb5a4f..1ac6518170f 100644 --- a/xds/internal/httpfilter/router/router.go +++ b/xds/internal/httpfilter/router/router.go @@ -73,7 +73,14 @@ func (builder) ParseFilterConfigOverride(override proto.Message) (httpfilter.Fil return config{}, nil } -var _ httpfilter.ClientInterceptorBuilder = builder{} +func (builder) IsTerminal() bool { + return true +} + +var ( + _ httpfilter.ClientInterceptorBuilder = builder{} + _ httpfilter.ServerInterceptorBuilder = builder{} +) func (builder) BuildClientInterceptor(cfg, override httpfilter.FilterConfig) (iresolver.ClientInterceptor, error) { if _, ok := cfg.(config); !ok { @@ -88,6 +95,18 @@ func (builder) BuildClientInterceptor(cfg, override httpfilter.FilterConfig) (ir return nil, nil } +func (builder) BuildServerInterceptor(cfg, override httpfilter.FilterConfig) (iresolver.ServerInterceptor, error) { + if _, ok := cfg.(config); !ok { + return nil, fmt.Errorf("router: incorrect config type provided (%T): %v", cfg, cfg) + } + if override != nil { + return nil, fmt.Errorf("router: unexpected override configuration specified: %v", override) + } + // The gRPC router is currently unimplemented on the server side. So we + // return a nil HTTPFilter, which will not be invoked. + return nil, nil +} + // The gRPC router filter does not currently support any configuration. Verify // type only. type config struct { diff --git a/xds/internal/internal.go b/xds/internal/internal.go index e4284ee02e0..0cccd382410 100644 --- a/xds/internal/internal.go +++ b/xds/internal/internal.go @@ -22,6 +22,8 @@ package internal import ( "encoding/json" "fmt" + + "google.golang.org/grpc/resolver" ) // LocalityID is xds.Locality without XXX fields, so it can be used as map @@ -53,3 +55,19 @@ func LocalityIDFromString(s string) (ret LocalityID, _ error) { } return ret, nil } + +type localityKeyType string + +const localityKey = localityKeyType("grpc.xds.internal.address.locality") + +// GetLocalityID returns the locality ID of addr. +func GetLocalityID(addr resolver.Address) LocalityID { + path, _ := addr.Attributes.Value(localityKey).(LocalityID) + return path +} + +// SetLocalityID sets locality ID in addr to l. +func SetLocalityID(addr resolver.Address, l LocalityID) resolver.Address { + addr.Attributes = addr.Attributes.WithValues(localityKey, l) + return addr +} diff --git a/xds/internal/resolver/matcher.go b/xds/internal/resolver/matcher.go deleted file mode 100644 index b7b5f3db0e3..00000000000 --- a/xds/internal/resolver/matcher.go +++ /dev/null @@ -1,161 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package resolver - -import ( - "fmt" - "regexp" - "strings" - - "google.golang.org/grpc/internal/grpcrand" - "google.golang.org/grpc/internal/grpcutil" - iresolver "google.golang.org/grpc/internal/resolver" - "google.golang.org/grpc/metadata" - xdsclient "google.golang.org/grpc/xds/internal/client" -) - -func routeToMatcher(r *xdsclient.Route) (*compositeMatcher, error) { - var pathMatcher pathMatcherInterface - switch { - case r.Regex != nil: - re, err := regexp.Compile(*r.Regex) - if err != nil { - return nil, fmt.Errorf("failed to compile regex %q", *r.Regex) - } - pathMatcher = newPathRegexMatcher(re) - case r.Path != nil: - pathMatcher = newPathExactMatcher(*r.Path, r.CaseInsensitive) - case r.Prefix != nil: - pathMatcher = newPathPrefixMatcher(*r.Prefix, r.CaseInsensitive) - default: - return nil, fmt.Errorf("illegal route: missing path_matcher") - } - - var headerMatchers []headerMatcherInterface - for _, h := range r.Headers { - var matcherT headerMatcherInterface - switch { - case h.ExactMatch != nil && *h.ExactMatch != "": - matcherT = newHeaderExactMatcher(h.Name, *h.ExactMatch) - case h.RegexMatch != nil && *h.RegexMatch != "": - re, err := regexp.Compile(*h.RegexMatch) - if err != nil { - return nil, fmt.Errorf("failed to compile regex %q, skipping this matcher", *h.RegexMatch) - } - matcherT = newHeaderRegexMatcher(h.Name, re) - case h.PrefixMatch != nil && *h.PrefixMatch != "": - matcherT = newHeaderPrefixMatcher(h.Name, *h.PrefixMatch) - case h.SuffixMatch != nil && *h.SuffixMatch != "": - matcherT = newHeaderSuffixMatcher(h.Name, *h.SuffixMatch) - case h.RangeMatch != nil: - matcherT = newHeaderRangeMatcher(h.Name, h.RangeMatch.Start, h.RangeMatch.End) - case h.PresentMatch != nil: - matcherT = newHeaderPresentMatcher(h.Name, *h.PresentMatch) - default: - return nil, fmt.Errorf("illegal route: missing header_match_specifier") - } - if h.InvertMatch != nil && *h.InvertMatch { - matcherT = newInvertMatcher(matcherT) - } - headerMatchers = append(headerMatchers, matcherT) - } - - var fractionMatcher *fractionMatcher - if r.Fraction != nil { - fractionMatcher = newFractionMatcher(*r.Fraction) - } - return newCompositeMatcher(pathMatcher, headerMatchers, fractionMatcher), nil -} - -// compositeMatcher.match returns true if all matchers return true. -type compositeMatcher struct { - pm pathMatcherInterface - hms []headerMatcherInterface - fm *fractionMatcher -} - -func newCompositeMatcher(pm pathMatcherInterface, hms []headerMatcherInterface, fm *fractionMatcher) *compositeMatcher { - return &compositeMatcher{pm: pm, hms: hms, fm: fm} -} - -func (a *compositeMatcher) match(info iresolver.RPCInfo) bool { - if a.pm != nil && !a.pm.match(info.Method) { - return false - } - - // Call headerMatchers even if md is nil, because routes may match - // non-presence of some headers. - var md metadata.MD - if info.Context != nil { - md, _ = metadata.FromOutgoingContext(info.Context) - if extraMD, ok := grpcutil.ExtraMetadata(info.Context); ok { - md = metadata.Join(md, extraMD) - // Remove all binary headers. They are hard to match with. May need - // to add back if asked by users. - for k := range md { - if strings.HasSuffix(k, "-bin") { - delete(md, k) - } - } - } - } - for _, m := range a.hms { - if !m.match(md) { - return false - } - } - - if a.fm != nil && !a.fm.match() { - return false - } - return true -} - -func (a *compositeMatcher) String() string { - var ret string - if a.pm != nil { - ret += a.pm.String() - } - for _, m := range a.hms { - ret += m.String() - } - if a.fm != nil { - ret += a.fm.String() - } - return ret -} - -type fractionMatcher struct { - fraction int64 // real fraction is fraction/1,000,000. -} - -func newFractionMatcher(fraction uint32) *fractionMatcher { - return &fractionMatcher{fraction: int64(fraction)} -} - -var grpcrandInt63n = grpcrand.Int63n - -func (fm *fractionMatcher) match() bool { - t := grpcrandInt63n(1000000) - return t <= fm.fraction -} - -func (fm *fractionMatcher) String() string { - return fmt.Sprintf("fraction:%v", fm.fraction) -} diff --git a/xds/internal/resolver/matcher_header.go b/xds/internal/resolver/matcher_header.go deleted file mode 100644 index 05a92788d7b..00000000000 --- a/xds/internal/resolver/matcher_header.go +++ /dev/null @@ -1,188 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package resolver - -import ( - "fmt" - "regexp" - "strconv" - "strings" - - "google.golang.org/grpc/metadata" -) - -type headerMatcherInterface interface { - match(metadata.MD) bool - String() string -} - -// mdValuesFromOutgoingCtx retrieves metadata from context. If there are -// multiple values, the values are concatenated with "," (comma and no space). -// -// All header matchers only match against the comma-concatenated string. -func mdValuesFromOutgoingCtx(md metadata.MD, key string) (string, bool) { - vs, ok := md[key] - if !ok { - return "", false - } - return strings.Join(vs, ","), true -} - -type headerExactMatcher struct { - key string - exact string -} - -func newHeaderExactMatcher(key, exact string) *headerExactMatcher { - return &headerExactMatcher{key: key, exact: exact} -} - -func (hem *headerExactMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hem.key) - if !ok { - return false - } - return v == hem.exact -} - -func (hem *headerExactMatcher) String() string { - return fmt.Sprintf("headerExact:%v:%v", hem.key, hem.exact) -} - -type headerRegexMatcher struct { - key string - re *regexp.Regexp -} - -func newHeaderRegexMatcher(key string, re *regexp.Regexp) *headerRegexMatcher { - return &headerRegexMatcher{key: key, re: re} -} - -func (hrm *headerRegexMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hrm.key) - if !ok { - return false - } - return hrm.re.MatchString(v) -} - -func (hrm *headerRegexMatcher) String() string { - return fmt.Sprintf("headerRegex:%v:%v", hrm.key, hrm.re.String()) -} - -type headerRangeMatcher struct { - key string - start, end int64 // represents [start, end). -} - -func newHeaderRangeMatcher(key string, start, end int64) *headerRangeMatcher { - return &headerRangeMatcher{key: key, start: start, end: end} -} - -func (hrm *headerRangeMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hrm.key) - if !ok { - return false - } - if i, err := strconv.ParseInt(v, 10, 64); err == nil && i >= hrm.start && i < hrm.end { - return true - } - return false -} - -func (hrm *headerRangeMatcher) String() string { - return fmt.Sprintf("headerRange:%v:[%d,%d)", hrm.key, hrm.start, hrm.end) -} - -type headerPresentMatcher struct { - key string - present bool -} - -func newHeaderPresentMatcher(key string, present bool) *headerPresentMatcher { - return &headerPresentMatcher{key: key, present: present} -} - -func (hpm *headerPresentMatcher) match(md metadata.MD) bool { - vs, ok := mdValuesFromOutgoingCtx(md, hpm.key) - present := ok && len(vs) > 0 - return present == hpm.present -} - -func (hpm *headerPresentMatcher) String() string { - return fmt.Sprintf("headerPresent:%v:%v", hpm.key, hpm.present) -} - -type headerPrefixMatcher struct { - key string - prefix string -} - -func newHeaderPrefixMatcher(key string, prefix string) *headerPrefixMatcher { - return &headerPrefixMatcher{key: key, prefix: prefix} -} - -func (hpm *headerPrefixMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hpm.key) - if !ok { - return false - } - return strings.HasPrefix(v, hpm.prefix) -} - -func (hpm *headerPrefixMatcher) String() string { - return fmt.Sprintf("headerPrefix:%v:%v", hpm.key, hpm.prefix) -} - -type headerSuffixMatcher struct { - key string - suffix string -} - -func newHeaderSuffixMatcher(key string, suffix string) *headerSuffixMatcher { - return &headerSuffixMatcher{key: key, suffix: suffix} -} - -func (hsm *headerSuffixMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hsm.key) - if !ok { - return false - } - return strings.HasSuffix(v, hsm.suffix) -} - -func (hsm *headerSuffixMatcher) String() string { - return fmt.Sprintf("headerSuffix:%v:%v", hsm.key, hsm.suffix) -} - -type invertMatcher struct { - m headerMatcherInterface -} - -func newInvertMatcher(m headerMatcherInterface) *invertMatcher { - return &invertMatcher{m: m} -} - -func (i *invertMatcher) match(md metadata.MD) bool { - return !i.m.match(md) -} - -func (i *invertMatcher) String() string { - return fmt.Sprintf("invert{%s}", i.m) -} diff --git a/xds/internal/resolver/serviceconfig.go b/xds/internal/resolver/serviceconfig.go index 7c1ec853e4c..ddf699f938b 100644 --- a/xds/internal/resolver/serviceconfig.go +++ b/xds/internal/resolver/serviceconfig.go @@ -22,18 +22,25 @@ import ( "context" "encoding/json" "fmt" + "math/bits" + "strings" "sync/atomic" "time" + xxhash "github.com/cespare/xxhash/v2" "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpcrand" iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/wrr" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/grpc/xds/internal/balancer/clustermanager" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/env" + "google.golang.org/grpc/xds/internal/balancer/ringhash" "google.golang.org/grpc/xds/internal/httpfilter" "google.golang.org/grpc/xds/internal/httpfilter/router" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -76,7 +83,7 @@ func (r *xdsResolver) pruneActiveClusters() { // serviceConfigJSON produces a service config in JSON format representing all // the clusters referenced in activeClusters. This includes clusters with zero // references, so they must be pruned first. -func serviceConfigJSON(activeClusters map[string]*clusterInfo) (string, error) { +func serviceConfigJSON(activeClusters map[string]*clusterInfo) ([]byte, error) { // Generate children (all entries in activeClusters). children := make(map[string]xdsChildConfig) for cluster := range activeClusters { @@ -93,14 +100,16 @@ func serviceConfigJSON(activeClusters map[string]*clusterInfo) (string, error) { bs, err := json.Marshal(sc) if err != nil { - return "", fmt.Errorf("failed to marshal json: %v", err) + return nil, fmt.Errorf("failed to marshal json: %v", err) } - return string(bs), nil + return bs, nil } type virtualHost struct { // map from filter name to its config httpFilterConfigOverride map[string]httpfilter.FilterConfig + // retry policy present in virtual host + retryConfig *xdsclient.RetryConfig } // routeCluster holds information about a cluster as referenced by a route. @@ -111,11 +120,13 @@ type routeCluster struct { } type route struct { - m *compositeMatcher // converted from route matchers - clusters wrr.WRR // holds *routeCluster entries + m *xdsclient.CompositeMatcher // converted from route matchers + clusters wrr.WRR // holds *routeCluster entries maxStreamDuration time.Duration // map from filter name to its config httpFilterConfigOverride map[string]httpfilter.FilterConfig + retryConfig *xdsclient.RetryConfig + hashPolicies []*xdsclient.HashPolicy } func (r route) String() string { @@ -139,7 +150,7 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP var rt *route // Loop through routes in order and select first match. for _, r := range cs.routes { - if r.m.match(rpcInfo) { + if r.m.Match(rpcInfo) { rt = &r break } @@ -161,9 +172,15 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP return nil, err } + lbCtx := clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name) + // Request Hashes are only applicable for a Ring Hash LB. + if env.RingHashSupport { + lbCtx = ringhash.SetRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies)) + } + config := &iresolver.RPCConfig{ - // Communicate to the LB policy the chosen cluster. - Context: clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name), + // Communicate to the LB policy the chosen cluster and request hash, if Ring Hash LB policy. + Context: lbCtx, OnCommitted: func() { // When the RPC is committed, the cluster is no longer required. // Decrease its ref. @@ -179,13 +196,83 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP Interceptor: interceptor, } - if env.TimeoutSupport && rt.maxStreamDuration != 0 { + if rt.maxStreamDuration != 0 { config.MethodConfig.Timeout = &rt.maxStreamDuration } + if rt.retryConfig != nil { + config.MethodConfig.RetryPolicy = retryConfigToPolicy(rt.retryConfig) + } else if cs.virtualHost.retryConfig != nil { + config.MethodConfig.RetryPolicy = retryConfigToPolicy(cs.virtualHost.retryConfig) + } return config, nil } +func retryConfigToPolicy(config *xdsclient.RetryConfig) *serviceconfig.RetryPolicy { + return &serviceconfig.RetryPolicy{ + MaxAttempts: int(config.NumRetries) + 1, + InitialBackoff: config.RetryBackoff.BaseInterval, + MaxBackoff: config.RetryBackoff.MaxInterval, + BackoffMultiplier: 2, + RetryableStatusCodes: config.RetryOn, + } +} + +func (cs *configSelector) generateHash(rpcInfo iresolver.RPCInfo, hashPolicies []*xdsclient.HashPolicy) uint64 { + var hash uint64 + var generatedHash bool + for _, policy := range hashPolicies { + var policyHash uint64 + var generatedPolicyHash bool + switch policy.HashPolicyType { + case xdsclient.HashPolicyTypeHeader: + md, ok := metadata.FromOutgoingContext(rpcInfo.Context) + if !ok { + continue + } + values := md.Get(policy.HeaderName) + // If the header isn't present, no-op. + if len(values) == 0 { + continue + } + joinedValues := strings.Join(values, ",") + if policy.Regex != nil { + joinedValues = policy.Regex.ReplaceAllString(joinedValues, policy.RegexSubstitution) + } + policyHash = xxhash.Sum64String(joinedValues) + generatedHash = true + generatedPolicyHash = true + case xdsclient.HashPolicyTypeChannelID: + // Hash the ClientConn pointer which logically uniquely + // identifies the client. + policyHash = xxhash.Sum64String(fmt.Sprintf("%p", &cs.r.cc)) + generatedHash = true + generatedPolicyHash = true + } + + // Deterministically combine the hash policies. Rotating prevents + // duplicate hash policies from cancelling each other out and preserves + // the 64 bits of entropy. + if generatedPolicyHash { + hash = bits.RotateLeft64(hash, 1) + hash = hash ^ policyHash + } + + // If terminal policy and a hash has already been generated, ignore the + // rest of the policies and use that hash already generated. + if policy.Terminal && generatedHash { + break + } + } + + if generatedHash { + return hash + } + // If no generated hash return a random long. In the grand scheme of things + // this logically will map to choosing a random backend to route request to. + return grpcrand.Uint64() +} + func (cs *configSelector) newInterceptor(rt *route, cluster *routeCluster) (iresolver.ClientInterceptor, error) { if len(cs.httpFilterConfig) == 0 { return nil, nil @@ -254,8 +341,11 @@ var newWRR = wrr.NewRandom // r.activeClusters for previously-unseen clusters. func (r *xdsResolver) newConfigSelector(su serviceUpdate) (*configSelector, error) { cs := &configSelector{ - r: r, - virtualHost: virtualHost{httpFilterConfigOverride: su.virtualHost.HTTPFilterConfigOverride}, + r: r, + virtualHost: virtualHost{ + httpFilterConfigOverride: su.virtualHost.HTTPFilterConfigOverride, + retryConfig: su.virtualHost.RetryConfig, + }, routes: make([]route, len(su.virtualHost.Routes)), clusters: make(map[string]*clusterInfo), httpFilterConfig: su.ldsConfig.httpFilterConfig, @@ -282,7 +372,7 @@ func (r *xdsResolver) newConfigSelector(su serviceUpdate) (*configSelector, erro cs.routes[i].clusters = clusters var err error - cs.routes[i].m, err = routeToMatcher(rt) + cs.routes[i].m, err = xdsclient.RouteToMatcher(rt) if err != nil { return nil, err } @@ -293,6 +383,8 @@ func (r *xdsResolver) newConfigSelector(su serviceUpdate) (*configSelector, erro } cs.routes[i].httpFilterConfigOverride = rt.HTTPFilterConfigOverride + cs.routes[i].retryConfig = rt.RetryConfig + cs.routes[i].hashPolicies = rt.HashPolicies } // Account for this config selector's clusters. Do this after no further diff --git a/xds/internal/resolver/serviceconfig_test.go b/xds/internal/resolver/serviceconfig_test.go index 1e253841e80..a1a48944dc4 100644 --- a/xds/internal/resolver/serviceconfig_test.go +++ b/xds/internal/resolver/serviceconfig_test.go @@ -19,10 +19,17 @@ package resolver import ( + "context" + "fmt" + "regexp" "testing" + xxhash "github.com/cespare/xxhash/v2" "github.com/google/go-cmp/cmp" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/metadata" _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // To parse LB config + "google.golang.org/grpc/xds/internal/xdsclient" ) func (s) TestPruneActiveClusters(t *testing.T) { @@ -41,3 +48,70 @@ func (s) TestPruneActiveClusters(t *testing.T) { t.Fatalf("r.activeClusters = %v; want %v\nDiffs: %v", r.activeClusters, want, d) } } + +func (s) TestGenerateRequestHash(t *testing.T) { + cs := &configSelector{ + r: &xdsResolver{ + cc: &testClientConn{}, + }, + } + tests := []struct { + name string + hashPolicies []*xdsclient.HashPolicy + requestHashWant uint64 + rpcInfo iresolver.RPCInfo + }{ + // TestGenerateRequestHashHeaders tests generating request hashes for + // hash policies that specify to hash headers. + { + name: "test-generate-request-hash-headers", + hashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("/products") }(), // Will replace /products with /new-products, to test find and replace functionality. + RegexSubstitution: "/new-products", + }}, + requestHashWant: xxhash.Sum64String("/new-products"), + rpcInfo: iresolver.RPCInfo{ + Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(":path", "/products")), + Method: "/some-method", + }, + }, + // TestGenerateHashChannelID tests generating request hashes for hash + // policies that specify to hash something that uniquely identifies the + // ClientConn (the pointer). + { + name: "test-generate-request-hash-channel-id", + hashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeChannelID, + }}, + requestHashWant: xxhash.Sum64String(fmt.Sprintf("%p", &cs.r.cc)), + rpcInfo: iresolver.RPCInfo{}, + }, + // TestGenerateRequestHashEmptyString tests generating request hashes + // for hash policies that specify to hash headers and replace empty + // strings in the headers. + { + name: "test-generate-request-hash-empty-string", + hashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("") }(), + RegexSubstitution: "e", + }}, + requestHashWant: xxhash.Sum64String("eaebece"), + rpcInfo: iresolver.RPCInfo{ + Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(":path", "abc")), + Method: "/some-method", + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + requestHashGot := cs.generateHash(test.rpcInfo, test.hashPolicies) + if requestHashGot != test.requestHashWant { + t.Fatalf("requestHashGot = %v, requestHashWant = %v", requestHashGot, test.requestHashWant) + } + }) + } +} diff --git a/xds/internal/resolver/watch_service.go b/xds/internal/resolver/watch_service.go index 913ac4ced15..da0bf95f3b9 100644 --- a/xds/internal/resolver/watch_service.go +++ b/xds/internal/resolver/watch_service.go @@ -20,12 +20,12 @@ package resolver import ( "fmt" - "strings" "sync" "time" "google.golang.org/grpc/internal/grpclog" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/xds/internal/xdsclient" ) // serviceUpdate contains information received from the LDS/RDS responses which @@ -53,7 +53,7 @@ type ldsConfig struct { // Note that during race (e.g. an xDS response is received while the user is // calling cancel()), there's a small window where the callback can be called // after the watcher is canceled. The caller needs to handle this case. -func watchService(c xdsClientInterface, serviceName string, cb func(serviceUpdate, error), logger *grpclog.PrefixLogger) (cancel func()) { +func watchService(c xdsclient.XDSClient, serviceName string, cb func(serviceUpdate, error), logger *grpclog.PrefixLogger) (cancel func()) { w := &serviceUpdateWatcher{ logger: logger, c: c, @@ -69,7 +69,7 @@ func watchService(c xdsClientInterface, serviceName string, cb func(serviceUpdat // callback at the right time. type serviceUpdateWatcher struct { logger *grpclog.PrefixLogger - c xdsClientInterface + c xdsclient.XDSClient serviceName string ldsCancel func() serviceCb func(serviceUpdate, error) @@ -82,7 +82,7 @@ type serviceUpdateWatcher struct { } func (w *serviceUpdateWatcher) handleLDSResp(update xdsclient.ListenerUpdate, err error) { - w.logger.Infof("received LDS update: %+v, err: %v", update, err) + w.logger.Infof("received LDS update: %+v, err: %v", pretty.ToJSON(update), err) w.mu.Lock() defer w.mu.Unlock() if w.closed { @@ -110,13 +110,37 @@ func (w *serviceUpdateWatcher) handleLDSResp(update xdsclient.ListenerUpdate, er httpFilterConfig: update.HTTPFilters, } + if update.InlineRouteConfig != nil { + // If there was an RDS watch, cancel it. + w.rdsName = "" + if w.rdsCancel != nil { + w.rdsCancel() + w.rdsCancel = nil + } + + // Handle the inline RDS update as if it's from an RDS watch. + w.updateVirtualHostsFromRDS(*update.InlineRouteConfig) + return + } + + // RDS name from update is not an empty string, need RDS to fetch the + // routes. + if w.rdsName == update.RouteConfigName { // If the new RouteConfigName is same as the previous, don't cancel and // restart the RDS watch. // // If the route name did change, then we must wait until the first RDS // update before reporting this LDS config. - w.serviceCb(w.lastUpdate, nil) + if w.lastUpdate.virtualHost != nil { + // We want to send an update with the new fields from the new LDS + // (e.g. max stream duration), and old fields from the the previous + // RDS. + // + // But note that this should only happen when virtual host is set, + // which means an RDS was received. + w.serviceCb(w.lastUpdate, nil) + } return } w.rdsName = update.RouteConfigName @@ -126,8 +150,20 @@ func (w *serviceUpdateWatcher) handleLDSResp(update xdsclient.ListenerUpdate, er w.rdsCancel = w.c.WatchRouteConfig(update.RouteConfigName, w.handleRDSResp) } +func (w *serviceUpdateWatcher) updateVirtualHostsFromRDS(update xdsclient.RouteConfigUpdate) { + matchVh := xdsclient.FindBestMatchingVirtualHost(w.serviceName, update.VirtualHosts) + if matchVh == nil { + // No matching virtual host found. + w.serviceCb(serviceUpdate{}, fmt.Errorf("no matching virtual host found for %q", w.serviceName)) + return + } + + w.lastUpdate.virtualHost = matchVh + w.serviceCb(w.lastUpdate, nil) +} + func (w *serviceUpdateWatcher) handleRDSResp(update xdsclient.RouteConfigUpdate, err error) { - w.logger.Infof("received RDS update: %+v, err: %v", update, err) + w.logger.Infof("received RDS update: %+v, err: %v", pretty.ToJSON(update), err) w.mu.Lock() defer w.mu.Unlock() if w.closed { @@ -142,16 +178,7 @@ func (w *serviceUpdateWatcher) handleRDSResp(update xdsclient.RouteConfigUpdate, w.serviceCb(serviceUpdate{}, err) return } - - matchVh := findBestMatchingVirtualHost(w.serviceName, update.VirtualHosts) - if matchVh == nil { - // No matching virtual host found. - w.serviceCb(serviceUpdate{}, fmt.Errorf("no matching virtual host found for %q", w.serviceName)) - return - } - - w.lastUpdate.virtualHost = matchVh - w.serviceCb(w.lastUpdate, nil) + w.updateVirtualHostsFromRDS(update) } func (w *serviceUpdateWatcher) close() { @@ -164,97 +191,3 @@ func (w *serviceUpdateWatcher) close() { w.rdsCancel = nil } } - -type domainMatchType int - -const ( - domainMatchTypeInvalid domainMatchType = iota - domainMatchTypeUniversal - domainMatchTypePrefix - domainMatchTypeSuffix - domainMatchTypeExact -) - -// Exact > Suffix > Prefix > Universal > Invalid. -func (t domainMatchType) betterThan(b domainMatchType) bool { - return t > b -} - -func matchTypeForDomain(d string) domainMatchType { - if d == "" { - return domainMatchTypeInvalid - } - if d == "*" { - return domainMatchTypeUniversal - } - if strings.HasPrefix(d, "*") { - return domainMatchTypeSuffix - } - if strings.HasSuffix(d, "*") { - return domainMatchTypePrefix - } - if strings.Contains(d, "*") { - return domainMatchTypeInvalid - } - return domainMatchTypeExact -} - -func match(domain, host string) (domainMatchType, bool) { - switch typ := matchTypeForDomain(domain); typ { - case domainMatchTypeInvalid: - return typ, false - case domainMatchTypeUniversal: - return typ, true - case domainMatchTypePrefix: - // abc.* - return typ, strings.HasPrefix(host, strings.TrimSuffix(domain, "*")) - case domainMatchTypeSuffix: - // *.123 - return typ, strings.HasSuffix(host, strings.TrimPrefix(domain, "*")) - case domainMatchTypeExact: - return typ, domain == host - default: - return domainMatchTypeInvalid, false - } -} - -// findBestMatchingVirtualHost returns the virtual host whose domains field best -// matches host -// -// The domains field support 4 different matching pattern types: -// - Exact match -// - Suffix match (e.g. “*ABC”) -// - Prefix match (e.g. “ABC*) -// - Universal match (e.g. “*”) -// -// The best match is defined as: -// - A match is better if it’s matching pattern type is better -// - Exact match > suffix match > prefix match > universal match -// - If two matches are of the same pattern type, the longer match is better -// - This is to compare the length of the matching pattern, e.g. “*ABCDE” > -// “*ABC” -func findBestMatchingVirtualHost(host string, vHosts []*xdsclient.VirtualHost) *xdsclient.VirtualHost { - var ( - matchVh *xdsclient.VirtualHost - matchType = domainMatchTypeInvalid - matchLen int - ) - for _, vh := range vHosts { - for _, domain := range vh.Domains { - typ, matched := match(domain, host) - if typ == domainMatchTypeInvalid { - // The rds response is invalid. - return nil - } - if matchType.betterThan(typ) || matchType == typ && matchLen >= len(domain) || !matched { - // The previous match has better type, or the previous match has - // better length, or this domain isn't a match. - continue - } - matchVh = vh - matchType = typ - matchLen = len(domain) - } - } - return matchVh -} diff --git a/xds/internal/resolver/watch_service_test.go b/xds/internal/resolver/watch_service_test.go index 705a3d35ae1..1bf65c4d450 100644 --- a/xds/internal/resolver/watch_service_test.go +++ b/xds/internal/resolver/watch_service_test.go @@ -27,57 +27,11 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" "google.golang.org/protobuf/proto" ) -func (s) TestMatchTypeForDomain(t *testing.T) { - tests := []struct { - d string - want domainMatchType - }{ - {d: "", want: domainMatchTypeInvalid}, - {d: "*", want: domainMatchTypeUniversal}, - {d: "bar.*", want: domainMatchTypePrefix}, - {d: "*.abc.com", want: domainMatchTypeSuffix}, - {d: "foo.bar.com", want: domainMatchTypeExact}, - {d: "foo.*.com", want: domainMatchTypeInvalid}, - } - for _, tt := range tests { - if got := matchTypeForDomain(tt.d); got != tt.want { - t.Errorf("matchTypeForDomain(%q) = %v, want %v", tt.d, got, tt.want) - } - } -} - -func (s) TestMatch(t *testing.T) { - tests := []struct { - name string - domain string - host string - wantTyp domainMatchType - wantMatched bool - }{ - {name: "invalid-empty", domain: "", host: "", wantTyp: domainMatchTypeInvalid, wantMatched: false}, - {name: "invalid", domain: "a.*.b", host: "", wantTyp: domainMatchTypeInvalid, wantMatched: false}, - {name: "universal", domain: "*", host: "abc.com", wantTyp: domainMatchTypeUniversal, wantMatched: true}, - {name: "prefix-match", domain: "abc.*", host: "abc.123", wantTyp: domainMatchTypePrefix, wantMatched: true}, - {name: "prefix-no-match", domain: "abc.*", host: "abcd.123", wantTyp: domainMatchTypePrefix, wantMatched: false}, - {name: "suffix-match", domain: "*.123", host: "abc.123", wantTyp: domainMatchTypeSuffix, wantMatched: true}, - {name: "suffix-no-match", domain: "*.123", host: "abc.1234", wantTyp: domainMatchTypeSuffix, wantMatched: false}, - {name: "exact-match", domain: "foo.bar", host: "foo.bar", wantTyp: domainMatchTypeExact, wantMatched: true}, - {name: "exact-no-match", domain: "foo.bar.com", host: "foo.bar", wantTyp: domainMatchTypeExact, wantMatched: false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if gotTyp, gotMatched := match(tt.domain, tt.host); gotTyp != tt.wantTyp || gotMatched != tt.wantMatched { - t.Errorf("match() = %v, %v, want %v, %v", gotTyp, gotMatched, tt.wantTyp, tt.wantMatched) - } - }) - } -} - func (s) TestFindBestMatchingVirtualHost(t *testing.T) { var ( oneExactMatch = &xdsclient.VirtualHost{ @@ -121,7 +75,7 @@ func (s) TestFindBestMatchingVirtualHost(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := findBestMatchingVirtualHost(tt.host, tt.vHosts); !cmp.Equal(got, tt.want, cmp.Comparer(proto.Equal)) { + if got := xdsclient.FindBestMatchingVirtualHost(tt.host, tt.vHosts); !cmp.Equal(got, tt.want, cmp.Comparer(proto.Equal)) { t.Errorf("findBestMatchingxdsclient.VirtualHost() = %v, want %v", got, tt.want) } }) @@ -167,7 +121,7 @@ func (s) TestServiceWatch(t *testing.T) { waitForWatchRouteConfig(ctx, t, xdsC, routeStr) wantUpdate := serviceUpdate{virtualHost: &xdsclient.VirtualHost{Domains: []string{"target"}, Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}}} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -185,7 +139,7 @@ func (s) TestServiceWatch(t *testing.T) { WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}, }}, }} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -221,7 +175,7 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) { waitForWatchRouteConfig(ctx, t, xdsC, routeStr) wantUpdate := serviceUpdate{virtualHost: &xdsclient.VirtualHost{Domains: []string{"target"}, Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}}} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -235,14 +189,14 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) { // Another LDS update with a different RDS_name. xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr + "2"}, nil) - if err := xdsC.WaitForCancelRouteConfigWatch(ctx); err != nil { + if _, err := xdsC.WaitForCancelRouteConfigWatch(ctx); err != nil { t.Fatalf("wait for cancel route watch failed: %v, want nil", err) } waitForWatchRouteConfig(ctx, t, xdsC, routeStr+"2") // RDS update for the new name. wantUpdate2 := serviceUpdate{virtualHost: &xdsclient.VirtualHost{Domains: []string{"target"}, Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster + "2": {Weight: 1}}}}}} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback(routeStr+"2", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -277,7 +231,7 @@ func (s) TestServiceWatchLDSUpdateMaxStreamDuration(t *testing.T) { WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}}, ldsConfig: ldsConfig{maxStreamDuration: time.Second}, } - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -301,7 +255,7 @@ func (s) TestServiceWatchLDSUpdateMaxStreamDuration(t *testing.T) { Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster + "2": {Weight: 1}}}}, }} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -335,7 +289,7 @@ func (s) TestServiceNotCancelRDSOnSameLDSUpdate(t *testing.T) { Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, }} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -352,7 +306,85 @@ func (s) TestServiceNotCancelRDSOnSameLDSUpdate(t *testing.T) { xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr}, nil) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelRouteConfigWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelRouteConfigWatch(sCtx); err != context.DeadlineExceeded { + t.Fatalf("wait for cancel route watch failed: %v, want nil", err) + } +} + +// TestServiceWatchInlineRDS covers the cases switching between: +// - LDS update contains RDS name to watch +// - LDS update contains inline RDS resource +func (s) TestServiceWatchInlineRDS(t *testing.T) { + serviceUpdateCh := testutils.NewChannel() + xdsC := fakeclient.NewClient() + cancelWatch := watchService(xdsC, targetStr, func(update serviceUpdate, err error) { + serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err}) + }, nil) + defer cancelWatch() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // First LDS update is LDS with RDS name to watch. + waitForWatchListener(ctx, t, xdsC, targetStr) + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr}, nil) + waitForWatchRouteConfig(ctx, t, xdsC, routeStr) + wantUpdate := serviceUpdate{virtualHost: &xdsclient.VirtualHost{Domains: []string{"target"}, Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}}} + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{ + { + Domains: []string{targetStr}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + }, + }, + }, nil) + if err := verifyServiceUpdate(ctx, serviceUpdateCh, wantUpdate); err != nil { + t.Fatal(err) + } + + // Switch LDS resp to a LDS with inline RDS resource + wantVirtualHosts2 := &xdsclient.VirtualHost{Domains: []string{"target"}, + Routes: []*xdsclient.Route{{ + Path: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}, + }}, + } + wantUpdate2 := serviceUpdate{virtualHost: wantVirtualHosts2} + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{InlineRouteConfig: &xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{wantVirtualHosts2}, + }}, nil) + // This inline RDS resource should cause the RDS watch to be canceled. + if _, err := xdsC.WaitForCancelRouteConfigWatch(ctx); err != nil { t.Fatalf("wait for cancel route watch failed: %v, want nil", err) } + if err := verifyServiceUpdate(ctx, serviceUpdateCh, wantUpdate2); err != nil { + t.Fatal(err) + } + + // Switch LDS update back to LDS with RDS name to watch. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr}, nil) + waitForWatchRouteConfig(ctx, t, xdsC, routeStr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{ + { + Domains: []string{targetStr}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + }, + }, + }, nil) + if err := verifyServiceUpdate(ctx, serviceUpdateCh, wantUpdate); err != nil { + t.Fatal(err) + } + + // Switch LDS resp to a LDS with inline RDS resource again. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{InlineRouteConfig: &xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{wantVirtualHosts2}, + }}, nil) + // This inline RDS resource should cause the RDS watch to be canceled. + if _, err := xdsC.WaitForCancelRouteConfigWatch(ctx); err != nil { + t.Fatalf("wait for cancel route watch failed: %v, want nil", err) + } + if err := verifyServiceUpdate(ctx, serviceUpdateCh, wantUpdate2); err != nil { + t.Fatal(err) + } } diff --git a/xds/internal/resolver/xds_resolver.go b/xds/internal/resolver/xds_resolver.go index d8c09db69b5..19ee01773e8 100644 --- a/xds/internal/resolver/xds_resolver.go +++ b/xds/internal/resolver/xds_resolver.go @@ -26,23 +26,35 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/xds/internal/client/bootstrap" - + "google.golang.org/grpc/internal/pretty" iresolver "google.golang.org/grpc/internal/resolver" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/xdsclient" ) const xdsScheme = "xds" +// NewBuilder creates a new xds resolver builder using a specific xds bootstrap +// config, so tests can use multiple xds clients in different ClientConns at +// the same time. +func NewBuilder(config []byte) (resolver.Builder, error) { + return &xdsResolverBuilder{ + newXDSClient: func() (xdsclient.XDSClient, error) { + return xdsclient.NewClientWithBootstrapContents(config) + }, + }, nil +} + // For overriding in unittests. -var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } +var newXDSClient = func() (xdsclient.XDSClient, error) { return xdsclient.New() } func init() { resolver.Register(&xdsResolverBuilder{}) } -type xdsResolverBuilder struct{} +type xdsResolverBuilder struct { + newXDSClient func() (xdsclient.XDSClient, error) +} // Build helps implement the resolver.Builder interface. // @@ -59,6 +71,11 @@ func (b *xdsResolverBuilder) Build(t resolver.Target, cc resolver.ClientConn, op r.logger = prefixLogger((r)) r.logger.Infof("Creating resolver for target: %+v", t) + newXDSClient := newXDSClient + if b.newXDSClient != nil { + newXDSClient = b.newXDSClient + } + client, err := newXDSClient() if err != nil { return nil, fmt.Errorf("xds: failed to create xds-client: %v", err) @@ -100,15 +117,6 @@ func (*xdsResolverBuilder) Scheme() string { return xdsScheme } -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the resolver. This will be faked out in unittests. -type xdsClientInterface interface { - WatchListener(serviceName string, cb func(xdsclient.ListenerUpdate, error)) func() - WatchRouteConfig(routeName string, cb func(xdsclient.RouteConfigUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - // suWithError wraps the ServiceUpdate and error received through a watch API // callback, so that it can pushed onto the update channel as a single entity. type suWithError struct { @@ -130,7 +138,7 @@ type xdsResolver struct { logger *grpclog.PrefixLogger // The underlying xdsClient which performs all xDS requests and responses. - client xdsClientInterface + client xdsclient.XDSClient // A channel for the watch API callback to write service updates on to. The // updates are read by the run goroutine and passed on to the ClientConn. updateCh chan suWithError @@ -171,13 +179,13 @@ func (r *xdsResolver) sendNewServiceConfig(cs *configSelector) bool { r.cc.ReportError(err) return false } - r.logger.Infof("Received update on resource %v from xds-client %p, generated service config: %v", r.target.Endpoint, r.client, sc) + r.logger.Infof("Received update on resource %v from xds-client %p, generated service config: %v", r.target.Endpoint, r.client, pretty.FormatJSON(sc)) // Send the update to the ClientConn. state := iresolver.SetConfigSelector(resolver.State{ - ServiceConfig: r.cc.ParseServiceConfig(sc), + ServiceConfig: r.cc.ParseServiceConfig(string(sc)), }, cs) - r.cc.UpdateState(state) + r.cc.UpdateState(xdsclient.SetClient(state, r.client)) return true } diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go index 53ea17042aa..90e6c1d4db0 100644 --- a/xds/internal/resolver/xds_resolver_test.go +++ b/xds/internal/resolver/xds_resolver_test.go @@ -26,6 +26,7 @@ import ( "testing" "time" + xxhash "github.com/cespare/xxhash/v2" "github.com/google/go-cmp/cmp" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" @@ -36,19 +37,20 @@ import ( iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/wrr" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/status" _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // To parse LB config "google.golang.org/grpc/xds/internal/balancer/clustermanager" - "google.golang.org/grpc/xds/internal/client" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" - "google.golang.org/grpc/xds/internal/env" + "google.golang.org/grpc/xds/internal/balancer/ringhash" "google.golang.org/grpc/xds/internal/httpfilter" "google.golang.org/grpc/xds/internal/httpfilter/router" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( @@ -88,8 +90,9 @@ type testClientConn struct { errorCh *testutils.Channel } -func (t *testClientConn) UpdateState(s resolver.State) { +func (t *testClientConn) UpdateState(s resolver.State) error { t.stateCh.Send(s) + return nil } func (t *testClientConn) ReportError(err error) { @@ -112,19 +115,19 @@ func newTestClientConn() *testClientConn { func (s) TestResolverBuilder(t *testing.T) { tests := []struct { name string - xdsClientFunc func() (xdsClientInterface, error) + xdsClientFunc func() (xdsclient.XDSClient, error) wantErr bool }{ { name: "simple-good", - xdsClientFunc: func() (xdsClientInterface, error) { + xdsClientFunc: func() (xdsclient.XDSClient, error) { return fakeclient.NewClient(), nil }, wantErr: false, }, { name: "newXDSClient-throws-error", - xdsClientFunc: func() (xdsClientInterface, error) { + xdsClientFunc: func() (xdsclient.XDSClient, error) { return nil, errors.New("newXDSClient-throws-error") }, wantErr: true, @@ -165,7 +168,7 @@ func (s) TestResolverBuilder_xdsCredsBootstrapMismatch(t *testing.T) { // Fake out the xdsClient creation process by providing a fake, which does // not have any certificate provider configuration. oldClientMaker := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { fc := fakeclient.NewClient() fc.SetBootstrapConfig(&bootstrap.Config{}) return fc, nil @@ -192,7 +195,7 @@ func (s) TestResolverBuilder_xdsCredsBootstrapMismatch(t *testing.T) { } type setupOpts struct { - xdsClientFunc func() (xdsClientInterface, error) + xdsClientFunc func() (xdsclient.XDSClient, error) } func testSetup(t *testing.T, opts setupOpts) (*xdsResolver, *testClientConn, func()) { @@ -252,7 +255,7 @@ func waitForWatchRouteConfig(ctx context.Context, t *testing.T, xdsC *fakeclient func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() @@ -265,11 +268,11 @@ func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { // Call the watchAPI callback after closing the resolver, and make sure no // update is triggerred on the ClientConn. xdsR.Close() - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, }, }, }, nil) @@ -279,17 +282,29 @@ func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { } } +// TestXDSResolverCloseClosesXDSClient tests that the XDS resolver's Close +// method closes the XDS client. +func (s) TestXDSResolverCloseClosesXDSClient(t *testing.T) { + xdsC := fakeclient.NewClient() + xdsR, _, cancel := testSetup(t, setupOpts{ + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, + }) + defer cancel() + xdsR.Close() + if !xdsC.Closed.HasFired() { + t.Fatalf("xds client not closed by xds resolver Close method") + } +} + // TestXDSResolverBadServiceUpdate tests the case the xdsClient returns a bad // service update. func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -300,7 +315,7 @@ func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { // Invoke the watchAPI callback with a bad service update and wait for the // ReportError method to be called on the ClientConn. suErr := errors.New("bad serviceupdate") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != nil || gotErrVal != suErr { t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr) @@ -312,12 +327,10 @@ func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -332,7 +345,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { wantClusters map[string]bool }{ { - routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, wantJSON: `{"loadBalancingConfig":[{ "xds_cluster_manager_experimental":{ "children":{ @@ -344,7 +357,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { wantClusters: map[string]bool{"test-cluster-1": true}, }, { - routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ + routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ "cluster_1": {Weight: 75}, "cluster_2": {Weight: 25}, }}}, @@ -368,7 +381,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { wantClusters: map[string]bool{"cluster_1": true, "cluster_2": true}, }, { - routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ + routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ "cluster_1": {Weight: 75}, "cluster_2": {Weight: 25}, }}}, @@ -391,7 +404,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { } { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -404,7 +417,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { defer cancel() gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -443,12 +456,77 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { } } +// TestXDSResolverRequestHash tests a case where a resolver receives a RouteConfig update +// with a HashPolicy specifying to generate a hash. The configSelector generated should +// successfully generate a Hash. +func (s) TestXDSResolverRequestHash(t *testing.T) { + oldRH := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRH }() + + xdsC := fakeclient.NewClient() + xdsR, tcc, cancel := testSetup(t, setupOpts{ + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, + }) + defer xdsR.Close() + defer cancel() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + waitForWatchListener(ctx, t, xdsC, targetStr) + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr, HTTPFilters: routerFilterList}, nil) + waitForWatchRouteConfig(ctx, t, xdsC, routeStr) + // Invoke watchAPI callback with a good service update (with hash policies + // specified) and wait for UpdateState method to be called on ClientConn. + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{ + { + Domains: []string{targetStr}, + Routes: []*xdsclient.Route{{ + Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{ + "cluster_1": {Weight: 75}, + "cluster_2": {Weight: 25}, + }, + HashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeHeader, + HeaderName: ":path", + }}, + }}, + }, + }, + }, nil) + + ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotState, err := tcc.stateCh.Receive(ctx) + if err != nil { + t.Fatalf("Error waiting for UpdateState to be called: %v", err) + } + rState := gotState.(resolver.State) + cs := iresolver.GetConfigSelector(rState) + if cs == nil { + t.Error("received nil config selector") + } + // Selecting a config when there was a hash policy specified in the route + // that will be selected should put a request hash in the config's context. + res, err := cs.SelectConfig(iresolver.RPCInfo{Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(":path", "/products"))}) + if err != nil { + t.Fatalf("Unexpected error from cs.SelectConfig(_): %v", err) + } + requestHashGot := ringhash.GetRequestHashForTesting(res.Context) + requestHashWant := xxhash.Sum64String("/products") + if requestHashGot != requestHashWant { + t.Fatalf("requestHashGot = %v, requestHashWant = %v", requestHashGot, requestHashWant) + } +} + // TestXDSResolverRemovedWithRPCs tests the case where a config selector sends // an empty update to the resolver after the resource is removed. func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() defer xdsR.Close() @@ -461,18 +539,18 @@ func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, }, }, }, nil) gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -492,10 +570,10 @@ func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { // Delete the resource suErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "resource removed error") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if _, err = tcc.stateCh.Receive(ctx); err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } // "Finish the RPC"; this could cause a panic if the resolver doesn't @@ -508,7 +586,7 @@ func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { func (s) TestXDSResolverRemovedResource(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() defer xdsR.Close() @@ -521,11 +599,11 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, }, }, }, nil) @@ -541,7 +619,7 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -571,10 +649,10 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { // Delete the resource. The channel should receive a service config with the // original cluster but with an erroring config selector. suErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "resource removed error") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if gotState, err = tcc.stateCh.Receive(ctx); err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState = gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -599,7 +677,7 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { // In the meantime, an empty ServiceConfig update should have been sent. if gotState, err = tcc.stateCh.Receive(ctx); err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState = gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -616,12 +694,10 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { func (s) TestXDSResolverWRR(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -634,11 +710,11 @@ func (s) TestXDSResolverWRR(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ "A": {Weight: 5}, "B": {Weight: 10}, }}}, @@ -648,7 +724,7 @@ func (s) TestXDSResolverWRR(t *testing.T) { gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -676,15 +752,12 @@ func (s) TestXDSResolverWRR(t *testing.T) { } func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { - defer func(old bool) { env.TimeoutSupport = old }(env.TimeoutSupport) xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -697,11 +770,11 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{ + Routes: []*xdsclient.Route{{ Prefix: newStringP("/foo"), WeightedClusters: map[string]xdsclient.WeightedCluster{"A": {Weight: 1}}, MaxStreamDuration: newDurationP(5 * time.Second), @@ -719,7 +792,7 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -732,35 +805,25 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { } testCases := []struct { - name string - method string - timeoutSupport bool - want *time.Duration + name string + method string + want *time.Duration }{{ - name: "RDS setting", - method: "/foo/method", - timeoutSupport: true, - want: newDurationP(5 * time.Second), - }, { - name: "timeout support disabled", - method: "/foo/method", - timeoutSupport: false, - want: nil, + name: "RDS setting", + method: "/foo/method", + want: newDurationP(5 * time.Second), }, { - name: "explicit zero in RDS; ignore LDS", - method: "/bar/method", - timeoutSupport: true, - want: nil, + name: "explicit zero in RDS; ignore LDS", + method: "/bar/method", + want: nil, }, { - name: "no config in RDS; fallback to LDS", - method: "/baz/method", - timeoutSupport: true, - want: newDurationP(time.Second), + name: "no config in RDS; fallback to LDS", + method: "/baz/method", + want: newDurationP(time.Second), }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - env.TimeoutSupport = tc.timeoutSupport req := iresolver.RPCInfo{ Method: tc.method, Context: context.Background(), @@ -784,12 +847,10 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -799,18 +860,18 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, }, }, }, nil) gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -849,27 +910,28 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { // Perform TWO updates to ensure the old config selector does not hold a // reference to test-cluster-1. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, }, }, }, nil) - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + tcc.stateCh.Receive(ctx) // Ignore the first update. + + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, }, }, }, nil) - tcc.stateCh.Receive(ctx) // Ignore the first update gotState, err = tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState = gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -897,17 +959,17 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { // test-cluster-1. res.OnCommitted() - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, }, }, }, nil) gotState, err = tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState = gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -934,12 +996,10 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -950,7 +1010,7 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { // Invoke the watchAPI callback with a bad service update and wait for the // ReportError method to be called on the ClientConn. suErr := errors.New("bad serviceupdate") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != nil || gotErrVal != suErr { t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr) @@ -958,17 +1018,17 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, }, }, }, nil) gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -978,7 +1038,7 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { // Invoke the watchAPI callback with a bad service update and wait for the // ReportError method to be called on the ClientConn. suErr2 := errors.New("bad serviceupdate 2") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr2) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr2) if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != nil || gotErrVal != suErr2 { t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr2) } @@ -990,12 +1050,10 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -1006,7 +1064,7 @@ func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { // Invoke the watchAPI callback with a bad service update and wait for the // ReportError method to be called on the ClientConn. suErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "resource removed error") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != context.DeadlineExceeded { t.Fatalf("ClientConn.ReportError() received %v, %v, want channel recv timeout", gotErrVal, gotErr) @@ -1016,7 +1074,7 @@ func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { defer cancel() gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) wantParsedConfig := internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)("{}") @@ -1030,6 +1088,46 @@ func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { } } +// TestXDSResolverMultipleLDSUpdates tests the case where two LDS updates with +// the same RDS name to watch are received without an RDS in between. Those LDS +// updates shouldn't trigger service config update. +// +// This test case also makes sure the resolver doesn't panic. +func (s) TestXDSResolverMultipleLDSUpdates(t *testing.T) { + xdsC := fakeclient.NewClient() + xdsR, tcc, cancel := testSetup(t, setupOpts{ + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, + }) + defer xdsR.Close() + defer cancel() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + waitForWatchListener(ctx, t, xdsC, targetStr) + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr, HTTPFilters: routerFilterList}, nil) + waitForWatchRouteConfig(ctx, t, xdsC, routeStr) + defer replaceRandNumGenerator(0)() + + // Send a new LDS update, with the same fields. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr, HTTPFilters: routerFilterList}, nil) + ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer cancel() + // Should NOT trigger a state update. + gotState, err := tcc.stateCh.Receive(ctx) + if err == nil { + t.Fatalf("ClientConn.UpdateState received %v, want timeout error", gotState) + } + + // Send a new LDS update, with the same RDS name, but different fields. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr, MaxStreamDuration: time.Second, HTTPFilters: routerFilterList}, nil) + ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer cancel() + gotState, err = tcc.stateCh.Receive(ctx) + if err == nil { + t.Fatalf("ClientConn.UpdateState received %v, want timeout error", gotState) + } +} + type filterBuilder struct { httpfilter.Filter // embedded as we do not need to implement registry / parsing in this test. path *[]string @@ -1173,12 +1271,10 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { t.Run(tc.name, func(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -1197,11 +1293,11 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{ + Routes: []*xdsclient.Route{{ Prefix: newStringP("1"), WeightedClusters: map[string]xdsclient.WeightedCluster{ "A": {Weight: 1}, "B": {Weight: 1}, @@ -1220,7 +1316,7 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -1295,13 +1391,13 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { func replaceRandNumGenerator(start int64) func() { nextInt := start - grpcrandInt63n = func(int64) (ret int64) { + xdsclient.RandInt63n = func(int64) (ret int64) { ret = nextInt nextInt++ return } return func() { - grpcrandInt63n = grpcrand.Int63n + xdsclient.RandInt63n = grpcrand.Int63n } } diff --git a/xds/internal/server/conn_wrapper.go b/xds/internal/server/conn_wrapper.go new file mode 100644 index 00000000000..dd0374dc88e --- /dev/null +++ b/xds/internal/server/conn_wrapper.go @@ -0,0 +1,165 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + "google.golang.org/grpc/credentials/tls/certprovider" + xdsinternal "google.golang.org/grpc/internal/credentials/xds" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +// connWrapper is a thin wrapper around a net.Conn returned by Accept(). It +// provides the following additional functionality: +// 1. A way to retrieve the configured deadline. This is required by the +// ServerHandshake() method of the xdsCredentials when it attempts to read +// key material from the certificate providers. +// 2. Implements the XDSHandshakeInfo() method used by the xdsCredentials to +// retrieve the configured certificate providers. +// 3. xDS filter_chain matching logic to select appropriate security +// configuration for the incoming connection. +type connWrapper struct { + net.Conn + + // The specific filter chain picked for handling this connection. + filterChain *xdsclient.FilterChain + + // A reference fo the listenerWrapper on which this connection was accepted. + parent *listenerWrapper + + // The certificate providers created for this connection. + rootProvider, identityProvider certprovider.Provider + + // The connection deadline as configured by the grpc.Server on the rawConn + // that is returned by a call to Accept(). This is set to the connection + // timeout value configured by the user (or to a default value) before + // initiating the transport credential handshake, and set to zero after + // completing the HTTP2 handshake. + deadlineMu sync.Mutex + deadline time.Time + + // The virtual hosts with matchable routes and instantiated HTTP Filters per + // route. + virtualHosts []xdsclient.VirtualHostWithInterceptors +} + +// VirtualHosts returns the virtual hosts to be used for server side routing. +func (c *connWrapper) VirtualHosts() []xdsclient.VirtualHostWithInterceptors { + return c.virtualHosts +} + +// SetDeadline makes a copy of the passed in deadline and forwards the call to +// the underlying rawConn. +func (c *connWrapper) SetDeadline(t time.Time) error { + c.deadlineMu.Lock() + c.deadline = t + c.deadlineMu.Unlock() + return c.Conn.SetDeadline(t) +} + +// GetDeadline returns the configured deadline. This will be invoked by the +// ServerHandshake() method of the XdsCredentials, which needs a deadline to +// pass to the certificate provider. +func (c *connWrapper) GetDeadline() time.Time { + c.deadlineMu.Lock() + t := c.deadline + c.deadlineMu.Unlock() + return t +} + +// XDSHandshakeInfo returns a HandshakeInfo with appropriate security +// configuration for this connection. This method is invoked by the +// ServerHandshake() method of the XdsCredentials. +func (c *connWrapper) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) { + // Ideally this should never happen, since xdsCredentials are the only ones + // which will invoke this method at handshake time. But to be on the safe + // side, we avoid acting on the security configuration received from the + // control plane when the user has not configured the use of xDS + // credentials, by checking the value of this flag. + if !c.parent.xdsCredsInUse { + return nil, errors.New("user has not configured xDS credentials") + } + + if c.filterChain.SecurityCfg == nil { + // If the security config is empty, this means that the control plane + // did not provide any security configuration and therefore we should + // return an empty HandshakeInfo here so that the xdsCreds can use the + // configured fallback credentials. + return xdsinternal.NewHandshakeInfo(nil, nil), nil + } + + cpc := c.parent.xdsC.BootstrapConfig().CertProviderConfigs + // Identity provider name is mandatory on the server-side, and this is + // enforced when the resource is received at the XDSClient layer. + secCfg := c.filterChain.SecurityCfg + ip, err := buildProviderFunc(cpc, secCfg.IdentityInstanceName, secCfg.IdentityCertName, true, false) + if err != nil { + return nil, err + } + // Root provider name is optional and required only when doing mTLS. + var rp certprovider.Provider + if instance, cert := secCfg.RootInstanceName, secCfg.RootCertName; instance != "" { + rp, err = buildProviderFunc(cpc, instance, cert, false, true) + if err != nil { + return nil, err + } + } + c.identityProvider = ip + c.rootProvider = rp + + xdsHI := xdsinternal.NewHandshakeInfo(c.rootProvider, c.identityProvider) + xdsHI.SetRequireClientCert(secCfg.RequireClientCert) + return xdsHI, nil +} + +// Close closes the providers and the underlying connection. +func (c *connWrapper) Close() error { + if c.identityProvider != nil { + c.identityProvider.Close() + } + if c.rootProvider != nil { + c.rootProvider.Close() + } + return c.Conn.Close() +} + +func buildProviderFunc(configs map[string]*certprovider.BuildableConfig, instanceName, certName string, wantIdentity, wantRoot bool) (certprovider.Provider, error) { + cfg, ok := configs[instanceName] + if !ok { + return nil, fmt.Errorf("certificate provider instance %q not found in bootstrap file", instanceName) + } + provider, err := cfg.Build(certprovider.BuildOptions{ + CertName: certName, + WantIdentity: wantIdentity, + WantRoot: wantRoot, + }) + if err != nil { + // This error is not expected since the bootstrap process parses the + // config and makes sure that it is acceptable to the plugin. Still, it + // is possible that the plugin parses the config successfully, but its + // Build() method errors out. + return nil, fmt.Errorf("failed to get security plugin instance (%+v): %v", cfg, err) + } + return provider, nil +} diff --git a/xds/internal/server/listener_wrapper.go b/xds/internal/server/listener_wrapper.go new file mode 100644 index 00000000000..99c9a753230 --- /dev/null +++ b/xds/internal/server/listener_wrapper.go @@ -0,0 +1,442 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package server contains internal server-side functionality used by the public +// facing xds package. +package server + +import ( + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + "unsafe" + + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + internalbackoff "google.golang.org/grpc/internal/backoff" + internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" +) + +var ( + logger = grpclog.Component("xds") + + // Backoff strategy for temporary errors received from Accept(). If this + // needs to be configurable, we can inject it through ListenerWrapperParams. + bs = internalbackoff.Exponential{Config: backoff.Config{ + BaseDelay: 5 * time.Millisecond, + Multiplier: 2.0, + MaxDelay: 1 * time.Second, + }} + backoffFunc = bs.Backoff +) + +// ServingModeCallback is the callback that users can register to get notified +// about the server's serving mode changes. The callback is invoked with the +// address of the listener and its new mode. The err parameter is set to a +// non-nil error if the server has transitioned into not-serving mode. +type ServingModeCallback func(addr net.Addr, mode connectivity.ServingMode, err error) + +// DrainCallback is the callback that an xDS-enabled server registers to get +// notified about updates to the Listener configuration. The server is expected +// to gracefully shutdown existing connections, thereby forcing clients to +// reconnect and have the new configuration applied to the newly created +// connections. +type DrainCallback func(addr net.Addr) + +func prefixLogger(p *listenerWrapper) *internalgrpclog.PrefixLogger { + return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[xds-server-listener %p] ", p)) +} + +// XDSClient wraps the methods on the XDSClient which are required by +// the listenerWrapper. +type XDSClient interface { + WatchListener(string, func(xdsclient.ListenerUpdate, error)) func() + WatchRouteConfig(string, func(xdsclient.RouteConfigUpdate, error)) func() + BootstrapConfig() *bootstrap.Config +} + +// ListenerWrapperParams wraps parameters required to create a listenerWrapper. +type ListenerWrapperParams struct { + // Listener is the net.Listener passed by the user that is to be wrapped. + Listener net.Listener + // ListenerResourceName is the xDS Listener resource to request. + ListenerResourceName string + // XDSCredsInUse specifies whether or not the user expressed interest to + // receive security configuration from the control plane. + XDSCredsInUse bool + // XDSClient provides the functionality from the XDSClient required here. + XDSClient XDSClient + // ModeCallback is the callback to invoke when the serving mode changes. + ModeCallback ServingModeCallback + // DrainCallback is the callback to invoke when the Listener gets a LDS + // update. + DrainCallback DrainCallback +} + +// NewListenerWrapper creates a new listenerWrapper with params. It returns a +// net.Listener and a channel which is written to, indicating that the former is +// ready to be passed to grpc.Serve(). +// +// Only TCP listeners are supported. +func NewListenerWrapper(params ListenerWrapperParams) (net.Listener, <-chan struct{}) { + lw := &listenerWrapper{ + Listener: params.Listener, + name: params.ListenerResourceName, + xdsCredsInUse: params.XDSCredsInUse, + xdsC: params.XDSClient, + modeCallback: params.ModeCallback, + drainCallback: params.DrainCallback, + isUnspecifiedAddr: params.Listener.Addr().(*net.TCPAddr).IP.IsUnspecified(), + + closed: grpcsync.NewEvent(), + goodUpdate: grpcsync.NewEvent(), + ldsUpdateCh: make(chan ldsUpdateWithError, 1), + rdsUpdateCh: make(chan rdsHandlerUpdate, 1), + } + lw.logger = prefixLogger(lw) + + // Serve() verifies that Addr() returns a valid TCPAddr. So, it is safe to + // ignore the error from SplitHostPort(). + lisAddr := lw.Listener.Addr().String() + lw.addr, lw.port, _ = net.SplitHostPort(lisAddr) + + lw.rdsHandler = newRDSHandler(lw.xdsC, lw.rdsUpdateCh) + + cancelWatch := lw.xdsC.WatchListener(lw.name, lw.handleListenerUpdate) + lw.logger.Infof("Watch started on resource name %v", lw.name) + lw.cancelWatch = func() { + cancelWatch() + lw.logger.Infof("Watch cancelled on resource name %v", lw.name) + } + go lw.run() + return lw, lw.goodUpdate.Done() +} + +type ldsUpdateWithError struct { + update xdsclient.ListenerUpdate + err error +} + +// listenerWrapper wraps the net.Listener associated with the listening address +// passed to Serve(). It also contains all other state associated with this +// particular invocation of Serve(). +type listenerWrapper struct { + net.Listener + logger *internalgrpclog.PrefixLogger + + name string + xdsCredsInUse bool + xdsC XDSClient + cancelWatch func() + modeCallback ServingModeCallback + drainCallback DrainCallback + + // Set to true if the listener is bound to the IP_ANY address (which is + // "0.0.0.0" for IPv4 and "::" for IPv6). + isUnspecifiedAddr bool + // Listening address and port. Used to validate the socket address in the + // Listener resource received from the control plane. + addr, port string + + // This is used to notify that a good update has been received and that + // Serve() can be invoked on the underlying gRPC server. Using an event + // instead of a vanilla channel simplifies the update handler as it need not + // keep track of whether the received update is the first one or not. + goodUpdate *grpcsync.Event + // A small race exists in the XDSClient code between the receipt of an xDS + // response and the user cancelling the associated watch. In this window, + // the registered callback may be invoked after the watch is canceled, and + // the user is expected to work around this. This event signifies that the + // listener is closed (and hence the watch is cancelled), and we drop any + // updates received in the callback if this event has fired. + closed *grpcsync.Event + + // mu guards access to the current serving mode and the filter chains. The + // reason for using an rw lock here is that these fields are read in + // Accept() for all incoming connections, but writes happen rarely (when we + // get a Listener resource update). + mu sync.RWMutex + // Current serving mode. + mode connectivity.ServingMode + // Filter chains received as part of the last good update. + filterChains *xdsclient.FilterChainManager + + // rdsHandler is used for any dynamic RDS resources specified in a LDS + // update. + rdsHandler *rdsHandler + // rdsUpdates are the RDS resources received from the management + // server, keyed on the RouteName of the RDS resource. + rdsUpdates unsafe.Pointer // map[string]xdsclient.RouteConfigUpdate + // ldsUpdateCh is a channel for XDSClient LDS updates. + ldsUpdateCh chan ldsUpdateWithError + // rdsUpdateCh is a channel for XDSClient RDS updates. + rdsUpdateCh chan rdsHandlerUpdate +} + +// Accept blocks on an Accept() on the underlying listener, and wraps the +// returned net.connWrapper with the configured certificate providers. +func (l *listenerWrapper) Accept() (net.Conn, error) { + var retries int + for { + conn, err := l.Listener.Accept() + if err != nil { + // Temporary() method is implemented by certain error types returned + // from the net package, and it is useful for us to not shutdown the + // server in these conditions. The listen queue being full is one + // such case. + if ne, ok := err.(interface{ Temporary() bool }); !ok || !ne.Temporary() { + return nil, err + } + retries++ + timer := time.NewTimer(backoffFunc(retries)) + select { + case <-timer.C: + case <-l.closed.Done(): + timer.Stop() + // Continuing here will cause us to call Accept() again + // which will return a non-temporary error. + continue + } + continue + } + // Reset retries after a successful Accept(). + retries = 0 + + // Since the net.Conn represents an incoming connection, the source and + // destination address can be retrieved from the local address and + // remote address of the net.Conn respectively. + destAddr, ok1 := conn.LocalAddr().(*net.TCPAddr) + srcAddr, ok2 := conn.RemoteAddr().(*net.TCPAddr) + if !ok1 || !ok2 { + // If the incoming connection is not a TCP connection, which is + // really unexpected since we check whether the provided listener is + // a TCP listener in Serve(), we return an error which would cause + // us to stop serving. + return nil, fmt.Errorf("received connection with non-TCP address (local: %T, remote %T)", conn.LocalAddr(), conn.RemoteAddr()) + } + + l.mu.RLock() + if l.mode == connectivity.ServingModeNotServing { + // Close connections as soon as we accept them when we are in + // "not-serving" mode. Since we accept a net.Listener from the user + // in Serve(), we cannot close the listener when we move to + // "not-serving". Closing the connection immediately upon accepting + // is one of the other ways to implement the "not-serving" mode as + // outlined in gRFC A36. + l.mu.RUnlock() + conn.Close() + continue + } + fc, err := l.filterChains.Lookup(xdsclient.FilterChainLookupParams{ + IsUnspecifiedListener: l.isUnspecifiedAddr, + DestAddr: destAddr.IP, + SourceAddr: srcAddr.IP, + SourcePort: srcAddr.Port, + }) + l.mu.RUnlock() + if err != nil { + // When a matching filter chain is not found, we close the + // connection right away, but do not return an error back to + // `grpc.Serve()` from where this Accept() was invoked. Returning an + // error to `grpc.Serve()` causes the server to shutdown. If we want + // to avoid the server from shutting down, we would need to return + // an error type which implements the `Temporary() bool` method, + // which is invoked by `grpc.Serve()` to see if the returned error + // represents a temporary condition. In the case of a temporary + // error, `grpc.Serve()` method sleeps for a small duration and + // therefore ends up blocking all connection attempts during that + // time frame, which is also not ideal for an error like this. + l.logger.Warningf("connection from %s to %s failed to find any matching filter chain", conn.RemoteAddr().String(), conn.LocalAddr().String()) + conn.Close() + continue + } + if !env.RBACSupport { + return &connWrapper{Conn: conn, filterChain: fc, parent: l}, nil + } + var rc xdsclient.RouteConfigUpdate + if fc.InlineRouteConfig != nil { + rc = *fc.InlineRouteConfig + } else { + rcPtr := atomic.LoadPointer(&l.rdsUpdates) + rcuPtr := (*map[string]xdsclient.RouteConfigUpdate)(rcPtr) + // This shouldn't happen, but this error protects against a panic. + if rcuPtr == nil { + return nil, errors.New("route configuration pointer is nil") + } + rcu := *rcuPtr + rc = rcu[fc.RouteConfigName] + } + // The filter chain will construct a usuable route table on each + // connection accept. This is done because preinstantiating every route + // table before it is needed for a connection would potentially lead to + // a lot of cpu time and memory allocated for route tables that will + // never be used. There was also a thought to cache this configuration, + // and reuse it for the next accepted connection. However, this would + // lead to a lot of code complexity (RDS Updates for a given route name + // can come it at any time), and connections aren't accepted too often, + // so this reinstantation of the Route Configuration is an acceptable + // tradeoff for simplicity. + vhswi, err := fc.ConstructUsableRouteConfiguration(rc) + if err != nil { + l.logger.Warningf("route configuration construction: %v", err) + conn.Close() + continue + } + return &connWrapper{Conn: conn, filterChain: fc, parent: l, virtualHosts: vhswi}, nil + } +} + +// Close closes the underlying listener. It also cancels the xDS watch +// registered in Serve() and closes any certificate provider instances created +// based on security configuration received in the LDS response. +func (l *listenerWrapper) Close() error { + l.closed.Fire() + l.Listener.Close() + if l.cancelWatch != nil { + l.cancelWatch() + } + l.rdsHandler.close() + return nil +} + +// run is a long running goroutine which handles all xds updates. LDS and RDS +// push updates onto a channel which is read and acted upon from this goroutine. +func (l *listenerWrapper) run() { + for { + select { + case <-l.closed.Done(): + return + case u := <-l.ldsUpdateCh: + l.handleLDSUpdate(u) + case u := <-l.rdsUpdateCh: + l.handleRDSUpdate(u) + } + } +} + +// handleLDSUpdate is the callback which handles LDS Updates. It writes the +// received update to the update channel, which is picked up by the run +// goroutine. +func (l *listenerWrapper) handleListenerUpdate(update xdsclient.ListenerUpdate, err error) { + if l.closed.HasFired() { + l.logger.Warningf("Resource %q received update: %v with error: %v, after listener was closed", l.name, update, err) + return + } + // Remove any existing entry in ldsUpdateCh and replace with the new one, as the only update + // listener cares about is most recent update. + select { + case <-l.ldsUpdateCh: + default: + } + l.ldsUpdateCh <- ldsUpdateWithError{update: update, err: err} +} + +// handleRDSUpdate handles a full rds update from rds handler. On a successful +// update, the server will switch to ServingModeServing as the full +// configuration (both LDS and RDS) has been received. +func (l *listenerWrapper) handleRDSUpdate(update rdsHandlerUpdate) { + if l.closed.HasFired() { + l.logger.Warningf("RDS received update: %v with error: %v, after listener was closed", update.updates, update.err) + return + } + if update.err != nil { + l.logger.Warningf("Received error for rds names specified in resource %q: %+v", l.name, update.err) + if xdsclient.ErrType(update.err) == xdsclient.ErrorTypeResourceNotFound { + l.switchMode(nil, connectivity.ServingModeNotServing, update.err) + } + // For errors which are anything other than "resource-not-found", we + // continue to use the old configuration. + return + } + atomic.StorePointer(&l.rdsUpdates, unsafe.Pointer(&update.updates)) + + l.switchMode(l.filterChains, connectivity.ServingModeServing, nil) + l.goodUpdate.Fire() +} + +func (l *listenerWrapper) handleLDSUpdate(update ldsUpdateWithError) { + if update.err != nil { + l.logger.Warningf("Received error for resource %q: %+v", l.name, update.err) + if xdsclient.ErrType(update.err) == xdsclient.ErrorTypeResourceNotFound { + l.switchMode(nil, connectivity.ServingModeNotServing, update.err) + } + // For errors which are anything other than "resource-not-found", we + // continue to use the old configuration. + return + } + l.logger.Infof("Received update for resource %q: %+v", l.name, update.update) + + // Make sure that the socket address on the received Listener resource + // matches the address of the net.Listener passed to us by the user. This + // check is done here instead of at the XDSClient layer because of the + // following couple of reasons: + // - XDSClient cannot know the listening address of every listener in the + // system, and hence cannot perform this check. + // - this is a very context-dependent check and only the server has the + // appropriate context to perform this check. + // + // What this means is that the XDSClient has ACKed a resource which can push + // the server into a "not serving" mode. This is not ideal, but this is + // what we have decided to do. See gRPC A36 for more details. + ilc := update.update.InboundListenerCfg + if ilc.Address != l.addr || ilc.Port != l.port { + l.switchMode(nil, connectivity.ServingModeNotServing, fmt.Errorf("address (%s:%s) in Listener update does not match listening address: (%s:%s)", ilc.Address, ilc.Port, l.addr, l.port)) + return + } + + // "Updates to a Listener cause all older connections on that Listener to be + // gracefully shut down with a grace period of 10 minutes for long-lived + // RPC's, such that clients will reconnect and have the updated + // configuration apply." - A36 Note that this is not the same as moving the + // Server's state to ServingModeNotServing. That prevents new connections + // from being accepted, whereas here we simply want the clients to reconnect + // to get the updated configuration. + if env.RBACSupport { + if l.drainCallback != nil { + l.drainCallback(l.Listener.Addr()) + } + } + l.rdsHandler.updateRouteNamesToWatch(ilc.FilterChains.RouteConfigNames) + // If there are no dynamic RDS Configurations still needed to be received + // from the management server, this listener has all the configuration + // needed, and is ready to serve. + if len(ilc.FilterChains.RouteConfigNames) == 0 { + l.switchMode(ilc.FilterChains, connectivity.ServingModeServing, nil) + l.goodUpdate.Fire() + } +} + +func (l *listenerWrapper) switchMode(fcs *xdsclient.FilterChainManager, newMode connectivity.ServingMode, err error) { + l.mu.Lock() + defer l.mu.Unlock() + + l.filterChains = fcs + l.mode = newMode + if l.modeCallback != nil { + l.modeCallback(l.Listener.Addr(), newMode, err) + } + l.logger.Warningf("Listener %q entering mode: %q due to error: %v", l.Addr(), newMode, err) +} diff --git a/xds/internal/server/listener_wrapper_test.go b/xds/internal/server/listener_wrapper_test.go new file mode 100644 index 00000000000..38372936366 --- /dev/null +++ b/xds/internal/server/listener_wrapper_test.go @@ -0,0 +1,484 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "context" + "errors" + "net" + "strconv" + "testing" + "time" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + _ "google.golang.org/grpc/xds/internal/httpfilter/router" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const ( + fakeListenerHost = "0.0.0.0" + fakeListenerPort = 50051 + testListenerResourceName = "lds.target.1.2.3.4:1111" + defaultTestTimeout = 1 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond +) + +var listenerWithRouteConfiguration = &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourcePorts: []uint32{80}, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, +} + +var listenerWithFilterChains = &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourcePorts: []uint32{80}, + }, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, +} + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type tempError struct{} + +func (tempError) Error() string { + return "listenerWrapper test temporary error" +} + +func (tempError) Temporary() bool { + return true +} + +// connAndErr wraps a net.Conn and an error. +type connAndErr struct { + conn net.Conn + err error +} + +// fakeListener allows the user to inject conns returned by Accept(). +type fakeListener struct { + acceptCh chan connAndErr + closeCh *testutils.Channel +} + +func (fl *fakeListener) Accept() (net.Conn, error) { + cne, ok := <-fl.acceptCh + if !ok { + return nil, errors.New("a non-temporary error") + } + return cne.conn, cne.err +} + +func (fl *fakeListener) Close() error { + fl.closeCh.Send(nil) + return nil +} + +func (fl *fakeListener) Addr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(0, 0, 0, 0), + Port: fakeListenerPort, + } +} + +// fakeConn overrides LocalAddr, RemoteAddr and Close methods. +type fakeConn struct { + net.Conn + local, remote net.Addr + closeCh *testutils.Channel +} + +func (fc *fakeConn) LocalAddr() net.Addr { + return fc.local +} + +func (fc *fakeConn) RemoteAddr() net.Addr { + return fc.remote +} + +func (fc *fakeConn) Close() error { + fc.closeCh.Send(nil) + return nil +} + +func newListenerWrapper(t *testing.T) (*listenerWrapper, <-chan struct{}, *fakeclient.Client, *fakeListener, func()) { + t.Helper() + + // Create a listener wrapper with a fake listener and fake XDSClient and + // verify that it extracts the host and port from the passed in listener. + lis := &fakeListener{ + acceptCh: make(chan connAndErr, 1), + closeCh: testutils.NewChannel(), + } + xdsC := fakeclient.NewClient() + lParams := ListenerWrapperParams{ + Listener: lis, + ListenerResourceName: testListenerResourceName, + XDSClient: xdsC, + } + l, readyCh := NewListenerWrapper(lParams) + if l == nil { + t.Fatalf("NewListenerWrapper(%+v) returned nil", lParams) + } + lw, ok := l.(*listenerWrapper) + if !ok { + t.Fatalf("NewListenerWrapper(%+v) returned listener of type %T want *listenerWrapper", lParams, l) + } + if lw.addr != fakeListenerHost || lw.port != strconv.Itoa(fakeListenerPort) { + t.Fatalf("listenerWrapper has host:port %s:%s, want %s:%d", lw.addr, lw.port, fakeListenerHost, fakeListenerPort) + } + return lw, readyCh, xdsC, lis, func() { l.Close() } +} + +func (s) TestNewListenerWrapper(t *testing.T) { + _, readyCh, xdsC, _, cleanup := newListenerWrapper(t) + defer cleanup() + + // Verify that the listener wrapper registers a listener watch for the + // expected Listener resource name. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + name, err := xdsC.WaitForWatchListener(ctx) + if err != nil { + t.Fatalf("error when waiting for a watch on a Listener resource: %v", err) + } + if name != testListenerResourceName { + t.Fatalf("listenerWrapper registered a lds watch on %s, want %s", name, testListenerResourceName) + } + + // Push an error to the listener update handler. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{}, errors.New("bad listener update")) + timer := time.NewTimer(defaultTestShortTimeout) + select { + case <-timer.C: + timer.Stop() + case <-readyCh: + t.Fatalf("ready channel written to after receipt of a bad Listener update") + } + + fcm, err := xdsclient.NewFilterChainManager(listenerWithFilterChains) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + + // Push an update whose address does not match the address to which our + // listener is bound, and verify that the ready channel is not written to. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: "10.0.0.1", + Port: "50051", + FilterChains: fcm, + }}, nil) + timer = time.NewTimer(defaultTestShortTimeout) + select { + case <-timer.C: + timer.Stop() + case <-readyCh: + t.Fatalf("ready channel written to after receipt of a bad Listener update") + } + + // Push a good update, and verify that the ready channel is written to. + // Since there are no dynamic RDS updates needed to be received, the + // ListenerWrapper does not have to wait for anything else before telling + // that it is ready. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: fakeListenerHost, + Port: strconv.Itoa(fakeListenerPort), + FilterChains: fcm, + }}, nil) + + select { + case <-ctx.Done(): + t.Fatalf("timeout waiting for the ready channel to be written to after receipt of a good Listener update") + case <-readyCh: + } +} + +// TestNewListenerWrapperWithRouteUpdate tests the scenario where the listener +// gets built, starts a watch, that watch returns a list of Route Names to +// return, than receives an update from the rds handler. Only after receiving +// the update from the rds handler should it move the server to +// ServingModeServing. +func (s) TestNewListenerWrapperWithRouteUpdate(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + _, readyCh, xdsC, _, cleanup := newListenerWrapper(t) + defer cleanup() + + // Verify that the listener wrapper registers a listener watch for the + // expected Listener resource name. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + name, err := xdsC.WaitForWatchListener(ctx) + if err != nil { + t.Fatalf("error when waiting for a watch on a Listener resource: %v", err) + } + if name != testListenerResourceName { + t.Fatalf("listenerWrapper registered a lds watch on %s, want %s", name, testListenerResourceName) + } + fcm, err := xdsclient.NewFilterChainManager(listenerWithRouteConfiguration) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + + // Push a good update which contains a Filter Chain that specifies dynamic + // RDS Resources that need to be received. This should ping rds handler + // about which rds names to start, which will eventually start a watch on + // xds client for rds name "route-1". + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: fakeListenerHost, + Port: strconv.Itoa(fakeListenerPort), + FilterChains: fcm, + }}, nil) + + // This should start a watch on xds client for rds name "route-1". + routeName, err := xdsC.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("error when waiting for a watch on a Route resource: %v", err) + } + if routeName != "route-1" { + t.Fatalf("listenerWrapper registered a lds watch on %s, want %s", routeName, "route-1") + } + + // This shouldn't invoke good update channel, as has not received rds updates yet. + timer := time.NewTimer(defaultTestShortTimeout) + select { + case <-timer.C: + timer.Stop() + case <-readyCh: + t.Fatalf("ready channel written to without rds configuration specified") + } + + // Invoke rds callback for the started rds watch. This valid rds callback + // should trigger the listener wrapper to fire GoodUpdate, as it has + // received both it's LDS Configuration and also RDS Configuration, + // specified in LDS Configuration. + xdsC.InvokeWatchRouteConfigCallback("route-1", xdsclient.RouteConfigUpdate{}, nil) + + // All of the xDS updates have completed, so can expect to send a ping on + // good update channel. + select { + case <-ctx.Done(): + t.Fatalf("timeout waiting for the ready channel to be written to after receipt of a good rds update") + case <-readyCh: + } +} + +func (s) TestListenerWrapper_Accept(t *testing.T) { + boCh := testutils.NewChannel() + origBackoffFunc := backoffFunc + backoffFunc = func(v int) time.Duration { + boCh.Send(v) + return 0 + } + defer func() { backoffFunc = origBackoffFunc }() + + lw, readyCh, xdsC, lis, cleanup := newListenerWrapper(t) + defer cleanup() + + // Push a good update with a filter chain which accepts local connections on + // 192.168.0.0/16 subnet and port 80. + fcm, err := xdsclient.NewFilterChainManager(listenerWithFilterChains) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: fakeListenerHost, + Port: strconv.Itoa(fakeListenerPort), + FilterChains: fcm, + }}, nil) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + defer close(lis.acceptCh) + select { + case <-ctx.Done(): + t.Fatalf("timeout waiting for the ready channel to be written to after receipt of a good Listener update") + case <-readyCh: + } + + // Push a non-temporary error into Accept(). + nonTempErr := errors.New("a non-temporary error") + lis.acceptCh <- connAndErr{err: nonTempErr} + if _, err := lw.Accept(); err != nonTempErr { + t.Fatalf("listenerWrapper.Accept() returned error: %v, want: %v", err, nonTempErr) + } + + // Invoke Accept() in a goroutine since we expect it to swallow: + // 1. temporary errors returned from the underlying listener + // 2. errors related to finding a matching filter chain for the incoming + // connection. + errCh := testutils.NewChannel() + go func() { + conn, err := lw.Accept() + if err != nil { + errCh.Send(err) + return + } + if _, ok := conn.(*connWrapper); !ok { + errCh.Send(errors.New("listenerWrapper.Accept() returned a Conn of type %T, want *connWrapper")) + return + } + errCh.Send(nil) + }() + + // Push a temporary error into Accept() and verify that it backs off. + lis.acceptCh <- connAndErr{err: tempError{}} + if _, err := boCh.Receive(ctx); err != nil { + t.Fatalf("error when waiting for Accept() to backoff on temporary errors: %v", err) + } + + // Push a fakeConn which does not match any filter chains configured on the + // received Listener resource. Verify that the conn is closed. + fc := &fakeConn{ + local: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 79}, + remote: &net.TCPAddr{IP: net.IPv4(10, 1, 1, 1), Port: 80}, + closeCh: testutils.NewChannel(), + } + lis.acceptCh <- connAndErr{conn: fc} + if _, err := fc.closeCh.Receive(ctx); err != nil { + t.Fatalf("error when waiting for conn to be closed on no filter chain match: %v", err) + } + + // Push a fakeConn which matches the filter chains configured on the + // received Listener resource. Verify that Accept() returns. + fc = &fakeConn{ + local: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 2)}, + remote: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 80}, + closeCh: testutils.NewChannel(), + } + lis.acceptCh <- connAndErr{conn: fc} + if _, err := errCh.Receive(ctx); err != nil { + t.Fatalf("error when waiting for Accept() to return the conn on filter chain match: %v", err) + } +} diff --git a/xds/internal/server/rds_handler.go b/xds/internal/server/rds_handler.go new file mode 100644 index 00000000000..cc676c4ca05 --- /dev/null +++ b/xds/internal/server/rds_handler.go @@ -0,0 +1,133 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "sync" + + "google.golang.org/grpc/xds/internal/xdsclient" +) + +// rdsHandlerUpdate wraps the full RouteConfigUpdate that are dynamically +// queried for a given server side listener. +type rdsHandlerUpdate struct { + updates map[string]xdsclient.RouteConfigUpdate + err error +} + +// rdsHandler handles any RDS queries that need to be started for a given server +// side listeners Filter Chains (i.e. not inline). +type rdsHandler struct { + xdsC XDSClient + + mu sync.Mutex + updates map[string]xdsclient.RouteConfigUpdate + cancels map[string]func() + + // For a rdsHandler update, the only update wrapped listener cares about is + // most recent one, so this channel will be opportunistically drained before + // sending any new updates. + updateChannel chan rdsHandlerUpdate +} + +// newRDSHandler creates a new rdsHandler to watch for RDS resources. +// listenerWrapper updates the list of route names to watch by calling +// updateRouteNamesToWatch() upon receipt of new Listener configuration. +func newRDSHandler(xdsC XDSClient, ch chan rdsHandlerUpdate) *rdsHandler { + return &rdsHandler{ + xdsC: xdsC, + updateChannel: ch, + updates: make(map[string]xdsclient.RouteConfigUpdate), + cancels: make(map[string]func()), + } +} + +// updateRouteNamesToWatch handles a list of route names to watch for a given +// server side listener (if a filter chain specifies dynamic RDS configuration). +// This function handles all the logic with respect to any routes that may have +// been added or deleted as compared to what was previously present. +func (rh *rdsHandler) updateRouteNamesToWatch(routeNamesToWatch map[string]bool) { + rh.mu.Lock() + defer rh.mu.Unlock() + // Add and start watches for any routes for any new routes in + // routeNamesToWatch. + for routeName := range routeNamesToWatch { + if _, ok := rh.cancels[routeName]; !ok { + func(routeName string) { + rh.cancels[routeName] = rh.xdsC.WatchRouteConfig(routeName, func(update xdsclient.RouteConfigUpdate, err error) { + rh.handleRouteUpdate(routeName, update, err) + }) + }(routeName) + } + } + + // Delete and cancel watches for any routes from persisted routeNamesToWatch + // that are no longer present. + for routeName := range rh.cancels { + if _, ok := routeNamesToWatch[routeName]; !ok { + rh.cancels[routeName]() + delete(rh.cancels, routeName) + delete(rh.updates, routeName) + } + } + + // If the full list (determined by length) of updates are now successfully + // updated, the listener is ready to be updated. + if len(rh.updates) == len(rh.cancels) && len(routeNamesToWatch) != 0 { + drainAndPush(rh.updateChannel, rdsHandlerUpdate{updates: rh.updates}) + } +} + +// handleRouteUpdate persists the route config for a given route name, and also +// sends an update to the Listener Wrapper on an error received or if the rds +// handler has a full collection of updates. +func (rh *rdsHandler) handleRouteUpdate(routeName string, update xdsclient.RouteConfigUpdate, err error) { + if err != nil { + drainAndPush(rh.updateChannel, rdsHandlerUpdate{err: err}) + return + } + rh.mu.Lock() + defer rh.mu.Unlock() + rh.updates[routeName] = update + + // If the full list (determined by length) of updates have successfully + // updated, the listener is ready to be updated. + if len(rh.updates) == len(rh.cancels) { + drainAndPush(rh.updateChannel, rdsHandlerUpdate{updates: rh.updates}) + } +} + +func drainAndPush(ch chan rdsHandlerUpdate, update rdsHandlerUpdate) { + select { + case <-ch: + default: + } + ch <- update +} + +// close() is meant to be called by wrapped listener when the wrapped listener +// is closed, and it cleans up resources by canceling all the active RDS +// watches. +func (rh *rdsHandler) close() { + rh.mu.Lock() + defer rh.mu.Unlock() + for _, cancel := range rh.cancels { + cancel() + } +} diff --git a/xds/internal/server/rds_handler_test.go b/xds/internal/server/rds_handler_test.go new file mode 100644 index 00000000000..d1daffd940c --- /dev/null +++ b/xds/internal/server/rds_handler_test.go @@ -0,0 +1,401 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const ( + route1 = "route1" + route2 = "route2" + route3 = "route3" +) + +// setupTests creates a rds handler with a fake xds client for control over the +// xds client. +func setupTests() (*rdsHandler, *fakeclient.Client, chan rdsHandlerUpdate) { + xdsC := fakeclient.NewClient() + ch := make(chan rdsHandlerUpdate, 1) + rh := newRDSHandler(xdsC, ch) + return rh, xdsC, ch +} + +// waitForFuncWithNames makes sure that a blocking function returns the correct +// set of names, where order doesn't matter. This takes away nondeterminism from +// ranging through a map. +func waitForFuncWithNames(ctx context.Context, f func(context.Context) (string, error), names ...string) error { + wantNames := make(map[string]bool, len(names)) + for _, name := range names { + wantNames[name] = true + } + gotNames := make(map[string]bool, len(names)) + for range wantNames { + name, err := f(ctx) + if err != nil { + return err + } + gotNames[name] = true + } + if !cmp.Equal(gotNames, wantNames) { + return fmt.Errorf("got routeNames %v, want %v", gotNames, wantNames) + } + return nil +} + +// TestSuccessCaseOneRDSWatch tests the simplest scenario: the rds handler +// receives a single route name, starts a watch for that route name, gets a +// successful update, and then writes an update to the update channel for +// listener to pick up. +func (s) TestSuccessCaseOneRDSWatch(t *testing.T) { + rh, fakeClient, ch := setupTests() + // When you first update the rds handler with a list of a single Route names + // that needs dynamic RDS Configuration, this Route name has not been seen + // before, so the RDS Handler should start a watch on that RouteName. + rh.updateRouteNamesToWatch(map[string]bool{route1: true}) + // The RDS Handler should start a watch for that routeName. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotRoute, err := fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error: %v", err) + } + if gotRoute != route1 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route1) + } + rdsUpdate := xdsclient.RouteConfigUpdate{} + // Invoke callback with the xds client with a certain route update. Due to + // this route update updating every route name that rds handler handles, + // this should write to the update channel to send to the listener. + fakeClient.InvokeWatchRouteConfigCallback(route1, rdsUpdate, nil) + rhuWant := map[string]xdsclient.RouteConfigUpdate{route1: rdsUpdate} + select { + case rhu := <-ch: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } + // Close the rds handler. This is meant to be called when the lis wrapper is + // closed, and the call should cancel all the watches present (for this + // test, a single watch). + rh.close() + routeNameDeleted, err := fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error: %v", err) + } + if routeNameDeleted != route1 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route1) + } +} + +// TestSuccessCaseTwoUpdates tests the case where the rds handler receives an +// update with a single Route, then receives a second update with two routes. +// The handler should start a watch for the added route, and if received a RDS +// update for that route it should send an update with both RDS updates present. +func (s) TestSuccessCaseTwoUpdates(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true}) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotRoute, err := fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error: %v", err) + } + if gotRoute != route1 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route1) + } + + // Update the RDSHandler with route names which adds a route name to watch. + // This should trigger the RDSHandler to start a watch for the added route + // name to watch. + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true}) + gotRoute, err = fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error: %v", err) + } + if gotRoute != route2 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route2) + } + + // Invoke the callback with an update for route 1. This shouldn't cause the + // handler to write an update, as it has not received RouteConfigurations + // for every RouteName. + rdsUpdate1 := xdsclient.RouteConfigUpdate{} + fakeClient.InvokeWatchRouteConfigCallback(route1, rdsUpdate1, nil) + + // The RDS Handler should not send an update. + sCtx, sCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCtxCancel() + select { + case <-ch: + t.Fatal("RDS Handler wrote an update to updateChannel when it shouldn't have, as each route name has not received an update yet") + case <-sCtx.Done(): + } + + // Invoke the callback with an update for route 2. This should cause the + // handler to write an update, as it has received RouteConfigurations for + // every RouteName. + rdsUpdate2 := xdsclient.RouteConfigUpdate{} + fakeClient.InvokeWatchRouteConfigCallback(route2, rdsUpdate2, nil) + // The RDS Handler should then update the listener wrapper with an update + // with two route configurations, as both route names the RDS Handler handles + // have received an update. + rhuWant := map[string]xdsclient.RouteConfigUpdate{route1: rdsUpdate1, route2: rdsUpdate2} + select { + case rhu := <-ch: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the rds handler update to be written to the update buffer.") + } + + // Close the rds handler. This is meant to be called when the lis wrapper is + // closed, and the call should cancel all the watches present (for this + // test, two watches on route1 and route2). + rh.close() + if err = waitForFuncWithNames(ctx, fakeClient.WaitForCancelRouteConfigWatch, route1, route2); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } +} + +// TestSuccessCaseDeletedRoute tests the case where the rds handler receives an +// update with two routes, then receives an update with only one route. The RDS +// Handler is expected to cancel the watch for the route no longer present. +func (s) TestSuccessCaseDeletedRoute(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true}) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Will start two watches. + if err := waitForFuncWithNames(ctx, fakeClient.WaitForWatchRouteConfig, route1, route2); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } + + // Update the RDSHandler with route names which deletes a route name to + // watch. This should trigger the RDSHandler to cancel the watch for the + // deleted route name to watch. + rh.updateRouteNamesToWatch(map[string]bool{route1: true}) + // This should delete the watch for route2. + routeNameDeleted, err := fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error %v", err) + } + if routeNameDeleted != route2 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route2) + } + + rdsUpdate := xdsclient.RouteConfigUpdate{} + // Invoke callback with the xds client with a certain route update. Due to + // this route update updating every route name that rds handler handles, + // this should write to the update channel to send to the listener. + fakeClient.InvokeWatchRouteConfigCallback(route1, rdsUpdate, nil) + rhuWant := map[string]xdsclient.RouteConfigUpdate{route1: rdsUpdate} + select { + case rhu := <-ch: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } + + rh.close() + routeNameDeleted, err = fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error: %v", err) + } + if routeNameDeleted != route1 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route1) + } +} + +// TestSuccessCaseTwoUpdatesAddAndDeleteRoute tests the case where the rds +// handler receives an update with two routes, and then receives an update with +// two routes, one previously there and one added (i.e. 12 -> 23). This should +// cause the route that is no longer there to be deleted and cancelled, and the +// route that was added should have a watch started for it. +func (s) TestSuccessCaseTwoUpdatesAddAndDeleteRoute(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true}) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := waitForFuncWithNames(ctx, fakeClient.WaitForWatchRouteConfig, route1, route2); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } + + // Update the rds handler with two routes, one which was already there and a new route. + // This should cause the rds handler to delete/cancel watch for route 1 and start a watch + // for route 3. + rh.updateRouteNamesToWatch(map[string]bool{route2: true, route3: true}) + + // Start watch comes first, which should be for route3 as was just added. + gotRoute, err := fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error: %v", err) + } + if gotRoute != route3 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route3) + } + + // Then route 1 should be deleted/cancelled watch for, as it is no longer present + // in the new RouteName to watch map. + routeNameDeleted, err := fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error: %v", err) + } + if routeNameDeleted != route1 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route1) + } + + // Invoke the callback with an update for route 2. This shouldn't cause the + // handler to write an update, as it has not received RouteConfigurations + // for every RouteName. + rdsUpdate2 := xdsclient.RouteConfigUpdate{} + fakeClient.InvokeWatchRouteConfigCallback(route2, rdsUpdate2, nil) + + // The RDS Handler should not send an update. + sCtx, sCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCtxCancel() + select { + case <-ch: + t.Fatalf("RDS Handler wrote an update to updateChannel when it shouldn't have, as each route name has not received an update yet") + case <-sCtx.Done(): + } + + // Invoke the callback with an update for route 3. This should cause the + // handler to write an update, as it has received RouteConfigurations for + // every RouteName. + rdsUpdate3 := xdsclient.RouteConfigUpdate{} + fakeClient.InvokeWatchRouteConfigCallback(route3, rdsUpdate3, nil) + // The RDS Handler should then update the listener wrapper with an update + // with two route configurations, as both route names the RDS Handler handles + // have received an update. + rhuWant := map[string]xdsclient.RouteConfigUpdate{route2: rdsUpdate2, route3: rdsUpdate3} + select { + case rhu := <-rh.updateChannel: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the rds handler update to be written to the update buffer.") + } + // Close the rds handler. This is meant to be called when the lis wrapper is + // closed, and the call should cancel all the watches present (for this + // test, two watches on route2 and route3). + rh.close() + if err = waitForFuncWithNames(ctx, fakeClient.WaitForCancelRouteConfigWatch, route2, route3); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } +} + +// TestSuccessCaseSecondUpdateMakesRouteFull tests the scenario where the rds handler gets +// told to watch three rds configurations, gets two successful updates, then gets told to watch +// only those two. The rds handler should then write an update to update buffer. +func (s) TestSuccessCaseSecondUpdateMakesRouteFull(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true, route3: true}) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := waitForFuncWithNames(ctx, fakeClient.WaitForWatchRouteConfig, route1, route2, route3); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } + + // Invoke the callbacks for two of the three watches. Since RDS is not full, + // this shouldn't trigger rds handler to write an update to update buffer. + fakeClient.InvokeWatchRouteConfigCallback(route1, xdsclient.RouteConfigUpdate{}, nil) + fakeClient.InvokeWatchRouteConfigCallback(route2, xdsclient.RouteConfigUpdate{}, nil) + + // The RDS Handler should not send an update. + sCtx, sCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCtxCancel() + select { + case <-rh.updateChannel: + t.Fatalf("RDS Handler wrote an update to updateChannel when it shouldn't have, as each route name has not received an update yet") + case <-sCtx.Done(): + } + + // Tell the rds handler to now only watch Route 1 and Route 2. This should + // trigger the rds handler to write an update to the update buffer as it now + // has full rds configuration. + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true}) + // Route 3 should be deleted/cancelled watch for, as it is no longer present + // in the new RouteName to watch map. + routeNameDeleted, err := fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error: %v", err) + } + if routeNameDeleted != route3 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route1) + } + rhuWant := map[string]xdsclient.RouteConfigUpdate{route1: {}, route2: {}} + select { + case rhu := <-ch: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the rds handler update to be written to the update buffer.") + } +} + +// TestErrorReceived tests the case where the rds handler receives a route name +// to watch, then receives an update with an error. This error should be then +// written to the update channel. +func (s) TestErrorReceived(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true}) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotRoute, err := fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error %v", err) + } + if gotRoute != route1 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route1) + } + + rdsErr := errors.New("some error") + fakeClient.InvokeWatchRouteConfigCallback(route1, xdsclient.RouteConfigUpdate{}, rdsErr) + select { + case rhu := <-ch: + if rhu.err.Error() != "some error" { + t.Fatalf("Did not receive the expected error, instead received: %v", rhu.err.Error()) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel") + } +} diff --git a/xds/internal/test/e2e/README.md b/xds/internal/test/e2e/README.md new file mode 100644 index 00000000000..33cffa0da56 --- /dev/null +++ b/xds/internal/test/e2e/README.md @@ -0,0 +1,19 @@ +Build client and server binaries. + +```sh +go build -o ./binaries/client ../../../../interop/xds/client/ +go build -o ./binaries/server ../../../../interop/xds/server/ +``` + +Run the test + +```sh +go test . -v +``` + +The client/server paths are flags + +```sh +go test . -v -client=$HOME/grpc-java/interop-testing/build/install/grpc-interop-testing/bin/xds-test-client +``` +Note that grpc logs are only turned on for Go. diff --git a/xds/internal/test/e2e/controlplane.go b/xds/internal/test/e2e/controlplane.go new file mode 100644 index 00000000000..247991b83d3 --- /dev/null +++ b/xds/internal/test/e2e/controlplane.go @@ -0,0 +1,62 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package e2e + +import ( + "fmt" + + "github.com/google/uuid" + xdsinternal "google.golang.org/grpc/internal/xds" + "google.golang.org/grpc/xds/internal/testutils/e2e" +) + +type controlPlane struct { + server *e2e.ManagementServer + nodeID string + bootstrapContent string +} + +func newControlPlane(testName string) (*controlPlane, error) { + // Spin up an xDS management server on a local port. + server, err := e2e.StartManagementServer() + if err != nil { + return nil, fmt.Errorf("failed to spin up the xDS management server: %v", err) + } + + nodeID := uuid.New().String() + bootstrapContentBytes, err := xdsinternal.BootstrapContents(xdsinternal.BootstrapOptions{ + Version: xdsinternal.TransportV3, + NodeID: nodeID, + ServerURI: server.Address, + ServerListenerResourceNameTemplate: e2e.ServerListenerResourceNameTemplate, + }) + if err != nil { + server.Stop() + return nil, fmt.Errorf("failed to create bootstrap file: %v", err) + } + + return &controlPlane{ + server: server, + nodeID: nodeID, + bootstrapContent: string(bootstrapContentBytes), + }, nil +} + +func (cp *controlPlane) stop() { + cp.server.Stop() +} diff --git a/xds/internal/test/e2e/e2e.go b/xds/internal/test/e2e/e2e.go new file mode 100644 index 00000000000..ade6339bf53 --- /dev/null +++ b/xds/internal/test/e2e/e2e.go @@ -0,0 +1,178 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package e2e implements xds e2e tests using go-control-plane. +package e2e + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + + "google.golang.org/grpc" + channelzgrpc "google.golang.org/grpc/channelz/grpc_channelz_v1" + channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +func cmd(path string, logger io.Writer, args []string, env []string) (*exec.Cmd, error) { + cmd := exec.Command(path, args...) + cmd.Env = append(os.Environ(), env...) + cmd.Stdout = logger + cmd.Stderr = logger + return cmd, nil +} + +const ( + clientStatsPort = 60363 // TODO: make this different per-test, only needed for parallel tests. +) + +type client struct { + cmd *exec.Cmd + + target string + statsCC *grpc.ClientConn +} + +// newClient create a client with the given target and bootstrap content. +func newClient(target, binaryPath, bootstrap string, logger io.Writer, flags ...string) (*client, error) { + cmd, err := cmd( + binaryPath, + logger, + append([]string{ + "--server=" + target, + "--print_response=true", + "--qps=100", + fmt.Sprintf("--stats_port=%d", clientStatsPort), + }, flags...), // Append any flags from caller. + []string{ + "GRPC_GO_LOG_VERBOSITY_LEVEL=99", + "GRPC_GO_LOG_SEVERITY_LEVEL=info", + "GRPC_XDS_BOOTSTRAP_CONFIG=" + bootstrap, // The bootstrap content doesn't need to be quoted. + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to run client cmd: %v", err) + } + cmd.Start() + + cc, err := grpc.Dial(fmt.Sprintf("localhost:%d", clientStatsPort), grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.WaitForReady(true))) + if err != nil { + return nil, err + } + return &client{ + cmd: cmd, + target: target, + statsCC: cc, + }, nil +} + +func (c *client) clientStats(ctx context.Context) (*testpb.LoadBalancerStatsResponse, error) { + ccc := testgrpc.NewLoadBalancerStatsServiceClient(c.statsCC) + return ccc.GetClientStats(ctx, &testpb.LoadBalancerStatsRequest{ + NumRpcs: 100, + TimeoutSec: 10, + }) +} + +func (c *client) configRPCs(ctx context.Context, req *testpb.ClientConfigureRequest) error { + ccc := testgrpc.NewXdsUpdateClientConfigureServiceClient(c.statsCC) + _, err := ccc.Configure(ctx, req) + return err +} + +func (c *client) channelzSubChannels(ctx context.Context) ([]*channelzpb.Subchannel, error) { + ccc := channelzgrpc.NewChannelzClient(c.statsCC) + r, err := ccc.GetTopChannels(ctx, &channelzpb.GetTopChannelsRequest{}) + if err != nil { + return nil, err + } + + var ret []*channelzpb.Subchannel + for _, cc := range r.Channel { + if cc.Data.Target != c.target { + continue + } + for _, sc := range cc.SubchannelRef { + rr, err := ccc.GetSubchannel(ctx, &channelzpb.GetSubchannelRequest{SubchannelId: sc.SubchannelId}) + if err != nil { + return nil, err + } + ret = append(ret, rr.Subchannel) + } + } + return ret, nil +} + +func (c *client) stop() { + c.cmd.Process.Kill() + c.cmd.Wait() +} + +const ( + serverPort = 50051 // TODO: make this different per-test, only needed for parallel tests. +) + +type server struct { + cmd *exec.Cmd + port int +} + +// newServer creates multiple servers with the given bootstrap content. +// +// Each server gets a different hostname, in the format of +// -. +func newServers(hostnamePrefix, binaryPath, bootstrap string, logger io.Writer, count int) (_ []*server, err error) { + var ret []*server + defer func() { + if err != nil { + for _, s := range ret { + s.stop() + } + } + }() + for i := 0; i < count; i++ { + port := serverPort + i + cmd, err := cmd( + binaryPath, + logger, + []string{ + fmt.Sprintf("--port=%d", port), + fmt.Sprintf("--host_name_override=%s-%d", hostnamePrefix, i), + }, + []string{ + "GRPC_GO_LOG_VERBOSITY_LEVEL=99", + "GRPC_GO_LOG_SEVERITY_LEVEL=info", + "GRPC_XDS_BOOTSTRAP_CONFIG=" + bootstrap, // The bootstrap content doesn't need to be quoted., + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to run server cmd: %v", err) + } + cmd.Start() + ret = append(ret, &server{cmd: cmd, port: port}) + } + return ret, nil +} + +func (s *server) stop() { + s.cmd.Process.Kill() + s.cmd.Wait() +} diff --git a/xds/internal/test/e2e/e2e_test.go b/xds/internal/test/e2e/e2e_test.go new file mode 100644 index 00000000000..6984566db2e --- /dev/null +++ b/xds/internal/test/e2e/e2e_test.go @@ -0,0 +1,257 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package e2e + +import ( + "bytes" + "context" + "flag" + "fmt" + "os" + "strconv" + "testing" + "time" + + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" + testpb "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/xds/internal/testutils/e2e" +) + +var ( + clientPath = flag.String("client", "./binaries/client", "The interop client") + serverPath = flag.String("server", "./binaries/server", "The interop server") +) + +type testOpts struct { + testName string + backendCount int + clientFlags []string +} + +func setup(t *testing.T, opts testOpts) (*controlPlane, *client, []*server) { + t.Helper() + if _, err := os.Stat(*clientPath); os.IsNotExist(err) { + t.Skip("skipped because client is not found") + } + if _, err := os.Stat(*serverPath); os.IsNotExist(err) { + t.Skip("skipped because server is not found") + } + backendCount := 1 + if opts.backendCount != 0 { + backendCount = opts.backendCount + } + + cp, err := newControlPlane(opts.testName) + if err != nil { + t.Fatalf("failed to start control-plane: %v", err) + } + t.Cleanup(cp.stop) + + var clientLog bytes.Buffer + c, err := newClient(fmt.Sprintf("xds:///%s", opts.testName), *clientPath, cp.bootstrapContent, &clientLog, opts.clientFlags...) + if err != nil { + t.Fatalf("failed to start client: %v", err) + } + t.Cleanup(c.stop) + + var serverLog bytes.Buffer + servers, err := newServers(opts.testName, *serverPath, cp.bootstrapContent, &serverLog, backendCount) + if err != nil { + t.Fatalf("failed to start server: %v", err) + } + t.Cleanup(func() { + for _, s := range servers { + s.stop() + } + }) + t.Cleanup(func() { + // TODO: find a better way to print the log. They are long, and hide the failure. + t.Logf("\n----- client logs -----\n%v", clientLog.String()) + t.Logf("\n----- server logs -----\n%v", serverLog.String()) + }) + return cp, c, servers +} + +func TestPingPong(t *testing.T) { + const testName = "pingpong" + cp, c, _ := setup(t, testOpts{testName: testName}) + + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: testName, + NodeID: cp.nodeID, + Host: "localhost", + Port: serverPort, + SecLevel: e2e.SecurityLevelNone, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := cp.server.Update(ctx, resources); err != nil { + t.Fatalf("failed to update control plane resources: %v", err) + } + + st, err := c.clientStats(ctx) + if err != nil { + t.Fatalf("failed to get client stats: %v", err) + } + if st.NumFailures != 0 { + t.Fatalf("Got %v failures: %+v", st.NumFailures, st) + } +} + +// TestAffinity covers the affinity tests with ringhash policy. +// - client is configured to use ringhash, with 3 backends +// - all RPCs will hash a specific metadata header +// - verify that +// - all RPCs with the same metadata value are sent to the same backend +// - only one backend is Ready +// - send more RPCs with different metadata values until a new backend is picked, and verify that +// - only two backends are in Ready +func TestAffinity(t *testing.T) { + const ( + testName = "affinity" + backendCount = 3 + testMDKey = "xds_md" + testMDValue = "unary_yranu" + ) + cp, c, servers := setup(t, testOpts{ + testName: testName, + backendCount: backendCount, + clientFlags: []string{"--rpc=EmptyCall", fmt.Sprintf("--metadata=EmptyCall:%s:%s", testMDKey, testMDValue)}, + }) + + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: testName, + NodeID: cp.nodeID, + Host: "localhost", + Port: serverPort, + SecLevel: e2e.SecurityLevelNone, + }) + + // Update EDS to multiple backends. + var ports []uint32 + for _, s := range servers { + ports = append(ports, uint32(s.port)) + } + edsMsg := resources.Endpoints[0] + resources.Endpoints[0] = e2e.DefaultEndpoint( + edsMsg.ClusterName, + "localhost", + ports, + ) + + // Update CDS lbpolicy to ringhash. + cdsMsg := resources.Clusters[0] + cdsMsg.LbPolicy = v3clusterpb.Cluster_RING_HASH + + // Update RDS to hash the header. + rdsMsg := resources.Routes[0] + rdsMsg.VirtualHosts[0].Routes[0].Action = &v3routepb.Route_Route{Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: cdsMsg.Name}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{{ + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: testMDKey, + }, + }, + }}, + }} + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := cp.server.Update(ctx, resources); err != nil { + t.Fatalf("failed to update control plane resources: %v", err) + } + + // Note: We can skip CSDS check because there's no long delay as in TD. + // + // The client stats check doesn't race with the xds resource update because + // there's only one version of xds resource, updated at the beginning of the + // test. So there's no need to retry the stats call. + // + // In the future, we may add tests that update xds in the middle. Then we + // either need to retry clientStats(), or make a CSDS check before so the + // result is stable. + + st, err := c.clientStats(ctx) + if err != nil { + t.Fatalf("failed to get client stats: %v", err) + } + if st.NumFailures != 0 { + t.Fatalf("Got %v failures: %+v", st.NumFailures, st) + } + if len(st.RpcsByPeer) != 1 { + t.Fatalf("more than 1 backends got traffic: %v, want 1", st.RpcsByPeer) + } + + // Call channelz to verify that only one subchannel is in state Ready. + scs, err := c.channelzSubChannels(ctx) + if err != nil { + t.Fatalf("failed to fetch channelz: %v", err) + } + verifySubConnStates(t, scs, map[channelzpb.ChannelConnectivityState_State]int{ + channelzpb.ChannelConnectivityState_READY: 1, + channelzpb.ChannelConnectivityState_IDLE: 2, + }) + + // Send Unary call with different metadata value with integers starting from + // 0. Stop when a second peer is picked. + var ( + diffPeerPicked bool + mdValue int + ) + for !diffPeerPicked { + if err := c.configRPCs(ctx, &testpb.ClientConfigureRequest{ + Types: []testpb.ClientConfigureRequest_RpcType{ + testpb.ClientConfigureRequest_EMPTY_CALL, + testpb.ClientConfigureRequest_UNARY_CALL, + }, + Metadata: []*testpb.ClientConfigureRequest_Metadata{ + {Type: testpb.ClientConfigureRequest_EMPTY_CALL, Key: testMDKey, Value: testMDValue}, + {Type: testpb.ClientConfigureRequest_UNARY_CALL, Key: testMDKey, Value: strconv.Itoa(mdValue)}, + }, + }); err != nil { + t.Fatalf("failed to configure RPC: %v", err) + } + + st, err := c.clientStats(ctx) + if err != nil { + t.Fatalf("failed to get client stats: %v", err) + } + if st.NumFailures != 0 { + t.Fatalf("Got %v failures: %+v", st.NumFailures, st) + } + if len(st.RpcsByPeer) == 2 { + break + } + + mdValue++ + } + + // Call channelz to verify that only one subchannel is in state Ready. + scs2, err := c.channelzSubChannels(ctx) + if err != nil { + t.Fatalf("failed to fetch channelz: %v", err) + } + verifySubConnStates(t, scs2, map[channelzpb.ChannelConnectivityState_State]int{ + channelzpb.ChannelConnectivityState_READY: 2, + channelzpb.ChannelConnectivityState_IDLE: 1, + }) +} diff --git a/internal/credentials/syscallconn_appengine.go b/xds/internal/test/e2e/e2e_utils.go similarity index 50% rename from internal/credentials/syscallconn_appengine.go rename to xds/internal/test/e2e/e2e_utils.go index a6144cd661c..34b0ee9eb09 100644 --- a/internal/credentials/syscallconn_appengine.go +++ b/xds/internal/test/e2e/e2e_utils.go @@ -1,8 +1,6 @@ -// +build appengine - /* * - * Copyright 2018 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,16 +13,24 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * */ -package credentials +package e2e import ( - "net" + "testing" + + "github.com/google/go-cmp/cmp" + channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" ) -// WrapSyscallConn returns newConn on appengine. -func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn { - return newConn +func verifySubConnStates(t *testing.T, scs []*channelzpb.Subchannel, want map[channelzpb.ChannelConnectivityState_State]int) { + t.Helper() + var scStatsCount = map[channelzpb.ChannelConnectivityState_State]int{} + for _, sc := range scs { + scStatsCount[sc.Data.State.State]++ + } + if diff := cmp.Diff(scStatsCount, want); diff != "" { + t.Fatalf("got unexpected number of subchannels in state Ready, %v, scs: %v", diff, scs) + } } diff --git a/xds/internal/test/e2e/run.sh b/xds/internal/test/e2e/run.sh new file mode 100755 index 00000000000..4363d6cbd94 --- /dev/null +++ b/xds/internal/test/e2e/run.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +mkdir binaries +go build -o ./binaries/client ../../../../interop/xds/client/ +go build -o ./binaries/server ../../../../interop/xds/server/ +go test . -v diff --git a/xds/internal/test/xds_client_affinity_test.go b/xds/internal/test/xds_client_affinity_test.go new file mode 100644 index 00000000000..e9ddfe157b1 --- /dev/null +++ b/xds/internal/test/xds_client_affinity_test.go @@ -0,0 +1,136 @@ +//go:build !386 +// +build !386 + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds_test + +import ( + "context" + "fmt" + "testing" + + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/xds/env" + testpb "google.golang.org/grpc/test/grpc_testing" + "google.golang.org/grpc/xds/internal/testutils/e2e" +) + +const hashHeaderName = "session_id" + +// hashRouteConfig returns a RouteConfig resource with hash policy set to +// header "session_id". +func hashRouteConfig(routeName, ldsTarget, clusterName string) *v3routepb.RouteConfiguration { + return &v3routepb.RouteConfiguration{ + Name: routeName, + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{ldsTarget}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}}, + Action: &v3routepb.Route_Route{Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{{ + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: hashHeaderName, + }, + }, + Terminal: true, + }}, + }}, + }}, + }}, + } +} + +// ringhashCluster returns a Cluster resource that picks ringhash as the lb +// policy. +func ringhashCluster(clusterName, edsServiceName string) *v3clusterpb.Cluster { + return &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: edsServiceName, + }, + LbPolicy: v3clusterpb.Cluster_RING_HASH, + } +} + +// TestClientSideAffinitySanityCheck tests that the affinity config can be +// propagated to pick the ring_hash policy. It doesn't test the affinity +// behavior in ring_hash policy. +func (s) TestClientSideAffinitySanityCheck(t *testing.T) { + defer func() func() { + old := env.RingHashSupport + env.RingHashSupport = true + return func() { env.RingHashSupport = old } + }()() + + managementServer, nodeID, _, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + port, cleanup2 := clientSetup(t, &testService{}) + defer cleanup2() + + const serviceName = "my-service-client-side-xds" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + // Replace RDS and CDS resources with ringhash config, but keep the resource + // names. + resources.Routes = []*v3routepb.RouteConfiguration{hashRouteConfig( + resources.Routes[0].Name, + resources.Listeners[0].Name, + resources.Clusters[0].Name, + )} + resources.Clusters = []*v3clusterpb.Cluster{ringhashCluster( + resources.Clusters[0].Name, + resources.Clusters[0].EdsClusterConfig.ServiceName, + )} + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create a ClientConn and make a successful RPC. + cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) + } +} diff --git a/xds/internal/test/xds_client_integration_test.go b/xds/internal/test/xds_client_integration_test.go index c3ea71acaa3..23ea1546935 100644 --- a/xds/internal/test/xds_client_integration_test.go +++ b/xds/internal/test/xds_client_integration_test.go @@ -1,3 +1,4 @@ +//go:build !386 // +build !386 /* @@ -22,51 +23,35 @@ package xds_test import ( "context" + "fmt" "net" "testing" - "github.com/google/uuid" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/status" "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/e2e" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" testpb "google.golang.org/grpc/test/grpc_testing" ) // clientSetup performs a bunch of steps common to all xDS client tests here: -// - spin up an xDS management server on a local port // - spin up a gRPC server and register the test service on it // - create a local TCP listener and start serving on it // // Returns the following: -// - the management server: tests use this to configure resources -// - nodeID expected by the management server: this is set in the Node proto -// sent by the xdsClient for queries. // - the port the server is listening on // - cleanup function to be invoked by the tests when done -func clientSetup(t *testing.T) (*e2e.ManagementServer, string, uint32, func()) { - // Spin up a xDS management server on a local port. - nodeID := uuid.New().String() - fs, err := e2e.StartManagementServer() - if err != nil { - t.Fatal(err) - } - - // Create a bootstrap file in a temporary directory. - bootstrapCleanup, err := e2e.SetupBootstrapFile(e2e.BootstrapOptions{ - Version: e2e.TransportV3, - NodeID: nodeID, - ServerURI: fs.Address, - ServerResourceNameID: "grpc/server", - }) - if err != nil { - t.Fatal(err) - } - +func clientSetup(t *testing.T, tss testpb.TestServiceServer) (uint32, func()) { // Initialize a gRPC server and register the stubServer on it. server := grpc.NewServer() - testpb.RegisterTestServiceServer(server, &testService{}) + testpb.RegisterTestServiceServer(server, tss) // Create a local listener and pass it to Serve(). lis, err := testutils.LocalTCPListener() @@ -80,33 +65,190 @@ func clientSetup(t *testing.T) (*e2e.ManagementServer, string, uint32, func()) { } }() - return fs, nodeID, uint32(lis.Addr().(*net.TCPAddr).Port), func() { - fs.Stop() - bootstrapCleanup() + return uint32(lis.Addr().(*net.TCPAddr).Port), func() { server.Stop() } } func (s) TestClientSideXDS(t *testing.T) { - fs, nodeID, port, cleanup := clientSetup(t) - defer cleanup() + managementServer, nodeID, _, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() - resources := e2e.DefaultClientResources("myservice", nodeID, "localhost", port) - if err := fs.Update(resources); err != nil { + port, cleanup2 := clientSetup(t, &testService{}) + defer cleanup2() + + const serviceName = "my-service-client-side-xds" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { t.Fatal(err) } // Create a ClientConn and make a successful RPC. - cc, err := grpc.Dial("xds:///myservice", grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolver)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } defer cc.Close() client := testpb.NewTestServiceClient(cc) - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { t.Fatalf("rpc EmptyCall() failed: %v", err) } } + +func (s) TestClientSideRetry(t *testing.T) { + if !env.RetrySupport { + // Skip this test if retry is not enabled. + return + } + + ctr := 0 + errs := []codes.Code{codes.ResourceExhausted} + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + defer func() { ctr++ }() + if ctr < len(errs) { + return nil, status.Errorf(errs[ctr], "this should be retried") + } + return &testpb.Empty{}, nil + }, + } + + managementServer, nodeID, _, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + port, cleanup2 := clientSetup(t, ss) + defer cleanup2() + + const serviceName = "my-service-client-side-xds" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create a ClientConn and make a successful RPC. + cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + defer cancel() + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != codes.ResourceExhausted { + t.Fatalf("rpc EmptyCall() = _, %v; want _, ResourceExhausted", err) + } + + testCases := []struct { + name string + vhPolicy *v3routepb.RetryPolicy + routePolicy *v3routepb.RetryPolicy + errs []codes.Code // the errors returned by the server for each RPC + tryAgainErr codes.Code // the error that would be returned if we are still using the old retry policies. + errWant codes.Code + }{{ + name: "virtualHost only, fail", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted,unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 1}, + }, + errs: []codes.Code{codes.ResourceExhausted, codes.Unavailable}, + routePolicy: nil, + tryAgainErr: codes.ResourceExhausted, + errWant: codes.Unavailable, + }, { + name: "virtualHost only", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted, unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + errs: []codes.Code{codes.ResourceExhausted, codes.Unavailable}, + routePolicy: nil, + tryAgainErr: codes.Unavailable, + errWant: codes.OK, + }, { + name: "virtualHost+route, fail", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted,unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + routePolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + errs: []codes.Code{codes.ResourceExhausted, codes.Unavailable}, + tryAgainErr: codes.OK, + errWant: codes.Unavailable, + }, { + name: "virtualHost+route", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + routePolicy: &v3routepb.RetryPolicy{ + RetryOn: "unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + errs: []codes.Code{codes.Unavailable}, + tryAgainErr: codes.Unavailable, + errWant: codes.OK, + }, { + name: "virtualHost+route, not enough attempts", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + routePolicy: &v3routepb.RetryPolicy{ + RetryOn: "unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 1}, + }, + errs: []codes.Code{codes.Unavailable, codes.Unavailable}, + tryAgainErr: codes.OK, + errWant: codes.Unavailable, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + errs = tc.errs + + // Confirm tryAgainErr is correct before updating resources. + ctr = 0 + _, err := client.EmptyCall(ctx, &testpb.Empty{}) + if code := status.Code(err); code != tc.tryAgainErr { + t.Fatalf("with old retry policy: EmptyCall() = _, %v; want _, %v", err, tc.tryAgainErr) + } + + resources.Routes[0].VirtualHosts[0].RetryPolicy = tc.vhPolicy + resources.Routes[0].VirtualHosts[0].Routes[0].GetRoute().RetryPolicy = tc.routePolicy + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + for { + ctr = 0 + _, err := client.EmptyCall(ctx, &testpb.Empty{}) + if code := status.Code(err); code == tc.tryAgainErr { + continue + } else if code != tc.errWant { + t.Fatalf("rpc EmptyCall() = _, %v; want _, %v", err, tc.errWant) + } + break + } + }) + } +} diff --git a/xds/internal/test/xds_integration_test.go b/xds/internal/test/xds_integration_test.go index ae306ae7864..4b7cca3b828 100644 --- a/xds/internal/test/xds_integration_test.go +++ b/xds/internal/test/xds_integration_test.go @@ -1,3 +1,4 @@ +//go:build !386 // +build !386 /* @@ -23,15 +24,32 @@ package xds_test import ( "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path" "testing" "time" + "github.com/google/uuid" + + "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/testdata" + "google.golang.org/grpc/xds" + "google.golang.org/grpc/xds/internal/testutils/e2e" + + xdsinternal "google.golang.org/grpc/internal/xds" testpb "google.golang.org/grpc/test/grpc_testing" ) const ( - defaultTestTimeout = 10 * time.Second + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 100 * time.Millisecond ) type s struct { @@ -49,3 +67,134 @@ type testService struct { func (*testService) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil } + +func (*testService) UnaryCall(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil +} + +func createTmpFile(src, dst string) error { + data, err := ioutil.ReadFile(src) + if err != nil { + return fmt.Errorf("ioutil.ReadFile(%q) failed: %v", src, err) + } + if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil { + return fmt.Errorf("ioutil.WriteFile(%q) failed: %v", dst, err) + } + return nil +} + +// createTempDirWithFiles creates a temporary directory under the system default +// tempDir with the given dirSuffix. It also reads from certSrc, keySrc and +// rootSrc files are creates appropriate files under the newly create tempDir. +// Returns the name of the created tempDir. +func createTmpDirWithFiles(dirSuffix, certSrc, keySrc, rootSrc string) (string, error) { + // Create a temp directory. Passing an empty string for the first argument + // uses the system temp directory. + dir, err := ioutil.TempDir("", dirSuffix) + if err != nil { + return "", fmt.Errorf("ioutil.TempDir() failed: %v", err) + } + + if err := createTmpFile(testdata.Path(certSrc), path.Join(dir, certFile)); err != nil { + return "", err + } + if err := createTmpFile(testdata.Path(keySrc), path.Join(dir, keyFile)); err != nil { + return "", err + } + if err := createTmpFile(testdata.Path(rootSrc), path.Join(dir, rootFile)); err != nil { + return "", err + } + return dir, nil +} + +// createClientTLSCredentials creates client-side TLS transport credentials. +func createClientTLSCredentials(t *testing.T) credentials.TransportCredentials { + t.Helper() + + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) + if err != nil { + t.Fatalf("tls.LoadX509KeyPair(x509/client1_cert.pem, x509/client1_key.pem) failed: %v", err) + } + b, err := ioutil.ReadFile(testdata.Path("x509/server_ca_cert.pem")) + if err != nil { + t.Fatalf("ioutil.ReadFile(x509/server_ca_cert.pem) failed: %v", err) + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(b) { + t.Fatal("failed to append certificates") + } + return credentials.NewTLS(&tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: roots, + ServerName: "x.test.example.com", + }) +} + +// setupManagement server performs the following: +// - spin up an xDS management server on a local port +// - set up certificates for consumption by the file_watcher plugin +// - creates a bootstrap file in a temporary location +// - creates an xDS resolver using the above bootstrap contents +// +// Returns the following: +// - management server +// - nodeID to be used by the client when connecting to the management server +// - bootstrap contents to be used by the client +// - xDS resolver builder to be used by the client +// - a cleanup function to be invoked at the end of the test +func setupManagementServer(t *testing.T) (*e2e.ManagementServer, string, []byte, resolver.Builder, func()) { + t.Helper() + + // Spin up an xDS management server on a local port. + server, err := e2e.StartManagementServer() + if err != nil { + t.Fatalf("Failed to spin up the xDS management server: %v", err) + } + defer func() { + if err != nil { + server.Stop() + } + }() + + // Create a directory to hold certs and key files used on the server side. + serverDir, err := createTmpDirWithFiles("testServerSideXDS*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") + if err != nil { + server.Stop() + t.Fatal(err) + } + + // Create a directory to hold certs and key files used on the client side. + clientDir, err := createTmpDirWithFiles("testClientSideXDS*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") + if err != nil { + server.Stop() + t.Fatal(err) + } + + // Create certificate providers section of the bootstrap config with entries + // for both the client and server sides. + cpc := map[string]json.RawMessage{ + e2e.ServerSideCertProviderInstance: e2e.DefaultFileWatcherConfig(path.Join(serverDir, certFile), path.Join(serverDir, keyFile), path.Join(serverDir, rootFile)), + e2e.ClientSideCertProviderInstance: e2e.DefaultFileWatcherConfig(path.Join(clientDir, certFile), path.Join(clientDir, keyFile), path.Join(clientDir, rootFile)), + } + + // Create a bootstrap file in a temporary directory. + nodeID := uuid.New().String() + bootstrapContents, err := xdsinternal.BootstrapContents(xdsinternal.BootstrapOptions{ + Version: xdsinternal.TransportV3, + NodeID: nodeID, + ServerURI: server.Address, + CertificateProviders: cpc, + ServerListenerResourceNameTemplate: e2e.ServerListenerResourceNameTemplate, + }) + if err != nil { + server.Stop() + t.Fatalf("Failed to create bootstrap file: %v", err) + } + resolver, err := xds.NewXDSResolverWithConfigForTesting(bootstrapContents) + if err != nil { + server.Stop() + t.Fatalf("Failed to create xDS resolver for testing: %v", err) + } + + return server, nodeID, bootstrapContents, resolver, func() { server.Stop() } +} diff --git a/xds/internal/test/xds_security_config_nack_test.go b/xds/internal/test/xds_security_config_nack_test.go new file mode 100644 index 00000000000..7b8e36c3f3a --- /dev/null +++ b/xds/internal/test/xds_security_config_nack_test.go @@ -0,0 +1,372 @@ +//go:build !386 +// +build !386 + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds_test + +import ( + "context" + "fmt" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + xdscreds "google.golang.org/grpc/credentials/xds" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +func (s) TestUnmarshalListener_WithUpdateValidatorFunc(t *testing.T) { + const ( + serviceName = "my-service-client-side-xds" + missingIdentityProviderInstance = "missing-identity-provider-instance" + missingRootProviderInstance = "missing-root-provider-instance" + ) + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + // Grab the host and port of the server and create client side xDS + // resources corresponding to it. + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + + // Create xDS resources to be consumed on the client side. This + // includes the listener, route configuration, cluster (with + // security configuration) and endpoint resources. + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelMTLS, + }) + + tests := []struct { + name string + securityConfig *v3corepb.TransportSocket + wantErr bool + }{ + { + name: "both identity and root providers are not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingIdentityProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingRootProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "only identity provider is not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingIdentityProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ServerSideCertProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "only root provider is not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ServerSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingRootProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "both identity and root providers are present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ServerSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ServerSideCertProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: false, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Create an inbound xDS listener resource for the server side. + inboundLis := e2e.DefaultServerListener(host, port, e2e.SecurityLevelMTLS) + for _, fc := range inboundLis.GetFilterChains() { + fc.TransportSocket = test.securityConfig + } + + // Setup the management server with client and server resources. + if len(resources.Listeners) == 1 { + resources.Listeners = append(resources.Listeners, inboundLis) + } else { + resources.Listeners[1] = inboundLis + } + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create client-side xDS credentials with an insecure fallback. + creds, err := xdscreds.NewClientCredentials(xdscreds.ClientOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + t.Fatal(err) + } + + // Create a ClientConn with the xds scheme and make an RPC. + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + // Make a context with a shorter timeout from the top level test + // context for cases where we expect failures. + timeout := defaultTestTimeout + if test.wantErr { + timeout = defaultTestShortTimeout + } + ctx2, cancel2 := context.WithTimeout(ctx, timeout) + defer cancel2() + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx2, &testpb.Empty{}, grpc.WaitForReady(true)); (err != nil) != test.wantErr { + t.Fatalf("EmptyCall() returned err: %v, wantErr %v", err, test.wantErr) + } + }) + } +} + +func (s) TestUnmarshalCluster_WithUpdateValidatorFunc(t *testing.T) { + const ( + serviceName = "my-service-client-side-xds" + missingIdentityProviderInstance = "missing-identity-provider-instance" + missingRootProviderInstance = "missing-root-provider-instance" + ) + + tests := []struct { + name string + securityConfig *v3corepb.TransportSocket + wantErr bool + }{ + { + name: "both identity and root providers are not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingIdentityProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingRootProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "only identity provider is not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingIdentityProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ClientSideCertProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "only root provider is not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ClientSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingRootProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "both identity and root providers are present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ClientSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ClientSideCertProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // setupManagementServer() sets up a bootstrap file with certificate + // provider instance names: `e2e.ServerSideCertProviderInstance` and + // `e2e.ClientSideCertProviderInstance`. + managementServer, nodeID, _, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + port, cleanup2 := clientSetup(t, &testService{}) + defer cleanup2() + + // This creates a `Cluster` resource with a security config which + // refers to `e2e.ClientSideCertProviderInstance` for both root and + // identity certs. + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelMTLS, + }) + resources.Clusters[0].TransportSocket = test.securityConfig + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + // Make a context with a shorter timeout from the top level test + // context for cases where we expect failures. + timeout := defaultTestTimeout + if test.wantErr { + timeout = defaultTestShortTimeout + } + ctx2, cancel2 := context.WithTimeout(ctx, timeout) + defer cancel2() + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx2, &testpb.Empty{}, grpc.WaitForReady(true)); (err != nil) != test.wantErr { + t.Fatalf("EmptyCall() returned err: %v, wantErr %v", err, test.wantErr) + } + }) + } +} diff --git a/xds/internal/test/xds_server_integration_test.go b/xds/internal/test/xds_server_integration_test.go index 169daf19d26..707a9605d82 100644 --- a/xds/internal/test/xds_server_integration_test.go +++ b/xds/internal/test/xds_server_integration_test.go @@ -1,3 +1,4 @@ +//go:build !386 // +build !386 /* @@ -23,36 +24,35 @@ package xds_test import ( "context" - "crypto/tls" - "crypto/x509" "fmt" - "io/ioutil" "net" - "os" - "path" "strconv" + "strings" "testing" - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" - v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" - wrapperspb "github.com/golang/protobuf/ptypes/wrappers" - "github.com/google/uuid" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - xdscreds "google.golang.org/grpc/credentials/xds" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/status" - testpb "google.golang.org/grpc/test/grpc_testing" - "google.golang.org/grpc/testdata" "google.golang.org/grpc/xds" - "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/httpfilter/rbac" "google.golang.org/grpc/xds/internal/testutils/e2e" - "google.golang.org/grpc/xds/internal/version" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + rpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/rbac/v3" + v3routerpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/router/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + anypb "github.com/golang/protobuf/ptypes/any" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" + xdscreds "google.golang.org/grpc/credentials/xds" + testpb "google.golang.org/grpc/test/grpc_testing" + xdstestutils "google.golang.org/grpc/xds/internal/testutils" ) const ( @@ -62,156 +62,587 @@ const ( rootFile = "ca.pem" ) -func createTmpFile(t *testing.T, src, dst string) { +// setupGRPCServer performs the following: +// - spin up an xDS-enabled gRPC server, configure it with xdsCredentials and +// register the test service on it +// - create a local TCP listener and start serving on it +// +// Returns the following: +// - local listener on which the xDS-enabled gRPC server is serving on +// - cleanup function to be invoked by the tests when done +func setupGRPCServer(t *testing.T, bootstrapContents []byte) (net.Listener, func()) { t.Helper() - data, err := ioutil.ReadFile(src) + // Configure xDS credentials to be used on the server-side. + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{ + FallbackCreds: insecure.NewCredentials(), + }) if err != nil { - t.Fatalf("ioutil.ReadFile(%q) failed: %v", src, err) - } - if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil { - t.Fatalf("ioutil.WriteFile(%q) failed: %v", dst, err) + t.Fatal(err) } - t.Logf("Wrote file at: %s", dst) - t.Logf("%s", string(data)) -} -// createTempDirWithFiles creates a temporary directory under the system default -// tempDir with the given dirSuffix. It also reads from certSrc, keySrc and -// rootSrc files are creates appropriate files under the newly create tempDir. -// Returns the name of the created tempDir. -func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string { - t.Helper() + // Initialize an xDS-enabled gRPC server and register the stubServer on it. + server := xds.NewGRPCServer(grpc.Creds(creds), xds.BootstrapContentsForTesting(bootstrapContents)) + testpb.RegisterTestServiceServer(server, &testService{}) - // Create a temp directory. Passing an empty string for the first argument - // uses the system temp directory. - dir, err := ioutil.TempDir("", dirSuffix) + // Create a local listener and pass it to Serve(). + lis, err := xdstestutils.LocalTCPListener() if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) + t.Fatalf("testutils.LocalTCPListener() failed: %v", err) } - t.Logf("Using tmpdir: %s", dir) - createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile)) - createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile)) - createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile)) - return dir + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + return lis, func() { + server.Stop() + } } -// createClientTLSCredentials creates client-side TLS transport credentials. -func createClientTLSCredentials(t *testing.T) credentials.TransportCredentials { - cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) +func hostPortFromListener(lis net.Listener) (string, uint32, error) { + host, p, err := net.SplitHostPort(lis.Addr().String()) if err != nil { - t.Fatalf("tls.LoadX509KeyPair(x509/client1_cert.pem, x509/client1_key.pem) failed: %v", err) + return "", 0, fmt.Errorf("net.SplitHostPort(%s) failed: %v", lis.Addr().String(), err) } - b, err := ioutil.ReadFile(testdata.Path("x509/server_ca_cert.pem")) + port, err := strconv.ParseInt(p, 10, 32) if err != nil { - t.Fatalf("ioutil.ReadFile(x509/server_ca_cert.pem) failed: %v", err) + return "", 0, fmt.Errorf("strconv.ParseInt(%s, 10, 32) failed: %v", p, err) } - roots := x509.NewCertPool() - if !roots.AppendCertsFromPEM(b) { - t.Fatal("failed to append certificates") - } - return credentials.NewTLS(&tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: roots, - ServerName: "x.test.example.com", - }) + return host, uint32(port), nil } -// commonSetup performs a bunch of steps common to all xDS server tests here: -// - spin up an xDS management server on a local port -// - set up certificates for consumption by the file_watcher plugin -// - spin up an xDS-enabled gRPC server, configure it with xdsCredentials and -// register the test service on it -// - create a local TCP listener and start serving on it +// TestServerSideXDS_Fallback is an e2e test which verifies xDS credentials +// fallback functionality. // -// Returns the following: -// - the management server: tests use this to configure resources -// - nodeID expected by the management server: this is set in the Node proto -// sent by the xdsClient used on the xDS-enabled gRPC server -// - local listener on which the xDS-enabled gRPC server is serving on -// - cleanup function to be invoked by the tests when done -func commonSetup(t *testing.T) (*e2e.ManagementServer, string, net.Listener, func()) { - t.Helper() +// The following sequence of events happen as part of this test: +// - An xDS-enabled gRPC server is created and xDS credentials are configured. +// - xDS is enabled on the client by the use of the xds:/// scheme, and xDS +// credentials are configured. +// - Control plane is configured to not send any security configuration to both +// the client and the server. This results in both of them using the +// configured fallback credentials (which is insecure creds in this case). +func (s) TestServerSideXDS_Fallback(t *testing.T) { + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() - // Spin up a xDS management server on a local port. - nodeID := uuid.New().String() - fs, err := e2e.StartManagementServer() + // Grab the host and port of the server and create client side xDS resources + // corresponding to it. This contains default resources with no security + // configuration in the Cluster resources. + host, port, err := hostPortFromListener(lis) if err != nil { - t.Fatal(err) + t.Fatalf("failed to retrieve host and port of server: %v", err) } + const serviceName = "my-service-fallback" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) - // Create certificate and key files in a temporary directory and generate - // certificate provider configuration for a file_watcher plugin. - tmpdir := createTmpDirWithFiles(t, "testServerSideXDS*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") - cpc := e2e.DefaultFileWatcherConfig(path.Join(tmpdir, certFile), path.Join(tmpdir, keyFile), path.Join(tmpdir, rootFile)) + // Create an inbound xDS listener resource for the server side that does not + // contain any security configuration. This should force the server-side + // xdsCredentials to use fallback. + inboundLis := e2e.DefaultServerListener(host, port, e2e.SecurityLevelNone) + resources.Listeners = append(resources.Listeners, inboundLis) - // Create a bootstrap file in a temporary directory. - bootstrapCleanup, err := e2e.SetupBootstrapFile(e2e.BootstrapOptions{ - Version: e2e.TransportV3, - NodeID: nodeID, - ServerURI: fs.Address, - CertificateProviders: cpc, - ServerResourceNameID: "grpc/server", + // Setup the management server with client and server-side resources. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create client-side xDS credentials with an insecure fallback. + creds, err := xdscreds.NewClientCredentials(xdscreds.ClientOptions{ + FallbackCreds: insecure.NewCredentials(), }) if err != nil { t.Fatal(err) } - // Configure xDS credentials to be used on the server-side. - creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{ + // Create a ClientConn with the xds scheme and make a successful RPC. + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Errorf("rpc EmptyCall() failed: %v", err) + } +} + +// TestServerSideXDS_FileWatcherCerts is an e2e test which verifies xDS +// credentials with file watcher certificate provider. +// +// The following sequence of events happen as part of this test: +// - An xDS-enabled gRPC server is created and xDS credentials are configured. +// - xDS is enabled on the client by the use of the xds:/// scheme, and xDS +// credentials are configured. +// - Control plane is configured to send security configuration to both the +// client and the server, pointing to the file watcher certificate provider. +// We verify both TLS and mTLS scenarios. +func (s) TestServerSideXDS_FileWatcherCerts(t *testing.T) { + tests := []struct { + name string + secLevel e2e.SecurityLevel + }{ + { + name: "tls", + secLevel: e2e.SecurityLevelTLS, + }, + { + name: "mtls", + secLevel: e2e.SecurityLevelMTLS, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + // Grab the host and port of the server and create client side xDS + // resources corresponding to it. + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + + // Create xDS resources to be consumed on the client side. This + // includes the listener, route configuration, cluster (with + // security configuration) and endpoint resources. + serviceName := "my-service-file-watcher-certs-" + test.name + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: test.secLevel, + }) + + // Create an inbound xDS listener resource for the server side that + // contains security configuration pointing to the file watcher + // plugin. + inboundLis := e2e.DefaultServerListener(host, port, test.secLevel) + resources.Listeners = append(resources.Listeners, inboundLis) + + // Setup the management server with client and server resources. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create client-side xDS credentials with an insecure fallback. + creds, err := xdscreds.NewClientCredentials(xdscreds.ClientOptions{ + FallbackCreds: insecure.NewCredentials(), + }) + if err != nil { + t.Fatal(err) + } + + // Create a ClientConn with the xds scheme and make an RPC. + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) + } + }) + } +} + +// TestServerSideXDS_SecurityConfigChange is an e2e test where xDS is enabled on +// the server-side and xdsCredentials are configured for security. The control +// plane initially does not any security configuration. This forces the +// xdsCredentials to use fallback creds, which is this case is insecure creds. +// We verify that a client connecting with TLS creds is not able to successfully +// make an RPC. The control plane then sends a listener resource with security +// configuration pointing to the use of the file_watcher plugin and we verify +// that the same client is now able to successfully make an RPC. +func (s) TestServerSideXDS_SecurityConfigChange(t *testing.T) { + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + // Grab the host and port of the server and create client side xDS resources + // corresponding to it. This contains default resources with no security + // configuration in the Cluster resource. This should force the xDS + // credentials on the client to use its fallback. + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + const serviceName = "my-service-security-config-change" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + + // Create an inbound xDS listener resource for the server side that does not + // contain any security configuration. This should force the xDS credentials + // on server to use its fallback. + inboundLis := e2e.DefaultServerListener(host, port, e2e.SecurityLevelNone) + resources.Listeners = append(resources.Listeners, inboundLis) + + // Setup the management server with client and server-side resources. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create client-side xDS credentials with an insecure fallback. + xdsCreds, err := xdscreds.NewClientCredentials(xdscreds.ClientOptions{ FallbackCreds: insecure.NewCredentials(), }) if err != nil { t.Fatal(err) } - // Initialize an xDS-enabled gRPC server and register the stubServer on it. - server := xds.NewGRPCServer(grpc.Creds(creds)) - testpb.RegisterTestServiceServer(server, &testService{}) + // Create a ClientConn with the xds scheme and make a successful RPC. + xdsCC, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(xdsCreds), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer xdsCC.Close() - // Create a local listener and pass it to Serve(). - lis, err := testutils.LocalTCPListener() + client := testpb.NewTestServiceClient(xdsCC) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) + } + + // Create a ClientConn with TLS creds. This should fail since the server is + // using fallback credentials which in this case in insecure creds. + tlsCreds := createClientTLSCredentials(t) + tlsCC, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithTransportCredentials(tlsCreds)) if err != nil { - t.Fatalf("testutils.LocalTCPListener() failed: %v", err) + t.Fatalf("failed to dial local test server: %v", err) } + defer tlsCC.Close() - go func() { - if err := server.Serve(lis); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() + // We don't set 'waitForReady` here since we want this call to failfast. + client = testpb.NewTestServiceClient(tlsCC) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable { + t.Fatal("rpc EmptyCall() succeeded when expected to fail") + } - return fs, nodeID, lis, func() { - fs.Stop() - bootstrapCleanup() - server.Stop() + // Switch server and client side resources with ones that contain required + // security configuration for mTLS with a file watcher certificate provider. + resources = e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelMTLS, + }) + inboundLis = e2e.DefaultServerListener(host, port, e2e.SecurityLevelMTLS) + resources.Listeners = append(resources.Listeners, inboundLis) + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Make another RPC with `waitForReady` set and expect this to succeed. + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) } } -func hostPortFromListener(t *testing.T, lis net.Listener) (string, uint32) { - t.Helper() +// TestServerSideXDS_RouteConfiguration is an e2e test which verifies routing +// functionality. The xDS enabled server will be set up with route configuration +// where the route configuration has routes with the correct routing actions +// (NonForwardingAction), and the RPC's matching those routes should proceed as +// normal. +func (s) TestServerSideXDS_RouteConfiguration(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() - host, p, err := net.SplitHostPort(lis.Addr().String()) + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + host, port, err := hostPortFromListener(lis) if err != nil { - t.Fatalf("net.SplitHostPort(%s) failed: %v", lis.Addr().String(), err) + t.Fatalf("failed to retrieve host and port of server: %v", err) } - port, err := strconv.ParseInt(p, 10, 32) + const serviceName = "my-service-fallback" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + + // Create an inbound xDS listener resource with route configuration which + // selectively will allow RPC's through or not. This will test routing in + // xds(Unary|Stream)Interceptors. + vhs := []*v3routepb.VirtualHost{ + // Virtual host that will never be matched to test Virtual Host selection. + { + Domains: []string{"this will not match*"}, + Routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }, + }, + }, + // This Virtual Host will actually get matched to. + { + Domains: []string{"*"}, + Routes: []*v3routepb.Route{ + // A routing rule that can be selectively triggered based on properties about incoming RPC. + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/grpc.testing.TestService/EmptyCall"}, + // "Fully-qualified RPC method name with leading slash. Same as :path header". + }, + // Correct Action, so RPC's that match this route should proceed to interceptor processing. + Action: &v3routepb.Route_NonForwardingAction{}, + }, + // This routing rule is matched the same way as the one above, + // except has an incorrect action for the server side. However, + // since routing chooses the first route which matches an + // incoming RPC, this should never get invoked (iteration + // through this route slice is deterministic). + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/grpc.testing.TestService/EmptyCall"}, + // "Fully-qualified RPC method name with leading slash. Same as :path header". + }, + // Incorrect Action, so RPC's that match this route should get denied. + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: ""}}, + }, + }, + // Another routing rule that can be selectively triggered based on incoming RPC. + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/grpc.testing.TestService/UnaryCall"}, + }, + // Wrong action (!Non_Forwarding_Action) so RPC's that match this route should get denied. + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: ""}}, + }, + }, + // Another routing rule that can be selectively triggered based on incoming RPC. + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/grpc.testing.TestService/StreamingInputCall"}, + }, + // Wrong action (!Non_Forwarding_Action) so RPC's that match this route should get denied. + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: ""}}, + }, + }, + // Not matching route, this is be able to get invoked logically (i.e. doesn't have to match the Route configurations above). + }}, + } + inboundLis := &v3listenerpb.Listener{ + Name: fmt.Sprintf(e2e.ServerListenerResourceNameTemplate, net.JoinHostPort(host, strconv.Itoa(int(port)))), + Address: &v3corepb.Address{ + Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Address: host, + PortSpecifier: &v3corepb.SocketAddress_PortValue{ + PortValue: port, + }, + }, + }, + }, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "v4-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{e2e.HTTPFilter("router", &v3routerpb.Router{})}, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: vhs, + }, + }, + }), + }, + }, + }, + }, + { + Name: "v6-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{e2e.HTTPFilter("router", &v3routerpb.Router{})}, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: vhs, + }, + }, + }), + }, + }, + }, + }, + }, + } + resources.Listeners = append(resources.Listeners, inboundLis) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Setup the management server with client and server-side resources. + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithInsecure(), grpc.WithResolvers(resolver)) if err != nil { - t.Fatalf("strconv.ParseInt(%s, 10, 32) failed: %v", p, err) + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + + // This Empty Call should match to a route with a correct action + // (NonForwardingAction). Thus, this RPC should proceed as normal. There is + // a routing rule that this RPC would match to that has an incorrect action, + // but the server should only use the first route matched to with the + // correct action. + if _, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) } - return host, uint32(port) + // This Unary Call should match to a route with an incorrect action. Thus, + // this RPC should not go through as per A36, and this call should receive + // an error with codes.Unavailable. + if _, err = client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Unavailable { + t.Fatalf("client.UnaryCall() = _, %v, want _, error code %s", err, codes.Unavailable) + } + + // This Streaming Call should match to a route with an incorrect action. + // Thus, this RPC should not go through as per A36, and this call should + // receive an error with codes.Unavailable. + stream, err := client.StreamingInputCall(ctx) + if err != nil { + t.Fatalf("StreamingInputCall(_) = _, %v, want ", err) + } + if _, err = stream.CloseAndRecv(); status.Code(err) != codes.Unavailable || !strings.Contains(err.Error(), "the incoming RPC matched to a route that was not of action type non forwarding") { + t.Fatalf("streaming RPC should have been denied") + } + + // This Full Duplex should not match to a route, and thus should return an + // error and not proceed. + dStream, err := client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("FullDuplexCall(_) = _, %v, want ", err) + } + if _, err = dStream.Recv(); status.Code(err) != codes.Unavailable || !strings.Contains(err.Error(), "the incoming RPC did not match a configured Route") { + t.Fatalf("streaming RPC should have been denied") + } } -// listenerResourceWithoutSecurityConfig returns a listener resource with no -// security configuration, and name and address fields matching the passed in -// net.Listener. -func listenerResourceWithoutSecurityConfig(t *testing.T, lis net.Listener) *v3listenerpb.Listener { - host, port := hostPortFromListener(t, lis) +// serverListenerWithRBACHTTPFilters returns an xds Listener resource with HTTP Filters defined in the HCM, and a route +// configuration that always matches to a route and a VH. +func serverListenerWithRBACHTTPFilters(host string, port uint32, rbacCfg *rpb.RBAC) *v3listenerpb.Listener { + // Rather than declare typed config inline, take a HCM proto and append the + // RBAC Filters to it. + hcm := &v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"*"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}, + // This tests override parsing + building when RBAC Filter + // passed both normal and override config. + TypedPerFilterConfig: map[string]*anypb.Any{ + "rbac": testutils.MarshalAny(&rpb.RBACPerRoute{Rbac: rbacCfg}), + }, + }}}, + }, + } + hcm.HttpFilters = nil + hcm.HttpFilters = append(hcm.HttpFilters, e2e.HTTPFilter("rbac", rbacCfg)) + hcm.HttpFilters = append(hcm.HttpFilters, e2e.RouterHTTPFilter) + return &v3listenerpb.Listener{ - // This needs to match the name we are querying for. - Name: fmt.Sprintf("grpc/server?udpa.resource.listening_address=%s", lis.Addr().String()), + Name: fmt.Sprintf(e2e.ServerListenerResourceNameTemplate, net.JoinHostPort(host, strconv.Itoa(int(port)))), Address: &v3corepb.Address{ Address: &v3corepb.Address_SocketAddress{ SocketAddress: &v3corepb.SocketAddress{ @@ -224,176 +655,563 @@ func listenerResourceWithoutSecurityConfig(t *testing.T, lis net.Listener) *v3li }, FilterChains: []*v3listenerpb.FilterChain{ { - Name: "filter-chain-1", + Name: "v4-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(hcm), + }, + }, + }, + }, + { + Name: "v6-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(hcm), + }, + }, + }, + }, + }, + } +} + +// TestRBACHTTPFilter tests the xds configured RBAC HTTP Filter. It sets up the +// full end to end flow, and makes sure certain RPC's are successful and proceed +// as normal and certain RPC's are denied by the RBAC HTTP Filter which gets +// called by hooked xds interceptors. +func (s) TestRBACHTTPFilter(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + rbac.RegisterForTesting() + defer rbac.UnregisterForTesting() + tests := []struct { + name string + rbacCfg *rpb.RBAC + wantStatusEmptyCall codes.Code + wantStatusUnaryCall codes.Code + }{ + // This test tests an RBAC HTTP Filter which is configured to allow any RPC. + // Any RPC passing through this RBAC HTTP Filter should proceed as normal. + { + name: "allow-anything", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, + }, + // This test tests an RBAC HTTP Filter which is configured to allow only + // RPC's with certain paths ("UnaryCall"). Only unary calls passing + // through this RBAC HTTP Filter should proceed as normal, and any + // others should be denied. + { + name: "allow-certain-path", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-path": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "/grpc.testing.TestService/UnaryCall"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.PermissionDenied, + wantStatusUnaryCall: codes.OK, + }, + // This test that a RBAC Config with nil rules means that every RPC is + // allowed. This maps to the line "If absent, no enforcing RBAC policy + // will be applied" from the RBAC Proto documentation for the Rules + // field. + { + name: "absent-rules", + rbacCfg: &rpb.RBAC{ + Rules: nil, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, + }, + // The two tests below test that configuring the xDS RBAC HTTP Filter + // with :authority and host header matchers end up being logically + // equivalent. This represents functionality from this line in A41 - + // "As documented for HeaderMatcher, Envoy aliases :authority and Host + // in its header map implementation, so they should be treated + // equivalent for the RBAC matchers; there must be no behavior change + // depending on which of the two header names is used in the RBAC + // policy." + + // This test tests an xDS RBAC Filter with an :authority header matcher. + { + name: "match-on-authority", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "match-on-authority": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":authority", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "my-service-fallback"}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, + }, + // This test tests that configuring an xDS RBAC Filter with a host + // header matcher has the same behavior as if it was configured with + // :authority. Since host and authority are aliased, this should still + // continue to match on incoming RPC's :authority, just as the test + // above. + { + name: "match-on-host", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "match-on-authority": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: "host", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "my-service-fallback"}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, + }, + // This test tests that the RBAC HTTP Filter hard codes the :method + // header to POST. Since the RBAC Configuration says to deny every RPC + // with a method :POST, every RPC tried should be denied. + { + name: "deny-post", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "post-method": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "POST"}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, }, + wantStatusEmptyCall: codes.PermissionDenied, + wantStatusUnaryCall: codes.PermissionDenied, + }, + // This test tests that RBAC ignores the TE: trailers header (which is + // hardcoded in http2_client.go for every RPC). Since the RBAC + // Configuration says to only ALLOW RPC's with a TE: Trailers, every RPC + // tried should be denied. + { + name: "allow-only-te", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "post-method": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: "TE", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "trailers"}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.PermissionDenied, + wantStatusUnaryCall: codes.PermissionDenied, + }, + // This test tests that an RBAC Config with Action.LOG configured allows + // every RPC through. This maps to the line "At this time, if the + // RBAC.action is Action.LOG then the policy will be completely ignored, + // as if RBAC was not configurated." from A41 + { + name: "action-log", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_LOG, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, }, } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + func() { + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + const serviceName = "my-service-fallback" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + inboundLis := serverListenerWithRBACHTTPFilters(host, port, test.rbacCfg) + resources.Listeners = append(resources.Listeners, inboundLis) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Setup the management server with client and server-side resources. + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithInsecure(), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != test.wantStatusEmptyCall { + t.Fatalf("EmptyCall() returned err with status: %v, wantStatusEmptyCall: %v", status.Code(err), test.wantStatusEmptyCall) + } + + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != test.wantStatusUnaryCall { + t.Fatalf("UnaryCall() returned err with status: %v, wantStatusUnaryCall: %v", err, test.wantStatusUnaryCall) + } + + // Toggle the RBAC Env variable off, this should disable RBAC and allow any RPC"s through (will not go through + // routing or processed by HTTP Filters and thus will never get denied by RBAC). + env.RBACSupport = false + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.OK { + t.Fatalf("EmptyCall() returned err with status: %v, once RBAC is disabled all RPC's should proceed as normal", status.Code(err)) + } + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.OK { + t.Fatalf("UnaryCall() returned err with status: %v, once RBAC is disabled all RPC's should proceed as normal", status.Code(err)) + } + // Toggle RBAC back on for next iterations. + env.RBACSupport = true + }() + }) + } } -// listenerResourceWithSecurityConfig returns a listener resource with security -// configuration pointing to the use of the file_watcher certificate provider -// plugin, and name and address fields matching the passed in net.Listener. -func listenerResourceWithSecurityConfig(t *testing.T, lis net.Listener) *v3listenerpb.Listener { - host, port := hostPortFromListener(t, lis) +// serverListenerWithBadRouteConfiguration returns an xds Listener resource with +// a Route Configuration that will never successfully match in order to test +// RBAC Environment variable being toggled on and off. +func serverListenerWithBadRouteConfiguration(host string, port uint32) *v3listenerpb.Listener { return &v3listenerpb.Listener{ - // This needs to match the name we are querying for. - Name: fmt.Sprintf("grpc/server?udpa.resource.listening_address=%s", lis.Addr().String()), + Name: fmt.Sprintf(e2e.ServerListenerResourceNameTemplate, net.JoinHostPort(host, strconv.Itoa(int(port)))), Address: &v3corepb.Address{ Address: &v3corepb.Address_SocketAddress{ SocketAddress: &v3corepb.SocketAddress{ Address: host, PortSpecifier: &v3corepb.SocketAddress_PortValue{ PortValue: port, - }}}}, + }, + }, + }, + }, FilterChains: []*v3listenerpb.FilterChain{ { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "google_cloud_private_spiffe", - }, - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "google_cloud_private_spiffe", - }}}} - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }}}}}, + Name: "v4-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + // Incoming RPC's will try and match to Virtual Hosts based on their :authority header. + // Thus, incoming RPC's will never match to a Virtual Host (server side requires matching + // to a VH/Route of type Non Forwarding Action to proceed normally), and all incoming RPC's + // with this route configuration will be denied. + Domains: []string{"will-never-match"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + { + Name: "v6-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + // Incoming RPC's will try and match to Virtual Hosts based on their :authority header. + // Thus, incoming RPC's will never match to a Virtual Host (server side requires matching + // to a VH/Route of type Non Forwarding Action to proceed normally), and all incoming RPC's + // with this route configuration will be denied. + Domains: []string{"will-never-match"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, } } -// TestServerSideXDS_Fallback is an e2e test where xDS is enabled on the -// server-side and xdsCredentials are configured for security. The control plane -// does not provide any security configuration and therefore the xdsCredentials -// uses fallback credentials, which in this case is insecure creds. -func (s) TestServerSideXDS_Fallback(t *testing.T) { - fs, nodeID, lis, cleanup := commonSetup(t) - defer cleanup() +func (s) TestRBACToggledOn_WithBadRouteConfiguration(t *testing.T) { + // Turn RBAC support on. + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() - // Setup the fake management server to respond with a Listener resource that - // does not contain any security configuration. This should force the - // server-side xdsCredentials to use fallback. - listener := listenerResourceWithoutSecurityConfig(t, lis) - if err := fs.Update(e2e.UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{listener}, - }); err != nil { - t.Error(err) - } + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() - // Create a ClientConn and make a successful RPC. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - cc, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + host, port, err := hostPortFromListener(lis) if err != nil { - t.Fatalf("failed to dial local test server: %v", err) + t.Fatalf("failed to retrieve host and port of server: %v", err) } - defer cc.Close() + const serviceName = "my-service-fallback" - client := testpb.NewTestServiceClient(cc) - if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { - t.Fatalf("rpc EmptyCall() failed: %v", err) - } -} + // The inbound listener needs a route table that will never match on a VH, + // and thus shouldn't allow incoming RPC's to proceed. + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + // Since RBAC support is turned ON, all the RPC's should get denied with + // status code Unavailable due to not matching to a route of type Non + // Forwarding Action (Route Table not configured properly). + inboundLis := serverListenerWithBadRouteConfiguration(host, port) + resources.Listeners = append(resources.Listeners, inboundLis) -// TestServerSideXDS_FileWatcherCerts is an e2e test where xDS is enabled on the -// server-side and xdsCredentials are configured for security. The control plane -// sends security configuration pointing to the use of the file_watcher plugin, -// and we verify that a client connecting with TLS creds is able to successfully -// make an RPC. -func (s) TestServerSideXDS_FileWatcherCerts(t *testing.T) { - fs, nodeID, lis, cleanup := commonSetup(t) - defer cleanup() - - // Setup the fake management server to respond with a Listener resource with - // security configuration pointing to the file watcher plugin and requiring - // mTLS. - listener := listenerResourceWithSecurityConfig(t, lis) - if err := fs.Update(e2e.UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{listener}, - }); err != nil { - t.Error(err) - } - - // Create a ClientConn with TLS creds and make a successful RPC. - clientCreds := createClientTLSCredentials(t) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithTransportCredentials(clientCreds)) + // Setup the management server with client and server-side resources. + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithInsecure(), grpc.WithResolvers(resolver)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } defer cc.Close() client := testpb.NewTestServiceClient(cc) - if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { - t.Fatalf("rpc EmptyCall() failed: %v", err) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable { + t.Fatalf("EmptyCall() returned err with status: %v, if RBAC is disabled all RPC's should proceed as normal", status.Code(err)) + } + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Unavailable { + t.Fatalf("UnaryCall() returned err with status: %v, if RBAC is disabled all RPC's should proceed as normal", status.Code(err)) } } -// TestServerSideXDS_SecurityConfigChange is an e2e test where xDS is enabled on -// the server-side and xdsCredentials are configured for security. The control -// plane initially does not any security configuration. This forces the -// xdsCredentials to use fallback creds, which is this case is insecure creds. -// We verify that a client connecting with TLS creds is not able to successfully -// make an RPC. The control plan then sends a listener resource with security -// configuration pointing to the use of the file_watcher plugin and we verify -// that the same client is now able to successfully make an RPC. -func (s) TestServerSideXDS_SecurityConfigChange(t *testing.T) { - fs, nodeID, lis, cleanup := commonSetup(t) - defer cleanup() +func (s) TestRBACToggledOff_WithBadRouteConfiguration(t *testing.T) { + // Turn RBAC support off. + oldRBAC := env.RBACSupport + env.RBACSupport = false + defer func() { + env.RBACSupport = oldRBAC + }() - // Setup the fake management server to respond with a Listener resource that - // does not contain any security configuration. This should force the - // server-side xdsCredentials to use fallback. - listener := listenerResourceWithoutSecurityConfig(t, lis) - if err := fs.Update(e2e.UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{listener}, - }); err != nil { - t.Error(err) + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) } + const serviceName = "my-service-fallback" + + // The inbound listener needs a route table that will never match on a VH, + // and thus shouldn't allow incoming RPC's to proceed. + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + // This bad route configuration shouldn't affect incoming RPC's from + // proceeding as normal, as the configuration shouldn't be parsed due to the + // RBAC Environment variable not being set to true. + inboundLis := serverListenerWithBadRouteConfiguration(host, port) + resources.Listeners = append(resources.Listeners, inboundLis) - // Create a ClientConn with TLS creds. This should fail since the server is - // using fallback credentials which in this case in insecure creds. - clientCreds := createClientTLSCredentials(t) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithTransportCredentials(clientCreds)) + // Setup the management server with client and server-side resources. + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithInsecure(), grpc.WithResolvers(resolver)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } defer cc.Close() - // We don't set 'waitForReady` here since we want this call to failfast. client := testpb.NewTestServiceClient(cc) - if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Code() != codes.Unavailable { - t.Fatal("rpc EmptyCall() succeeded when expected to fail") - } - - // Setup the fake management server to respond with a Listener resource with - // security configuration pointing to the file watcher plugin and requiring - // mTLS. - listener = listenerResourceWithSecurityConfig(t, lis) - if err := fs.Update(e2e.UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{listener}, - }); err != nil { - t.Error(err) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.OK { + t.Fatalf("EmptyCall() returned err with status: %v, if RBAC is disabled all RPC's should proceed as normal", status.Code(err)) } - - // Make another RPC with `waitForReady` set and expect this to succeed. - if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { - t.Fatalf("rpc EmptyCall() failed: %v", err) + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.OK { + t.Fatalf("UnaryCall() returned err with status: %v, if RBAC is disabled all RPC's should proceed as normal", status.Code(err)) } } diff --git a/xds/internal/test/xds_server_serving_mode_test.go b/xds/internal/test/xds_server_serving_mode_test.go new file mode 100644 index 00000000000..ac4b3929cb6 --- /dev/null +++ b/xds/internal/test/xds_server_serving_mode_test.go @@ -0,0 +1,388 @@ +//go:build !386 +// +build !386 + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package xds_test contains e2e tests for xDS use. +package xds_test + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + xdscreds "google.golang.org/grpc/credentials/xds" + testpb "google.golang.org/grpc/test/grpc_testing" + "google.golang.org/grpc/xds" + xdstestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" +) + +// TestServerSideXDS_RedundantUpdateSuppression tests the scenario where the +// control plane sends the same resource update. It verifies that the mode +// change callback is not invoked and client connections to the server are not +// recycled. +func (s) TestServerSideXDS_RedundantUpdateSuppression(t *testing.T) { + managementServer, nodeID, bootstrapContents, _, cleanup := setupManagementServer(t) + defer cleanup() + + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + t.Fatal(err) + } + lis, err := xdstestutils.LocalTCPListener() + if err != nil { + t.Fatalf("testutils.LocalTCPListener() failed: %v", err) + } + updateCh := make(chan connectivity.ServingMode, 1) + + // Create a server option to get notified about serving mode changes. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + modeChangeOpt := xds.ServingModeCallback(func(addr net.Addr, args xds.ServingModeChangeArgs) { + t.Logf("serving mode for listener %q changed to %q, err: %v", addr.String(), args.Mode, args.Err) + updateCh <- args.Mode + }) + + // Initialize an xDS-enabled gRPC server and register the stubServer on it. + server := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) + defer server.Stop() + testpb.RegisterTestServiceServer(server, &testService{}) + + // Setup the management server to respond with the listener resources. + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + listener := e2e.DefaultServerListener(host, port, e2e.SecurityLevelNone) + resources := e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener}, + } + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + // Wait for the listener to move to "serving" mode. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh: + if mode != connectivity.ServingModeServing { + t.Fatalf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + + // Create a ClientConn and make a successful RPCs. + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + waitForSuccessfulRPC(ctx, t, cc) + + // Start a goroutine to make sure that we do not see any connectivity state + // changes on the client connection. If redundant updates are not + // suppressed, server will recycle client connections. + errCh := make(chan error, 1) + go func() { + if cc.WaitForStateChange(ctx, connectivity.Ready) { + errCh <- fmt.Errorf("unexpected connectivity state change {%s --> %s} on the client connection", connectivity.Ready, cc.GetState()) + return + } + errCh <- nil + }() + + // Update the management server with the same listener resource. This will + // update the resource version though, and should result in a the management + // server sending the same resource to the xDS-enabled gRPC server. + if err := managementServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener}, + }); err != nil { + t.Fatal(err) + } + + // Since redundant resource updates are suppressed, we should not see the + // mode change callback being invoked. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + select { + case <-sCtx.Done(): + case mode := <-updateCh: + t.Fatalf("unexpected mode change callback with new mode %v", mode) + } + + // Make sure RPCs continue to succeed. + waitForSuccessfulRPC(ctx, t, cc) + + // Cancel the context to ensure that the WaitForStateChange call exits early + // and returns false. + cancel() + if err := <-errCh; err != nil { + t.Fatal(err) + } +} + +// TestServerSideXDS_ServingModeChanges tests the serving mode functionality in +// xDS enabled gRPC servers. It verifies that appropriate mode changes happen in +// the server, and also verifies behavior of clientConns under these modes. +func (s) TestServerSideXDS_ServingModeChanges(t *testing.T) { + managementServer, nodeID, bootstrapContents, _, cleanup := setupManagementServer(t) + defer cleanup() + + // Configure xDS credentials to be used on the server-side. + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{ + FallbackCreds: insecure.NewCredentials(), + }) + if err != nil { + t.Fatal(err) + } + + // Create two local listeners and pass it to Serve(). + lis1, err := xdstestutils.LocalTCPListener() + if err != nil { + t.Fatalf("testutils.LocalTCPListener() failed: %v", err) + } + lis2, err := xdstestutils.LocalTCPListener() + if err != nil { + t.Fatalf("testutils.LocalTCPListener() failed: %v", err) + } + + // Create a couple of channels on which mode updates will be pushed. + updateCh1 := make(chan connectivity.ServingMode, 1) + updateCh2 := make(chan connectivity.ServingMode, 1) + + // Create a server option to get notified about serving mode changes, and + // push the updated mode on the channels created above. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + modeChangeOpt := xds.ServingModeCallback(func(addr net.Addr, args xds.ServingModeChangeArgs) { + t.Logf("serving mode for listener %q changed to %q, err: %v", addr.String(), args.Mode, args.Err) + switch addr.String() { + case lis1.Addr().String(): + updateCh1 <- args.Mode + case lis2.Addr().String(): + updateCh2 <- args.Mode + default: + t.Logf("serving mode callback invoked for unknown listener address: %q", addr.String()) + } + }) + + // Initialize an xDS-enabled gRPC server and register the stubServer on it. + server := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) + defer server.Stop() + testpb.RegisterTestServiceServer(server, &testService{}) + + // Setup the management server to respond with server-side Listener + // resources for both listeners. + host1, port1, err := hostPortFromListener(lis1) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + listener1 := e2e.DefaultServerListener(host1, port1, e2e.SecurityLevelNone) + host2, port2, err := hostPortFromListener(lis2) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + listener2 := e2e.DefaultServerListener(host2, port2, e2e.SecurityLevelNone) + resources := e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener1, listener2}, + } + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + go func() { + if err := server.Serve(lis1); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + go func() { + if err := server.Serve(lis2); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + // Wait for both listeners to move to "serving" mode. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh1: + if mode != connectivity.ServingModeServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh2: + if mode != connectivity.ServingModeServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + + // Create a ClientConn to the first listener and make a successful RPCs. + cc1, err := grpc.Dial(lis1.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc1.Close() + waitForSuccessfulRPC(ctx, t, cc1) + + // Create a ClientConn to the second listener and make a successful RPCs. + cc2, err := grpc.Dial(lis2.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc2.Close() + waitForSuccessfulRPC(ctx, t, cc2) + + // Update the management server to remove the second listener resource. This + // should push only the second listener into "not-serving" mode. + if err := managementServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener1}, + }); err != nil { + t.Error(err) + } + + // Wait for lis2 to move to "not-serving" mode. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh2: + if mode != connectivity.ServingModeNotServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeNotServing) + } + } + + // Make sure RPCs succeed on cc1 and fail on cc2. + waitForSuccessfulRPC(ctx, t, cc1) + waitForFailedRPC(ctx, t, cc2) + + // Update the management server to remove the first listener resource as + // well. This should push the first listener into "not-serving" mode. Second + // listener is already in "not-serving" mode. + if err := managementServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{}, + }); err != nil { + t.Error(err) + } + + // Wait for lis1 to move to "not-serving" mode. lis2 was already removed + // from the xdsclient's resource cache. So, lis2's callback will not be + // invoked this time around. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh1: + if mode != connectivity.ServingModeNotServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeNotServing) + } + } + + // Make sure RPCs fail on both. + waitForFailedRPC(ctx, t, cc1) + waitForFailedRPC(ctx, t, cc2) + + // Make sure new connection attempts to "not-serving" servers fail. We use a + // short timeout since we expect this to fail. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if _, err := grpc.DialContext(sCtx, lis1.Addr().String(), grpc.WithBlock(), grpc.WithTransportCredentials(insecure.NewCredentials())); err == nil { + t.Fatal("successfully created clientConn to a server in \"not-serving\" state") + } + + // Update the management server with both listener resources. + if err := managementServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener1, listener2}, + }); err != nil { + t.Error(err) + } + + // Wait for both listeners to move to "serving" mode. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh1: + if mode != connectivity.ServingModeServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh2: + if mode != connectivity.ServingModeServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + + // The clientConns created earlier should be able to make RPCs now. + waitForSuccessfulRPC(ctx, t, cc1) + waitForSuccessfulRPC(ctx, t, cc2) +} + +func waitForSuccessfulRPC(ctx context.Context, t *testing.T, cc *grpc.ClientConn) { + t.Helper() + + c := testpb.NewTestServiceClient(cc) + if _, err := c.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) + } +} + +func waitForFailedRPC(ctx context.Context, t *testing.T, cc *grpc.ClientConn) { + t.Helper() + + // Attempt one RPC before waiting for the ticker to expire. + c := testpb.NewTestServiceClient(cc) + if _, err := c.EmptyCall(ctx, &testpb.Empty{}); err != nil { + return + } + + ticker := time.NewTimer(1 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + t.Fatalf("failure when waiting for RPCs to fail: %v", ctx.Err()) + case <-ticker.C: + if _, err := c.EmptyCall(ctx, &testpb.Empty{}); err != nil { + return + } + } + } +} diff --git a/xds/internal/testutils/balancer.go b/xds/internal/testutils/balancer.go index dab84a84e07..ff74da71cc9 100644 --- a/xds/internal/testutils/balancer.go +++ b/xds/internal/testutils/balancer.go @@ -46,21 +46,28 @@ var TestSubConns []*TestSubConn func init() { for i := 0; i < TestSubConnsCount; i++ { TestSubConns = append(TestSubConns, &TestSubConn{ - id: fmt.Sprintf("sc%d", i), + id: fmt.Sprintf("sc%d", i), + ConnectCh: make(chan struct{}, 1), }) } } // TestSubConn implements the SubConn interface, to be used in tests. type TestSubConn struct { - id string + id string + ConnectCh chan struct{} } // UpdateAddresses is a no-op. func (tsc *TestSubConn) UpdateAddresses([]resolver.Address) {} // Connect is a no-op. -func (tsc *TestSubConn) Connect() {} +func (tsc *TestSubConn) Connect() { + select { + case tsc.ConnectCh <- struct{}{}: + default: + } +} // String implements stringer to print human friendly error message. func (tsc *TestSubConn) String() string { @@ -76,8 +83,9 @@ type TestClientConn struct { RemoveSubConnCh chan balancer.SubConn // the last 10 subconn removed. UpdateAddressesAddrsCh chan []resolver.Address // last updated address via UpdateAddresses(). - NewPickerCh chan balancer.Picker // the last picker updated. - NewStateCh chan connectivity.State // the last state. + NewPickerCh chan balancer.Picker // the last picker updated. + NewStateCh chan connectivity.State // the last state. + ResolveNowCh chan resolver.ResolveNowOptions // the last ResolveNow(). subConnIdx int } @@ -92,8 +100,9 @@ func NewTestClientConn(t *testing.T) *TestClientConn { RemoveSubConnCh: make(chan balancer.SubConn, 10), UpdateAddressesAddrsCh: make(chan []resolver.Address, 1), - NewPickerCh: make(chan balancer.Picker, 1), - NewStateCh: make(chan connectivity.State, 1), + NewPickerCh: make(chan balancer.Picker, 1), + NewStateCh: make(chan connectivity.State, 1), + ResolveNowCh: make(chan resolver.ResolveNowOptions, 1), } } @@ -151,8 +160,12 @@ func (tcc *TestClientConn) UpdateState(bs balancer.State) { } // ResolveNow panics. -func (tcc *TestClientConn) ResolveNow(resolver.ResolveNowOptions) { - panic("not implemented") +func (tcc *TestClientConn) ResolveNow(o resolver.ResolveNowOptions) { + select { + case <-tcc.ResolveNowCh: + default: + } + tcc.ResolveNowCh <- o } // Target panics. diff --git a/xds/internal/testutils/e2e/bootstrap.go b/xds/internal/testutils/e2e/bootstrap.go index 25993f19fc3..99702032f81 100644 --- a/xds/internal/testutils/e2e/bootstrap.go +++ b/xds/internal/testutils/e2e/bootstrap.go @@ -21,98 +21,13 @@ package e2e import ( "encoding/json" "fmt" - "io/ioutil" - "os" - - "google.golang.org/grpc/xds/internal/env" ) -// TransportAPI refers to the API version for xDS transport protocol. -type TransportAPI int - -const ( - // TransportV2 refers to the v2 xDS transport protocol. - TransportV2 TransportAPI = iota - // TransportV3 refers to the v3 xDS transport protocol. - TransportV3 -) - -// BootstrapOptions wraps the parameters passed to SetupBootstrapFile. -type BootstrapOptions struct { - // Version is the xDS transport protocol version. - Version TransportAPI - // NodeID is the node identifier of the gRPC client/server node in the - // proxyless service mesh. - NodeID string - // ServerURI is the address of the management server. - ServerURI string - // ServerResourceNameID is the Listener resource name to fetch. - ServerResourceNameID string - // CertificateProviders is the certificate providers configuration. - CertificateProviders map[string]json.RawMessage -} - -// SetupBootstrapFile creates a temporary file with bootstrap contents, based on -// the passed in options, and updates the bootstrap environment variable to -// point to this file. -// -// Returns a cleanup function which will be non-nil if the setup process was -// completed successfully. It is the responsibility of the caller to invoke the -// cleanup function at the end of the test. -func SetupBootstrapFile(opts BootstrapOptions) (func(), error) { - f, err := ioutil.TempFile("", "test_xds_bootstrap_*") - if err != nil { - return nil, fmt.Errorf("failed to created bootstrap file: %v", err) - } - - cfg := &bootstrapConfig{ - XdsServers: []server{ - { - ServerURI: opts.ServerURI, - ChannelCreds: []creds{ - { - Type: "insecure", - }, - }, - }, - }, - Node: node{ - ID: opts.NodeID, - }, - CertificateProviders: opts.CertificateProviders, - GRPCServerResourceNameID: opts.ServerResourceNameID, - } - switch opts.Version { - case TransportV2: - // TODO: Add any v2 specific fields. - case TransportV3: - cfg.XdsServers[0].ServerFeatures = append(cfg.XdsServers[0].ServerFeatures, "xds_v3") - default: - return nil, fmt.Errorf("unsupported xDS transport protocol version: %v", opts.Version) - } - - bootstrapContents, err := json.MarshalIndent(cfg, "", " ") - if err != nil { - return nil, fmt.Errorf("failed to created bootstrap file: %v", err) - } - if err := ioutil.WriteFile(f.Name(), bootstrapContents, 0644); err != nil { - return nil, fmt.Errorf("failed to created bootstrap file: %v", err) - } - logger.Infof("Created bootstrap file at %q with contents: %s\n", f.Name(), bootstrapContents) - - origBootstrapFileName := env.BootstrapFileName - env.BootstrapFileName = f.Name() - return func() { - os.Remove(f.Name()) - env.BootstrapFileName = origBootstrapFileName - }, nil -} - // DefaultFileWatcherConfig is a helper function to create a default certificate // provider plugin configuration. The test is expected to have setup the files // appropriately before this configuration is used to instantiate providers. -func DefaultFileWatcherConfig(certPath, keyPath, caPath string) map[string]json.RawMessage { - cfg := fmt.Sprintf(`{ +func DefaultFileWatcherConfig(certPath, keyPath, caPath string) json.RawMessage { + return json.RawMessage(fmt.Sprintf(`{ "plugin_name": "file_watcher", "config": { "certificate_file": %q, @@ -120,30 +35,5 @@ func DefaultFileWatcherConfig(certPath, keyPath, caPath string) map[string]json. "ca_certificate_file": %q, "refresh_interval": "600s" } - }`, certPath, keyPath, caPath) - return map[string]json.RawMessage{ - "google_cloud_private_spiffe": json.RawMessage(cfg), - } -} - -type bootstrapConfig struct { - XdsServers []server `json:"xds_servers,omitempty"` - Node node `json:"node,omitempty"` - CertificateProviders map[string]json.RawMessage `json:"certificate_providers,omitempty"` - GRPCServerResourceNameID string `json:"grpc_server_resource_name_id,omitempty"` -} - -type server struct { - ServerURI string `json:"server_uri,omitempty"` - ChannelCreds []creds `json:"channel_creds,omitempty"` - ServerFeatures []string `json:"server_features,omitempty"` -} - -type creds struct { - Type string `json:"type,omitempty"` - Config interface{} `json:"config,omitempty"` -} - -type node struct { - ID string `json:"id,omitempty"` + }`, certPath, keyPath, caPath)) } diff --git a/xds/internal/testutils/e2e/clientresources.go b/xds/internal/testutils/e2e/clientresources.go index 79424b13b91..f3f7f6307c5 100644 --- a/xds/internal/testutils/e2e/clientresources.go +++ b/xds/internal/testutils/e2e/clientresources.go @@ -19,9 +19,13 @@ package e2e import ( + "fmt" + "net" + "strconv" + "github.com/envoyproxy/go-control-plane/pkg/wellknown" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/internal/testutils" v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -30,37 +34,76 @@ import ( v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" v3routerpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/router/v3" v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" - anypb "github.com/golang/protobuf/ptypes/any" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" wrapperspb "github.com/golang/protobuf/ptypes/wrappers" ) -func any(m proto.Message) *anypb.Any { - a, err := ptypes.MarshalAny(m) - if err != nil { - panic("error marshalling any: " + err.Error()) - } - return a +const ( + // ServerListenerResourceNameTemplate is the Listener resource name template + // used on the server side. + ServerListenerResourceNameTemplate = "grpc/server?xds.resource.listening_address=%s" + // ClientSideCertProviderInstance is the certificate provider instance name + // used in the Cluster resource on the client side. + ClientSideCertProviderInstance = "client-side-certificate-provider-instance" + // ServerSideCertProviderInstance is the certificate provider instance name + // used in the Listener resource on the server side. + ServerSideCertProviderInstance = "server-side-certificate-provider-instance" +) + +// SecurityLevel allows the test to control the security level to be used in the +// resource returned by this package. +type SecurityLevel int + +const ( + // SecurityLevelNone is used when no security configuration is required. + SecurityLevelNone SecurityLevel = iota + // SecurityLevelTLS is used when security configuration corresponding to TLS + // is required. Only the server presents an identity certificate in this + // configuration. + SecurityLevelTLS + // SecurityLevelMTLS is used when security ocnfiguration corresponding to + // mTLS is required. Both client and server present identity certificates in + // this configuration. + SecurityLevelMTLS +) + +// ResourceParams wraps the arguments to be passed to DefaultClientResources. +type ResourceParams struct { + // DialTarget is the client's dial target. This is used as the name of the + // Listener resource. + DialTarget string + // NodeID is the id of the xdsClient to which this update is to be pushed. + NodeID string + // Host is the host of the default Endpoint resource. + Host string + // port is the port of the default Endpoint resource. + Port uint32 + // SecLevel controls the security configuration in the Cluster resource. + SecLevel SecurityLevel } // DefaultClientResources returns a set of resources (LDS, RDS, CDS, EDS) for a // client to generically connect to one server. -func DefaultClientResources(target, nodeID, host string, port uint32) UpdateOptions { - const routeConfigName = "route" - const clusterName = "cluster" - const endpointsName = "endpoints" - +func DefaultClientResources(params ResourceParams) UpdateOptions { + routeConfigName := "route-" + params.DialTarget + clusterName := "cluster-" + params.DialTarget + endpointsName := "endpoints-" + params.DialTarget return UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{DefaultListener(target, routeConfigName)}, - Routes: []*v3routepb.RouteConfiguration{DefaultRouteConfig(routeConfigName, target, clusterName)}, - Clusters: []*v3clusterpb.Cluster{DefaultCluster(clusterName, endpointsName)}, - Endpoints: []*v3endpointpb.ClusterLoadAssignment{DefaultEndpoint(endpointsName, host, port)}, + NodeID: params.NodeID, + Listeners: []*v3listenerpb.Listener{DefaultClientListener(params.DialTarget, routeConfigName)}, + Routes: []*v3routepb.RouteConfiguration{DefaultRouteConfig(routeConfigName, params.DialTarget, clusterName)}, + Clusters: []*v3clusterpb.Cluster{DefaultCluster(clusterName, endpointsName, params.SecLevel)}, + Endpoints: []*v3endpointpb.ClusterLoadAssignment{DefaultEndpoint(endpointsName, params.Host, []uint32{params.Port})}, } } -// DefaultListener returns a basic xds Listener resource. -func DefaultListener(target, routeName string) *v3listenerpb.Listener { - hcm := any(&v3httppb.HttpConnectionManager{ +// RouterHTTPFilter is the HTTP Filter configuration for the Router filter. +var RouterHTTPFilter = HTTPFilter("router", &v3routerpb.Router{}) + +// DefaultClientListener returns a basic xds Listener resource to be used on +// the client side. +func DefaultClientListener(target, routeName string) *v3listenerpb.Listener { + hcm := testutils.MarshalAny(&v3httppb.HttpConnectionManager{ RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{Rds: &v3httppb.Rds{ ConfigSource: &v3corepb.ConfigSource{ ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, @@ -82,12 +125,164 @@ func DefaultListener(target, routeName string) *v3listenerpb.Listener { } } +// DefaultServerListener returns a basic xds Listener resource to be used on +// the server side. +func DefaultServerListener(host string, port uint32, secLevel SecurityLevel) *v3listenerpb.Listener { + var tlsContext *v3tlspb.DownstreamTlsContext + switch secLevel { + case SecurityLevelNone: + case SecurityLevelTLS: + tlsContext = &v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ServerSideCertProviderInstance, + }, + }, + } + case SecurityLevelMTLS: + tlsContext = &v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ServerSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ServerSideCertProviderInstance, + }, + }, + }, + } + } + + var ts *v3corepb.TransportSocket + if tlsContext != nil { + ts = &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(tlsContext), + }, + } + } + return &v3listenerpb.Listener{ + Name: fmt.Sprintf(ServerListenerResourceNameTemplate, net.JoinHostPort(host, strconv.Itoa(int(port)))), + Address: &v3corepb.Address{ + Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Address: host, + PortSpecifier: &v3corepb.SocketAddress_PortValue{ + PortValue: port, + }, + }, + }, + }, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "v4-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + // This "*" string matches on any incoming authority. This is to ensure any + // incoming RPC matches to Route_NonForwardingAction and will proceed as + // normal. + Domains: []string{"*"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{RouterHTTPFilter}, + }), + }, + }, + }, + TransportSocket: ts, + }, + { + Name: "v6-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + // This "*" string matches on any incoming authority. This is to ensure any + // incoming RPC matches to Route_NonForwardingAction and will proceed as + // normal. + Domains: []string{"*"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{RouterHTTPFilter}, + }), + }, + }, + }, + TransportSocket: ts, + }, + }, + } +} + // HTTPFilter constructs an xds HttpFilter with the provided name and config. func HTTPFilter(name string, config proto.Message) *v3httppb.HttpFilter { return &v3httppb.HttpFilter{ Name: name, ConfigType: &v3httppb.HttpFilter_TypedConfig{ - TypedConfig: any(config), + TypedConfig: testutils.MarshalAny(config), }, } } @@ -109,8 +304,36 @@ func DefaultRouteConfig(routeName, ldsTarget, clusterName string) *v3routepb.Rou } // DefaultCluster returns a basic xds Cluster resource. -func DefaultCluster(clusterName, edsServiceName string) *v3clusterpb.Cluster { - return &v3clusterpb.Cluster{ +func DefaultCluster(clusterName, edsServiceName string, secLevel SecurityLevel) *v3clusterpb.Cluster { + var tlsContext *v3tlspb.UpstreamTlsContext + switch secLevel { + case SecurityLevelNone: + case SecurityLevelTLS: + tlsContext = &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ClientSideCertProviderInstance, + }, + }, + }, + } + case SecurityLevelMTLS: + tlsContext = &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ClientSideCertProviderInstance, + }, + }, + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ClientSideCertProviderInstance, + }, + }, + } + } + + cluster := &v3clusterpb.Cluster{ Name: clusterName, ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ @@ -123,24 +346,37 @@ func DefaultCluster(clusterName, edsServiceName string) *v3clusterpb.Cluster { }, LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, } + if tlsContext != nil { + cluster.TransportSocket = &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(tlsContext), + }, + } + } + return cluster } // DefaultEndpoint returns a basic xds Endpoint resource. -func DefaultEndpoint(clusterName string, host string, port uint32) *v3endpointpb.ClusterLoadAssignment { +func DefaultEndpoint(clusterName string, host string, ports []uint32) *v3endpointpb.ClusterLoadAssignment { + var lbEndpoints []*v3endpointpb.LbEndpoint + for _, port := range ports { + lbEndpoints = append(lbEndpoints, &v3endpointpb.LbEndpoint{ + HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{Endpoint: &v3endpointpb.Endpoint{ + Address: &v3corepb.Address{Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Protocol: v3corepb.SocketAddress_TCP, + Address: host, + PortSpecifier: &v3corepb.SocketAddress_PortValue{PortValue: port}}, + }}, + }}, + }) + } return &v3endpointpb.ClusterLoadAssignment{ ClusterName: clusterName, Endpoints: []*v3endpointpb.LocalityLbEndpoints{{ - Locality: &v3corepb.Locality{SubZone: "subzone"}, - LbEndpoints: []*v3endpointpb.LbEndpoint{{ - HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{Endpoint: &v3endpointpb.Endpoint{ - Address: &v3corepb.Address{Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Protocol: v3corepb.SocketAddress_TCP, - Address: host, - PortSpecifier: &v3corepb.SocketAddress_PortValue{PortValue: uint32(port)}}, - }}, - }}, - }}, + Locality: &v3corepb.Locality{SubZone: "subzone"}, + LbEndpoints: lbEndpoints, LoadBalancingWeight: &wrapperspb.UInt32Value{Value: 1}, Priority: 0, }}, diff --git a/xds/internal/testutils/e2e/server.go b/xds/internal/testutils/e2e/server.go index cc55c595cae..e47dcc5213c 100644 --- a/xds/internal/testutils/e2e/server.go +++ b/xds/internal/testutils/e2e/server.go @@ -33,6 +33,7 @@ import ( v3discoverygrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" "github.com/envoyproxy/go-control-plane/pkg/cache/types" v3cache "github.com/envoyproxy/go-control-plane/pkg/cache/v3" + v3resource "github.com/envoyproxy/go-control-plane/pkg/resource/v3" v3server "github.com/envoyproxy/go-control-plane/pkg/server/v3" "google.golang.org/grpc" @@ -45,10 +46,22 @@ var logger = grpclog.Component("xds-e2e") // envoyproxy/go-control-plane/pkg/log. This is passed to the Snapshot cache. type serverLogger struct{} -func (l serverLogger) Debugf(format string, args ...interface{}) { logger.Infof(format, args...) } -func (l serverLogger) Infof(format string, args ...interface{}) { logger.Infof(format, args...) } -func (l serverLogger) Warnf(format string, args ...interface{}) { logger.Warningf(format, args...) } -func (l serverLogger) Errorf(format string, args ...interface{}) { logger.Errorf(format, args...) } +func (l serverLogger) Debugf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + logger.InfoDepth(1, msg) +} +func (l serverLogger) Infof(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + logger.InfoDepth(1, msg) +} +func (l serverLogger) Warnf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + logger.WarningDepth(1, msg) +} +func (l serverLogger) Errorf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + logger.ErrorDepth(1, msg) +} // ManagementServer is a thin wrapper around the xDS control plane // implementation provided by envoyproxy/go-control-plane. @@ -113,22 +126,38 @@ type UpdateOptions struct { Clusters []*v3clusterpb.Cluster Routes []*v3routepb.RouteConfiguration Listeners []*v3listenerpb.Listener + // SkipValidation indicates whether we want to skip validation (by not + // calling snapshot.Consistent()). It can be useful for negative tests, + // where we send updates that the client will NACK. + SkipValidation bool } // Update changes the resource snapshot held by the management server, which // updates connected clients as required. -func (s *ManagementServer) Update(opts UpdateOptions) error { +func (s *ManagementServer) Update(ctx context.Context, opts UpdateOptions) error { s.version++ // Create a snapshot with the passed in resources. - snapshot := v3cache.NewSnapshot(strconv.Itoa(s.version), resourceSlice(opts.Endpoints), resourceSlice(opts.Clusters), resourceSlice(opts.Routes), resourceSlice(opts.Listeners), nil /*runtimes*/, nil /*secrets*/) - if err := snapshot.Consistent(); err != nil { - return fmt.Errorf("failed to create new resource snapshot: %v", err) + resources := map[v3resource.Type][]types.Resource{ + v3resource.ListenerType: resourceSlice(opts.Listeners), + v3resource.RouteType: resourceSlice(opts.Routes), + v3resource.ClusterType: resourceSlice(opts.Clusters), + v3resource.EndpointType: resourceSlice(opts.Endpoints), + } + snapshot, err := v3cache.NewSnapshot(strconv.Itoa(s.version), resources) + if err != nil { + return fmt.Errorf("failed to create new snapshot cache: %v", err) + + } + if !opts.SkipValidation { + if err := snapshot.Consistent(); err != nil { + return fmt.Errorf("failed to create new resource snapshot: %v", err) + } } logger.Infof("Created new resource snapshot...") // Update the cache with the new resource snapshot. - if err := s.cache.SetSnapshot(opts.NodeID, snapshot); err != nil { + if err := s.cache.SetSnapshot(ctx, opts.NodeID, snapshot); err != nil { return fmt.Errorf("failed to update resource snapshot in management server: %v", err) } logger.Infof("Updated snapshot cache with resource snapshot...") @@ -141,7 +170,6 @@ func (s *ManagementServer) Stop() { s.cancel() } s.gs.Stop() - logger.Infof("Stopped the xDS management server...") } // resourceSlice accepts a slice of any type of proto messages and returns a diff --git a/xds/internal/testutils/fakeclient/client.go b/xds/internal/testutils/fakeclient/client.go index 0978125b8ae..b582fd9bee9 100644 --- a/xds/internal/testutils/fakeclient/client.go +++ b/xds/internal/testutils/fakeclient/client.go @@ -22,15 +22,21 @@ package fakeclient import ( "context" + "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // Client is a fake implementation of an xds client. It exposes a bunch of // channels to signal the occurrence of various events. type Client struct { + // Embed XDSClient so this fake client implements the interface, but it's + // never set (it's always nil). This may cause nil panic since not all the + // methods are implemented. + xdsclient.XDSClient + name string ldsWatchCh *testutils.Channel rdsWatchCh *testutils.Channel @@ -41,14 +47,16 @@ type Client struct { cdsCancelCh *testutils.Channel edsCancelCh *testutils.Channel loadReportCh *testutils.Channel - closeCh *testutils.Channel + lrsCancelCh *testutils.Channel loadStore *load.Store bootstrapCfg *bootstrap.Config - ldsCb func(xdsclient.ListenerUpdate, error) - rdsCb func(xdsclient.RouteConfigUpdate, error) - cdsCb func(xdsclient.ClusterUpdate, error) - edsCb func(xdsclient.EndpointsUpdate, error) + ldsCb func(xdsclient.ListenerUpdate, error) + rdsCbs map[string]func(xdsclient.RouteConfigUpdate, error) + cdsCbs map[string]func(xdsclient.ClusterUpdate, error) + edsCbs map[string]func(xdsclient.EndpointsUpdate, error) + + Closed *grpcsync.Event // fired when Close is called. } // WatchListener registers a LDS watch. @@ -87,10 +95,10 @@ func (xdsC *Client) WaitForCancelListenerWatch(ctx context.Context) error { // WatchRouteConfig registers a RDS watch. func (xdsC *Client) WatchRouteConfig(routeName string, callback func(xdsclient.RouteConfigUpdate, error)) func() { - xdsC.rdsCb = callback + xdsC.rdsCbs[routeName] = callback xdsC.rdsWatchCh.Send(routeName) return func() { - xdsC.rdsCancelCh.Send(nil) + xdsC.rdsCancelCh.Send(routeName) } } @@ -108,23 +116,39 @@ func (xdsC *Client) WaitForWatchRouteConfig(ctx context.Context) (string, error) // // Not thread safe with WatchRouteConfig. Only call this after // WaitForWatchRouteConfig. -func (xdsC *Client) InvokeWatchRouteConfigCallback(update xdsclient.RouteConfigUpdate, err error) { - xdsC.rdsCb(update, err) +func (xdsC *Client) InvokeWatchRouteConfigCallback(name string, update xdsclient.RouteConfigUpdate, err error) { + if len(xdsC.rdsCbs) != 1 { + xdsC.rdsCbs[name](update, err) + return + } + // Keeps functionality with previous usage of this on client side, if single + // callback call that callback. + var routeName string + for route := range xdsC.rdsCbs { + routeName = route + } + xdsC.rdsCbs[routeName](update, err) } // WaitForCancelRouteConfigWatch waits for a RDS watch to be cancelled and returns // context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForCancelRouteConfigWatch(ctx context.Context) error { - _, err := xdsC.rdsCancelCh.Receive(ctx) - return err +func (xdsC *Client) WaitForCancelRouteConfigWatch(ctx context.Context) (string, error) { + val, err := xdsC.rdsCancelCh.Receive(ctx) + if err != nil { + return "", err + } + return val.(string), err } // WatchCluster registers a CDS watch. func (xdsC *Client) WatchCluster(clusterName string, callback func(xdsclient.ClusterUpdate, error)) func() { - xdsC.cdsCb = callback + // Due to the tree like structure of aggregate clusters, there can be multiple callbacks persisted for each cluster + // node. However, the client doesn't care about the parent child relationship between the nodes, only that it invokes + // the right callback for a particular cluster. + xdsC.cdsCbs[clusterName] = callback xdsC.cdsWatchCh.Send(clusterName) return func() { - xdsC.cdsCancelCh.Send(nil) + xdsC.cdsCancelCh.Send(clusterName) } } @@ -143,22 +167,36 @@ func (xdsC *Client) WaitForWatchCluster(ctx context.Context) (string, error) { // Not thread safe with WatchCluster. Only call this after // WaitForWatchCluster. func (xdsC *Client) InvokeWatchClusterCallback(update xdsclient.ClusterUpdate, err error) { - xdsC.cdsCb(update, err) + // Keeps functionality with previous usage of this, if single callback call that callback. + if len(xdsC.cdsCbs) == 1 { + var clusterName string + for cluster := range xdsC.cdsCbs { + clusterName = cluster + } + xdsC.cdsCbs[clusterName](update, err) + } else { + // Have what callback you call with the update determined by the service name in the ClusterUpdate. Left up to the + // caller to make sure the cluster update matches with a persisted callback. + xdsC.cdsCbs[update.ClusterName](update, err) + } } // WaitForCancelClusterWatch waits for a CDS watch to be cancelled and returns // context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForCancelClusterWatch(ctx context.Context) error { - _, err := xdsC.cdsCancelCh.Receive(ctx) - return err +func (xdsC *Client) WaitForCancelClusterWatch(ctx context.Context) (string, error) { + clusterNameReceived, err := xdsC.cdsCancelCh.Receive(ctx) + if err != nil { + return "", err + } + return clusterNameReceived.(string), err } // WatchEndpoints registers an EDS watch for provided clusterName. func (xdsC *Client) WatchEndpoints(clusterName string, callback func(xdsclient.EndpointsUpdate, error)) (cancel func()) { - xdsC.edsCb = callback + xdsC.edsCbs[clusterName] = callback xdsC.edsWatchCh.Send(clusterName) return func() { - xdsC.edsCancelCh.Send(nil) + xdsC.edsCancelCh.Send(clusterName) } } @@ -176,15 +214,28 @@ func (xdsC *Client) WaitForWatchEDS(ctx context.Context) (string, error) { // // Not thread safe with WatchEndpoints. Only call this after // WaitForWatchEDS. -func (xdsC *Client) InvokeWatchEDSCallback(update xdsclient.EndpointsUpdate, err error) { - xdsC.edsCb(update, err) +func (xdsC *Client) InvokeWatchEDSCallback(name string, update xdsclient.EndpointsUpdate, err error) { + if len(xdsC.edsCbs) != 1 { + // This may panic if name isn't found. But it's fine for tests. + xdsC.edsCbs[name](update, err) + return + } + // Keeps functionality with previous usage of this, if single callback call + // that callback. + for n := range xdsC.edsCbs { + name = n + } + xdsC.edsCbs[name](update, err) } // WaitForCancelEDSWatch waits for a EDS watch to be cancelled and returns // context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForCancelEDSWatch(ctx context.Context) error { - _, err := xdsC.edsCancelCh.Receive(ctx) - return err +func (xdsC *Client) WaitForCancelEDSWatch(ctx context.Context) (string, error) { + edsNameReceived, err := xdsC.edsCancelCh.Receive(ctx) + if err != nil { + return "", err + } + return edsNameReceived.(string), err } // ReportLoadArgs wraps the arguments passed to ReportLoad. @@ -196,7 +247,16 @@ type ReportLoadArgs struct { // ReportLoad starts reporting load about clusterName to server. func (xdsC *Client) ReportLoad(server string) (loadStore *load.Store, cancel func()) { xdsC.loadReportCh.Send(ReportLoadArgs{Server: server}) - return xdsC.loadStore, func() {} + return xdsC.loadStore, func() { + xdsC.lrsCancelCh.Send(nil) + } +} + +// WaitForCancelReportLoad waits for a load report to be cancelled and returns +// context.DeadlineExceeded otherwise. +func (xdsC *Client) WaitForCancelReportLoad(ctx context.Context) error { + _, err := xdsC.lrsCancelCh.Receive(ctx) + return err } // LoadStore returns the underlying load data store. @@ -208,19 +268,15 @@ func (xdsC *Client) LoadStore() *load.Store { // returns the arguments passed to it. func (xdsC *Client) WaitForReportLoad(ctx context.Context) (ReportLoadArgs, error) { val, err := xdsC.loadReportCh.Receive(ctx) - return val.(ReportLoadArgs), err + if err != nil { + return ReportLoadArgs{}, err + } + return val.(ReportLoadArgs), nil } -// Close closes the xds client. +// Close fires xdsC.Closed, indicating it was called. func (xdsC *Client) Close() { - xdsC.closeCh.Send(nil) -} - -// WaitForClose waits for Close to be invoked on this client and returns -// context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForClose(ctx context.Context) error { - _, err := xdsC.closeCh.Receive(ctx) - return err + xdsC.Closed.Fire() } // BootstrapConfig returns the bootstrap config. @@ -250,15 +306,19 @@ func NewClientWithName(name string) *Client { return &Client{ name: name, ldsWatchCh: testutils.NewChannel(), - rdsWatchCh: testutils.NewChannel(), - cdsWatchCh: testutils.NewChannel(), - edsWatchCh: testutils.NewChannel(), + rdsWatchCh: testutils.NewChannelWithSize(10), + cdsWatchCh: testutils.NewChannelWithSize(10), + edsWatchCh: testutils.NewChannelWithSize(10), ldsCancelCh: testutils.NewChannel(), - rdsCancelCh: testutils.NewChannel(), - cdsCancelCh: testutils.NewChannel(), - edsCancelCh: testutils.NewChannel(), + rdsCancelCh: testutils.NewChannelWithSize(10), + cdsCancelCh: testutils.NewChannelWithSize(10), + edsCancelCh: testutils.NewChannelWithSize(10), loadReportCh: testutils.NewChannel(), - closeCh: testutils.NewChannel(), + lrsCancelCh: testutils.NewChannel(), loadStore: load.NewStore(), + rdsCbs: make(map[string]func(xdsclient.RouteConfigUpdate, error)), + cdsCbs: make(map[string]func(xdsclient.ClusterUpdate, error)), + edsCbs: make(map[string]func(xdsclient.EndpointsUpdate, error)), + Closed: grpcsync.NewEvent(), } } diff --git a/xds/internal/testutils/protos.go b/xds/internal/testutils/protos.go index e0dba0e2b30..fc3cdf307fc 100644 --- a/xds/internal/testutils/protos.go +++ b/xds/internal/testutils/protos.go @@ -59,7 +59,7 @@ type ClusterLoadAssignmentBuilder struct { // NewClusterLoadAssignmentBuilder creates a ClusterLoadAssignmentBuilder. func NewClusterLoadAssignmentBuilder(clusterName string, dropPercents map[string]uint32) *ClusterLoadAssignmentBuilder { - var drops []*v2xdspb.ClusterLoadAssignment_Policy_DropOverload + drops := make([]*v2xdspb.ClusterLoadAssignment_Policy_DropOverload, 0, len(dropPercents)) for n, d := range dropPercents { drops = append(drops, &v2xdspb.ClusterLoadAssignment_Policy_DropOverload{ Category: n, @@ -88,7 +88,7 @@ type AddLocalityOptions struct { // AddLocality adds a locality to the builder. func (clab *ClusterLoadAssignmentBuilder) AddLocality(subzone string, weight uint32, priority uint32, addrsWithPort []string, opts *AddLocalityOptions) { - var lbEndPoints []*v2endpointpb.LbEndpoint + lbEndPoints := make([]*v2endpointpb.LbEndpoint, 0, len(addrsWithPort)) for i, a := range addrsWithPort { host, portStr, err := net.SplitHostPort(a) if err != nil { diff --git a/xds/internal/xdsclient/attributes.go b/xds/internal/xdsclient/attributes.go new file mode 100644 index 00000000000..d2357df0727 --- /dev/null +++ b/xds/internal/xdsclient/attributes.go @@ -0,0 +1,59 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/load" +) + +type clientKeyType string + +const clientKey = clientKeyType("grpc.xds.internal.client.Client") + +// XDSClient is a full fledged gRPC client which queries a set of discovery APIs +// (collectively termed as xDS) on a remote management server, to discover +// various dynamic resources. +type XDSClient interface { + WatchListener(string, func(ListenerUpdate, error)) func() + WatchRouteConfig(string, func(RouteConfigUpdate, error)) func() + WatchCluster(string, func(ClusterUpdate, error)) func() + WatchEndpoints(clusterName string, edsCb func(EndpointsUpdate, error)) (cancel func()) + ReportLoad(server string) (*load.Store, func()) + + DumpLDS() (string, map[string]UpdateWithMD) + DumpRDS() (string, map[string]UpdateWithMD) + DumpCDS() (string, map[string]UpdateWithMD) + DumpEDS() (string, map[string]UpdateWithMD) + + BootstrapConfig() *bootstrap.Config + Close() +} + +// FromResolverState returns the Client from state, or nil if not present. +func FromResolverState(state resolver.State) XDSClient { + cs, _ := state.Attributes.Value(clientKey).(XDSClient) + return cs +} + +// SetClient sets c in state and returns the new state. +func SetClient(state resolver.State, c XDSClient) resolver.State { + state.Attributes = state.Attributes.WithValues(clientKey, c) + return state +} diff --git a/xds/internal/client/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go similarity index 88% rename from xds/internal/client/bootstrap/bootstrap.go rename to xds/internal/xdsclient/bootstrap/bootstrap.go index c48bb797f36..fa229d99593 100644 --- a/xds/internal/client/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -35,7 +35,8 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" - "google.golang.org/grpc/xds/internal/env" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/xds/internal/version" ) @@ -79,10 +80,12 @@ type Config struct { // CertProviderConfigs contains a mapping from certificate provider plugin // instance names to parsed buildable configs. CertProviderConfigs map[string]*certprovider.BuildableConfig - // ServerResourceNameID contains the value to be used as the id in the - // resource name used to fetch the Listener resource on the xDS-enabled gRPC - // server. - ServerResourceNameID string + // ServerListenerResourceNameTemplate is a template for the name of the + // Listener resource to subscribe to for a gRPC server. If the token `%s` is + // present in the string, it will be replaced with the server's listening + // "IP:port" (e.g., "0.0.0.0:8080", "[::]:8080"). For example, a value of + // "example/resource/%s" could become "example/resource/0.0.0.0:8080". + ServerListenerResourceNameTemplate string } type channelCreds struct { @@ -124,16 +127,18 @@ func bootstrapConfigFromEnvVariable() ([]byte, error) { // // The format of the bootstrap file will be as follows: // { -// "xds_server": { -// "server_uri": , -// "channel_creds": [ -// { -// "type": , -// "config": -// } -// ], -// "server_features": [ ... ], -// }, +// "xds_servers": [ +// { +// "server_uri": , +// "channel_creds": [ +// { +// "type": , +// "config": +// } +// ], +// "server_features": [ ... ], +// } +// ], // "node": , // "certificate_providers" : { // "default": { @@ -145,7 +150,7 @@ func bootstrapConfigFromEnvVariable() ([]byte, error) { // "config": { foo plugin config in JSON } // } // }, -// "grpc_server_resource_name_id": "grpc/server" +// "server_listener_resource_name_template": "grpc/server?xds.resource.listening_address=%s" // } // // Currently, we support exactly one type of credential, which is @@ -157,13 +162,19 @@ func bootstrapConfigFromEnvVariable() ([]byte, error) { // fields left unspecified, in which case the caller should use some sane // defaults. func NewConfig() (*Config, error) { - config := &Config{} - data, err := bootstrapConfigFromEnvVariable() if err != nil { return nil, fmt.Errorf("xds: Failed to read bootstrap config: %v", err) } logger.Debugf("Bootstrap content: %s", data) + return NewConfigFromContents(data) +} + +// NewConfigFromContents returns a new Config using the specified bootstrap +// file contents instead of reading the environment variable. This is only +// suitable for testing purposes. +func NewConfigFromContents(data []byte) (*Config, error) { + config := &Config{} var jsonData map[string]json.RawMessage if err := json.Unmarshal(data, &jsonData); err != nil { @@ -241,8 +252,8 @@ func NewConfig() (*Config, error) { configs[instance] = bc } config.CertProviderConfigs = configs - case "grpc_server_resource_name_id": - if err := json.Unmarshal(v, &config.ServerResourceNameID); err != nil { + case "server_listener_resource_name_template": + if err := json.Unmarshal(v, &config.ServerListenerResourceNameTemplate); err != nil { return nil, fmt.Errorf("xds: json.Unmarshal(%v) for field %q failed during bootstrap: %v", string(v), k, err) } } @@ -268,7 +279,7 @@ func NewConfig() (*Config, error) { if err := config.updateNodeProto(); err != nil { return nil, err } - logger.Infof("Bootstrap config for creating xds-client: %+v", config) + logger.Infof("Bootstrap config for creating xds-client: %v", pretty.ToJSON(config)) return config, nil } diff --git a/xds/internal/client/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go similarity index 95% rename from xds/internal/client/bootstrap/bootstrap_test.go rename to xds/internal/xdsclient/bootstrap/bootstrap_test.go index e881594b877..501d62102d2 100644 --- a/xds/internal/client/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -36,7 +36,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" - "google.golang.org/grpc/xds/internal/env" + "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/xds/internal/version" ) @@ -240,8 +240,8 @@ func (c *Config) compare(want *Config) error { if diff := cmp.Diff(want.NodeProto, c.NodeProto, cmp.Comparer(proto.Equal)); diff != "" { return fmt.Errorf("config.NodeProto diff (-want, +got):\n%s", diff) } - if c.ServerResourceNameID != want.ServerResourceNameID { - return fmt.Errorf("config.ServerResourceNameID is %q, want %q", c.ServerResourceNameID, want.ServerResourceNameID) + if c.ServerListenerResourceNameTemplate != want.ServerListenerResourceNameTemplate { + return fmt.Errorf("config.ServerListenerResourceNameTemplate is %q, want %q", c.ServerListenerResourceNameTemplate, want.ServerListenerResourceNameTemplate) } // A vanilla cmp.Equal or cmp.Diff will not produce useful error message @@ -711,9 +711,9 @@ func TestNewConfigWithCertificateProviders(t *testing.T) { } } -func TestNewConfigWithServerResourceNameID(t *testing.T) { +func TestNewConfigWithServerListenerResourceNameTemplate(t *testing.T) { cancel := setupBootstrapOverride(map[string]string{ - "badServerResourceNameID": ` + "badServerListenerResourceNameTemplate:": ` { "node": { "id": "ENVOY_NODE_ID", @@ -727,9 +727,9 @@ func TestNewConfigWithServerResourceNameID(t *testing.T) { { "type": "google_default" } ] }], - "grpc_server_resource_name_id": 123456789 + "server_listener_resource_name_template": 123456789 }`, - "goodServerResourceNameID": ` + "goodServerListenerResourceNameTemplate": ` { "node": { "id": "ENVOY_NODE_ID", @@ -743,7 +743,7 @@ func TestNewConfigWithServerResourceNameID(t *testing.T) { { "type": "google_default" } ] }], - "grpc_server_resource_name_id": "grpc/server" + "server_listener_resource_name_template": "grpc/server?xds.resource.listening_address=%s" }`, }) defer cancel() @@ -754,17 +754,17 @@ func TestNewConfigWithServerResourceNameID(t *testing.T) { wantErr bool }{ { - name: "badServerResourceNameID", + name: "badServerListenerResourceNameTemplate", wantErr: true, }, { - name: "goodServerResourceNameID", + name: "goodServerListenerResourceNameTemplate", wantConfig: &Config{ - BalancerName: "trafficdirector.googleapis.com:443", - Creds: grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()), - TransportAPI: version.TransportV2, - NodeProto: v2NodeProto, - ServerResourceNameID: "grpc/server", + BalancerName: "trafficdirector.googleapis.com:443", + Creds: grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()), + TransportAPI: version.TransportV2, + NodeProto: v2NodeProto, + ServerListenerResourceNameTemplate: "grpc/server?xds.resource.listening_address=%s", }, }, } diff --git a/xds/internal/client/bootstrap/logging.go b/xds/internal/xdsclient/bootstrap/logging.go similarity index 100% rename from xds/internal/client/bootstrap/logging.go rename to xds/internal/xdsclient/bootstrap/logging.go diff --git a/xds/internal/client/callback.go b/xds/internal/xdsclient/callback.go similarity index 60% rename from xds/internal/client/callback.go rename to xds/internal/xdsclient/callback.go index da8e2f62d6c..0c2665e84c0 100644 --- a/xds/internal/client/callback.go +++ b/xds/internal/xdsclient/callback.go @@ -16,7 +16,12 @@ * */ -package client +package xdsclient + +import ( + "google.golang.org/grpc/internal/pretty" + "google.golang.org/protobuf/proto" +) type watcherInfoWithUpdate struct { wi *watchInfo @@ -74,39 +79,49 @@ func (c *clientImpl) callCallback(wiu *watcherInfoWithUpdate) { // // A response can contain multiple resources. They will be parsed and put in a // map from resource name to the resource content. -func (c *clientImpl) NewListeners(updates map[string]ListenerUpdate, metadata UpdateMetadata) { +func (c *clientImpl) NewListeners(updates map[string]ListenerUpdateErrTuple, metadata UpdateMetadata) { c.mu.Lock() defer c.mu.Unlock() + c.ldsVersion = metadata.Version if metadata.ErrState != nil { - // On NACK, update overall version to the NACKed resp. c.ldsVersion = metadata.ErrState.Version - for name := range updates { - if _, ok := c.ldsWatchers[name]; ok { + } + for name, uErr := range updates { + if s, ok := c.ldsWatchers[name]; ok { + if uErr.Err != nil { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.ldsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.ldsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + wi.newError(uErr.Err) + } + continue } - } - return - } - - // If no error received, the status is ACK. - c.ldsVersion = metadata.Version - for name, update := range updates { - if s, ok := c.ldsWatchers[name]; ok { - // Only send the update if this is not an error. - for wi := range s { - wi.newUpdate(update) + // If we get here, it means that the update is a valid one. Notify + // watchers only if this is a first time update or it is different + // from the one currently cached. + if cur, ok := c.ldsCache[name]; !ok || !proto.Equal(cur.Raw, uErr.Update.Raw) { + for wi := range s { + wi.newUpdate(uErr.Update) + } } // Sync cache. - c.logger.Debugf("LDS resource with name %v, value %+v added to cache", name, update) - c.ldsCache[name] = update - c.ldsMD[name] = metadata + c.logger.Debugf("LDS resource with name %v, value %+v added to cache", name, pretty.ToJSON(uErr)) + c.ldsCache[name] = uErr.Update + // Set status to ACK, and clear error state. The metadata might be a + // NACK metadata because some other resources in the same response + // are invalid. + mdCopy := metadata + mdCopy.Status = ServiceStatusACKed + mdCopy.ErrState = nil + if metadata.ErrState != nil { + mdCopy.Version = metadata.ErrState.Version + } + c.ldsMD[name] = mdCopy } } // Resources not in the new update were removed by the server, so delete @@ -133,39 +148,50 @@ func (c *clientImpl) NewListeners(updates map[string]ListenerUpdate, metadata Up // // A response can contain multiple resources. They will be parsed and put in a // map from resource name to the resource content. -func (c *clientImpl) NewRouteConfigs(updates map[string]RouteConfigUpdate, metadata UpdateMetadata) { +func (c *clientImpl) NewRouteConfigs(updates map[string]RouteConfigUpdateErrTuple, metadata UpdateMetadata) { c.mu.Lock() defer c.mu.Unlock() + // If no error received, the status is ACK. + c.rdsVersion = metadata.Version if metadata.ErrState != nil { - // On NACK, update overall version to the NACKed resp. c.rdsVersion = metadata.ErrState.Version - for name := range updates { - if _, ok := c.rdsWatchers[name]; ok { + } + for name, uErr := range updates { + if s, ok := c.rdsWatchers[name]; ok { + if uErr.Err != nil { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.rdsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.rdsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + wi.newError(uErr.Err) + } + continue } - } - return - } - - // If no error received, the status is ACK. - c.rdsVersion = metadata.Version - for name, update := range updates { - if s, ok := c.rdsWatchers[name]; ok { - // Only send the update if this is not an error. - for wi := range s { - wi.newUpdate(update) + // If we get here, it means that the update is a valid one. Notify + // watchers only if this is a first time update or it is different + // from the one currently cached. + if cur, ok := c.rdsCache[name]; !ok || !proto.Equal(cur.Raw, uErr.Update.Raw) { + for wi := range s { + wi.newUpdate(uErr.Update) + } } // Sync cache. - c.logger.Debugf("RDS resource with name %v, value %+v added to cache", name, update) - c.rdsCache[name] = update - c.rdsMD[name] = metadata + c.logger.Debugf("RDS resource with name %v, value %+v added to cache", name, pretty.ToJSON(uErr)) + c.rdsCache[name] = uErr.Update + // Set status to ACK, and clear error state. The metadata might be a + // NACK metadata because some other resources in the same response + // are invalid. + mdCopy := metadata + mdCopy.Status = ServiceStatusACKed + mdCopy.ErrState = nil + if metadata.ErrState != nil { + mdCopy.Version = metadata.ErrState.Version + } + c.rdsMD[name] = mdCopy } } } @@ -175,39 +201,51 @@ func (c *clientImpl) NewRouteConfigs(updates map[string]RouteConfigUpdate, metad // // A response can contain multiple resources. They will be parsed and put in a // map from resource name to the resource content. -func (c *clientImpl) NewClusters(updates map[string]ClusterUpdate, metadata UpdateMetadata) { +func (c *clientImpl) NewClusters(updates map[string]ClusterUpdateErrTuple, metadata UpdateMetadata) { c.mu.Lock() defer c.mu.Unlock() + c.cdsVersion = metadata.Version if metadata.ErrState != nil { - // On NACK, update overall version to the NACKed resp. c.cdsVersion = metadata.ErrState.Version - for name := range updates { - if _, ok := c.cdsWatchers[name]; ok { + } + for name, uErr := range updates { + if s, ok := c.cdsWatchers[name]; ok { + if uErr.Err != nil { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.cdsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.cdsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + // Send the watcher the individual error, instead of the + // overall combined error from the metadata.ErrState. + wi.newError(uErr.Err) + } + continue } - } - return - } - - // If no error received, the status is ACK. - c.cdsVersion = metadata.Version - for name, update := range updates { - if s, ok := c.cdsWatchers[name]; ok { - // Only send the update if this is not an error. - for wi := range s { - wi.newUpdate(update) + // If we get here, it means that the update is a valid one. Notify + // watchers only if this is a first time update or it is different + // from the one currently cached. + if cur, ok := c.cdsCache[name]; !ok || !proto.Equal(cur.Raw, uErr.Update.Raw) { + for wi := range s { + wi.newUpdate(uErr.Update) + } } // Sync cache. - c.logger.Debugf("CDS resource with name %v, value %+v added to cache", name, update) - c.cdsCache[name] = update - c.cdsMD[name] = metadata + c.logger.Debugf("CDS resource with name %v, value %+v added to cache", name, pretty.ToJSON(uErr)) + c.cdsCache[name] = uErr.Update + // Set status to ACK, and clear error state. The metadata might be a + // NACK metadata because some other resources in the same response + // are invalid. + mdCopy := metadata + mdCopy.Status = ServiceStatusACKed + mdCopy.ErrState = nil + if metadata.ErrState != nil { + mdCopy.Version = metadata.ErrState.Version + } + c.cdsMD[name] = mdCopy } } // Resources not in the new update were removed by the server, so delete @@ -234,39 +272,64 @@ func (c *clientImpl) NewClusters(updates map[string]ClusterUpdate, metadata Upda // // A response can contain multiple resources. They will be parsed and put in a // map from resource name to the resource content. -func (c *clientImpl) NewEndpoints(updates map[string]EndpointsUpdate, metadata UpdateMetadata) { +func (c *clientImpl) NewEndpoints(updates map[string]EndpointsUpdateErrTuple, metadata UpdateMetadata) { c.mu.Lock() defer c.mu.Unlock() + c.edsVersion = metadata.Version if metadata.ErrState != nil { - // On NACK, update overall version to the NACKed resp. c.edsVersion = metadata.ErrState.Version - for name := range updates { - if _, ok := c.edsWatchers[name]; ok { + } + for name, uErr := range updates { + if s, ok := c.edsWatchers[name]; ok { + if uErr.Err != nil { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.edsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.edsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + // Send the watcher the individual error, instead of the + // overall combined error from the metadata.ErrState. + wi.newError(uErr.Err) + } + continue + } + // If we get here, it means that the update is a valid one. Notify + // watchers only if this is a first time update or it is different + // from the one currently cached. + if cur, ok := c.edsCache[name]; !ok || !proto.Equal(cur.Raw, uErr.Update.Raw) { + for wi := range s { + wi.newUpdate(uErr.Update) + } } + // Sync cache. + c.logger.Debugf("EDS resource with name %v, value %+v added to cache", name, pretty.ToJSON(uErr)) + c.edsCache[name] = uErr.Update + // Set status to ACK, and clear error state. The metadata might be a + // NACK metadata because some other resources in the same response + // are invalid. + mdCopy := metadata + mdCopy.Status = ServiceStatusACKed + mdCopy.ErrState = nil + if metadata.ErrState != nil { + mdCopy.Version = metadata.ErrState.Version + } + c.edsMD[name] = mdCopy } - return } +} - // If no error received, the status is ACK. - c.edsVersion = metadata.Version - for name, update := range updates { - if s, ok := c.edsWatchers[name]; ok { - // Only send the update if this is not an error. - for wi := range s { - wi.newUpdate(update) - } - // Sync cache. - c.logger.Debugf("EDS resource with name %v, value %+v added to cache", name, update) - c.edsCache[name] = update - c.edsMD[name] = metadata +// NewConnectionError is called by the underlying xdsAPIClient when it receives +// a connection error. The error will be forwarded to all the resource watchers. +func (c *clientImpl) NewConnectionError(err error) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, s := range c.edsWatchers { + for wi := range s { + wi.newError(NewErrorf(ErrorTypeConnection, "xds: error received from xDS stream: %v", err)) } } } diff --git a/xds/internal/xdsclient/cds_test.go b/xds/internal/xdsclient/cds_test.go new file mode 100644 index 00000000000..21e3b05b908 --- /dev/null +++ b/xds/internal/xdsclient/cds_test.go @@ -0,0 +1,1590 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "regexp" + "strings" + "testing" + + v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" + v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + v3aggregateclusterpb "github.com/envoyproxy/go-control-plane/envoy/extensions/clusters/aggregate/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + anypb "github.com/golang/protobuf/ptypes/any" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/internal/xds/matcher" + "google.golang.org/grpc/xds/internal/version" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +const ( + clusterName = "clusterName" + serviceName = "service" +) + +var emptyUpdate = ClusterUpdate{ClusterName: clusterName, EnableLRS: false} + +func (s) TestValidateCluster_Failure(t *testing.T) { + tests := []struct { + name string + cluster *v3clusterpb.Cluster + wantUpdate ClusterUpdate + wantErr bool + }{ + { + name: "non-supported-cluster-type-static", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + }, + LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "non-supported-cluster-type-original-dst", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_ORIGINAL_DST}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + }, + LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "no-eds-config", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "no-ads-config-source", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{}, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "non-round-robin-or-ring-hash-lb-policy", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + }, + LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "logical-dns-multiple-localities", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_LOGICAL_DNS}, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + LoadAssignment: &v3endpointpb.ClusterLoadAssignment{ + Endpoints: []*v3endpointpb.LocalityLbEndpoints{ + // Invalid if there are more than one locality. + {LbEndpoints: nil}, + {LbEndpoints: nil}, + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "ring-hash-hash-function-not-xx-hash", + cluster: &v3clusterpb.Cluster{ + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + HashFunction: v3clusterpb.Cluster_RingHashLbConfig_MURMUR_HASH_2, + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "ring-hash-min-bound-greater-than-max", + cluster: &v3clusterpb.Cluster{ + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + MinimumRingSize: wrapperspb.UInt64(100), + MaximumRingSize: wrapperspb.UInt64(10), + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "ring-hash-min-bound-greater-than-upper-bound", + cluster: &v3clusterpb.Cluster{ + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + MinimumRingSize: wrapperspb.UInt64(ringHashSizeUpperBound + 1), + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "ring-hash-max-bound-greater-than-upper-bound", + cluster: &v3clusterpb.Cluster{ + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + MaximumRingSize: wrapperspb.UInt64(ringHashSizeUpperBound + 1), + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + } + + oldAggregateAndDNSSupportEnv := env.AggregateAndDNSSupportEnv + env.AggregateAndDNSSupportEnv = true + defer func() { env.AggregateAndDNSSupportEnv = oldAggregateAndDNSSupportEnv }() + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if update, err := validateClusterAndConstructClusterUpdate(test.cluster); err == nil { + t.Errorf("validateClusterAndConstructClusterUpdate(%+v) = %v, wanted error", test.cluster, update) + } + }) + } +} + +func (s) TestValidateCluster_Success(t *testing.T) { + tests := []struct { + name string + cluster *v3clusterpb.Cluster + wantUpdate ClusterUpdate + }{ + { + name: "happy-case-logical-dns", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_LOGICAL_DNS}, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + LoadAssignment: &v3endpointpb.ClusterLoadAssignment{ + Endpoints: []*v3endpointpb.LocalityLbEndpoints{{ + LbEndpoints: []*v3endpointpb.LbEndpoint{{ + HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{ + Endpoint: &v3endpointpb.Endpoint{ + Address: &v3corepb.Address{ + Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Address: "dns_host", + PortSpecifier: &v3corepb.SocketAddress_PortValue{ + PortValue: 8080, + }, + }, + }, + }, + }, + }, + }}, + }}, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + ClusterType: ClusterTypeLogicalDNS, + DNSHostName: "dns_host:8080", + }, + }, + { + name: "happy-case-aggregate-v3", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_ClusterType{ + ClusterType: &v3clusterpb.Cluster_CustomClusterType{ + Name: "envoy.clusters.aggregate", + TypedConfig: testutils.MarshalAny(&v3aggregateclusterpb.ClusterConfig{ + Clusters: []string{"a", "b", "c"}, + }), + }, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, EnableLRS: false, ClusterType: ClusterTypeAggregate, + PrioritizedClusterNames: []string{"a", "b", "c"}, + }, + }, + { + name: "happy-case-no-service-name-no-lrs", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: emptyUpdate, + }, + { + name: "happy-case-no-lrs", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: ClusterUpdate{ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: false}, + }, + { + name: "happiest-case", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }, + wantUpdate: ClusterUpdate{ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: true}, + }, + { + name: "happiest-case-with-circuitbreakers", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + CircuitBreakers: &v3clusterpb.CircuitBreakers{ + Thresholds: []*v3clusterpb.CircuitBreakers_Thresholds{ + { + Priority: v3corepb.RoutingPriority_DEFAULT, + MaxRequests: wrapperspb.UInt32(512), + }, + { + Priority: v3corepb.RoutingPriority_HIGH, + MaxRequests: nil, + }, + }, + }, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }, + wantUpdate: ClusterUpdate{ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: true, MaxRequests: func() *uint32 { i := uint32(512); return &i }()}, + }, + { + name: "happiest-case-with-ring-hash-lb-policy-with-default-config", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: true, + LBPolicy: &ClusterLBPolicyRingHash{MinimumRingSize: defaultRingHashMinSize, MaximumRingSize: defaultRingHashMaxSize}, + }, + }, + { + name: "happiest-case-with-ring-hash-lb-policy-with-none-default-config", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + MinimumRingSize: wrapperspb.UInt64(10), + MaximumRingSize: wrapperspb.UInt64(100), + }, + }, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: true, + LBPolicy: &ClusterLBPolicyRingHash{MinimumRingSize: 10, MaximumRingSize: 100}, + }, + }, + } + + oldAggregateAndDNSSupportEnv := env.AggregateAndDNSSupportEnv + env.AggregateAndDNSSupportEnv = true + defer func() { env.AggregateAndDNSSupportEnv = oldAggregateAndDNSSupportEnv }() + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + update, err := validateClusterAndConstructClusterUpdate(test.cluster) + if err != nil { + t.Errorf("validateClusterAndConstructClusterUpdate(%+v) failed: %v", test.cluster, err) + } + if diff := cmp.Diff(update, test.wantUpdate, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("validateClusterAndConstructClusterUpdate(%+v) got diff: %v (-got, +want)", test.cluster, diff) + } + }) + } +} + +func (s) TestValidateClusterWithSecurityConfig_EnvVarOff(t *testing.T) { + // Turn off the env var protection for client-side security. + origClientSideSecurityEnvVar := env.ClientSideSecuritySupport + env.ClientSideSecuritySupport = false + defer func() { env.ClientSideSecuritySupport = origClientSideSecurityEnvVar }() + + cluster := &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "rootInstance", + CertificateName: "rootCert", + }, + }, + }, + }), + }, + }, + } + wantUpdate := ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + } + gotUpdate, err := validateClusterAndConstructClusterUpdate(cluster) + if err != nil { + t.Errorf("validateClusterAndConstructClusterUpdate() failed: %v", err) + } + if diff := cmp.Diff(wantUpdate, gotUpdate); diff != "" { + t.Errorf("validateClusterAndConstructClusterUpdate() returned unexpected diff (-want, got):\n%s", diff) + } +} + +func (s) TestSecurityConfigFromCommonTLSContextUsingNewFields_ErrorCases(t *testing.T) { + tests := []struct { + name string + common *v3tlspb.CommonTlsContext + server bool + wantErr string + }{ + { + name: "unsupported-tls_certificates-field-for-identity-certs", + common: &v3tlspb.CommonTlsContext{ + TlsCertificates: []*v3tlspb.TlsCertificate{ + {CertificateChain: &v3corepb.DataSource{}}, + }, + }, + wantErr: "unsupported field tls_certificates is set in CommonTlsContext message", + }, + { + name: "unsupported-tls_certificates_sds_secret_configs-field-for-identity-certs", + common: &v3tlspb.CommonTlsContext{ + TlsCertificateSdsSecretConfigs: []*v3tlspb.SdsSecretConfig{ + {Name: "sds-secrets-config"}, + }, + }, + wantErr: "unsupported field tls_certificate_sds_secret_configs is set in CommonTlsContext message", + }, + { + name: "unsupported-sds-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + wantErr: "validation context contains unexpected type", + }, + { + name: "missing-ca_certificate_provider_instance-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{}, + }, + }, + wantErr: "expected field ca_certificate_provider_instance is missing in CommonTlsContext message", + }, + { + name: "unsupported-field-verify_certificate_spki-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + VerifyCertificateSpki: []string{"spki"}, + }, + }, + }, + wantErr: "unsupported verify_certificate_spki field in CommonTlsContext message", + }, + { + name: "unsupported-field-verify_certificate_hash-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + VerifyCertificateHash: []string{"hash"}, + }, + }, + }, + wantErr: "unsupported verify_certificate_hash field in CommonTlsContext message", + }, + { + name: "unsupported-field-require_signed_certificate_timestamp-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + RequireSignedCertificateTimestamp: &wrapperspb.BoolValue{Value: true}, + }, + }, + }, + wantErr: "unsupported require_sugned_ceritificate_timestamp field in CommonTlsContext message", + }, + { + name: "unsupported-field-crl-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + Crl: &v3corepb.DataSource{}, + }, + }, + }, + wantErr: "unsupported crl field in CommonTlsContext message", + }, + { + name: "unsupported-field-custom_validator_config-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + CustomValidatorConfig: &v3corepb.TypedExtensionConfig{}, + }, + }, + }, + wantErr: "unsupported custom_validator_config field in CommonTlsContext message", + }, + { + name: "invalid-match_subject_alt_names-field-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}}, + }, + }, + }, + }, + wantErr: "empty prefix is not allowed in StringMatcher", + }, + { + name: "unsupported-field-matching-subject-alt-names-in-validation-context-of-server", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "sanPrefix"}}, + }, + }, + }, + }, + server: true, + wantErr: "match_subject_alt_names field in validation context is not supported on the server", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := securityConfigFromCommonTLSContextUsingNewFields(test.common, test.server) + if err == nil { + t.Fatal("securityConfigFromCommonTLSContextUsingNewFields() succeeded when expected to fail") + } + if !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("securityConfigFromCommonTLSContextUsingNewFields() returned err: %v, wantErr: %v", err, test.wantErr) + } + }) + } +} + +func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { + const ( + identityPluginInstance = "identityPluginInstance" + identityCertName = "identityCert" + rootPluginInstance = "rootPluginInstance" + rootCertName = "rootCert" + clusterName = "cluster" + serviceName = "service" + sanExact = "san-exact" + sanPrefix = "san-prefix" + sanSuffix = "san-suffix" + sanRegexBad = "??" + sanRegexGood = "san?regex?" + sanContains = "san-contains" + ) + var sanRE = regexp.MustCompile(sanRegexGood) + + tests := []struct { + name string + cluster *v3clusterpb.Cluster + wantUpdate ClusterUpdate + wantErr bool + }{ + { + name: "transport-socket-matches", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocketMatches: []*v3clusterpb.Cluster_TransportSocketMatch{ + {Name: "transport-socket-match-1"}, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-name", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "unsupported-foo", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-typeURL", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3HTTPConnManagerURL, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-type", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-tls-params-field", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsParams: &v3tlspb.TlsParameters{}, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-custom-handshaker-field", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + CustomHandshaker: &v3corepb.TypedExtensionConfig{}, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-validation-context", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-without-validation-context", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{}, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "empty-prefix-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "empty-suffix-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "empty-contains-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "invalid-regex-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexBad}}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "invalid-regex-in-matching-SAN-with-new-fields", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexBad}}}, + }, + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "happy-case-with-no-identity-certs-using-deprecated-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + }, + }, + }, + { + name: "happy-case-with-no-identity-certs-using-new-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + }, + }, + }, + { + name: "happy-case-with-validation-context-provider-instance-using-deprecated-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: identityPluginInstance, + CertificateName: identityCertName, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + IdentityInstanceName: identityPluginInstance, + IdentityCertName: identityCertName, + }, + }, + }, + { + name: "happy-case-with-validation-context-provider-instance-using-new-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: identityPluginInstance, + CertificateName: identityCertName, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + IdentityInstanceName: identityPluginInstance, + IdentityCertName: identityCertName, + }, + }, + }, + { + name: "happy-case-with-combined-validation-context-using-deprecated-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: identityPluginInstance, + CertificateName: identityCertName, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + { + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: sanExact}, + IgnoreCase: true, + }, + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: sanPrefix}}, + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: sanSuffix}}, + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexGood}}}, + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: sanContains}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + IdentityInstanceName: identityPluginInstance, + IdentityCertName: identityCertName, + SubjectAltNameMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP(sanExact), nil, nil, nil, nil, true), + matcher.StringMatcherForTesting(nil, newStringP(sanPrefix), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP(sanSuffix), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, sanRE, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP(sanContains), nil, false), + }, + }, + }, + }, + { + name: "happy-case-with-combined-validation-context-using-new-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: identityPluginInstance, + CertificateName: identityCertName, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + { + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: sanExact}, + IgnoreCase: true, + }, + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: sanPrefix}}, + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: sanSuffix}}, + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexGood}}}, + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: sanContains}}, + }, + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + IdentityInstanceName: identityPluginInstance, + IdentityCertName: identityCertName, + SubjectAltNameMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP(sanExact), nil, nil, nil, nil, true), + matcher.StringMatcherForTesting(nil, newStringP(sanPrefix), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP(sanSuffix), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, sanRE, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP(sanContains), nil, false), + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + update, err := validateClusterAndConstructClusterUpdate(test.cluster) + if (err != nil) != test.wantErr { + t.Errorf("validateClusterAndConstructClusterUpdate() returned err %v wantErr %v)", err, test.wantErr) + } + if diff := cmp.Diff(test.wantUpdate, update, cmpopts.EquateEmpty(), cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Errorf("validateClusterAndConstructClusterUpdate() returned unexpected diff (-want, +got):\n%s", diff) + } + }) + } +} + +func (s) TestUnmarshalCluster(t *testing.T) { + const ( + v2ClusterName = "v2clusterName" + v3ClusterName = "v3clusterName" + v2Service = "v2Service" + v3Service = "v2Service" + ) + var ( + v2ClusterAny = testutils.MarshalAny(&v2xdspb.Cluster{ + Name: v2ClusterName, + ClusterDiscoveryType: &v2xdspb.Cluster_Type{Type: v2xdspb.Cluster_EDS}, + EdsClusterConfig: &v2xdspb.Cluster_EdsClusterConfig{ + EdsConfig: &v2corepb.ConfigSource{ + ConfigSourceSpecifier: &v2corepb.ConfigSource_Ads{ + Ads: &v2corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: v2Service, + }, + LbPolicy: v2xdspb.Cluster_ROUND_ROBIN, + LrsServer: &v2corepb.ConfigSource{ + ConfigSourceSpecifier: &v2corepb.ConfigSource_Self{ + Self: &v2corepb.SelfConfigSource{}, + }, + }, + }) + + v3ClusterAny = testutils.MarshalAny(&v3clusterpb.Cluster{ + Name: v3ClusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: v3Service, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }) + ) + const testVersion = "test-version-cds" + + tests := []struct { + name string + resources []*anypb.Any + wantUpdate map[string]ClusterUpdateErrTuple + wantMD UpdateMetadata + wantErr bool + }{ + { + name: "non-cluster resource type", + resources: []*anypb.Any{{TypeUrl: version.V3HTTPConnManagerURL}}, + wantMD: UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + }, + wantErr: true, + }, + { + name: "badly marshaled cluster resource", + resources: []*anypb.Any{ + { + TypeUrl: version.V3ClusterURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + }, + wantErr: true, + }, + { + name: "bad cluster resource", + resources: []*anypb.Any{ + testutils.MarshalAny(&v3clusterpb.Cluster{ + Name: "test", + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, + }), + }, + wantUpdate: map[string]ClusterUpdateErrTuple{ + "test": {Err: cmpopts.AnyError}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + }, + wantErr: true, + }, + { + name: "v2 cluster", + resources: []*anypb.Any{v2ClusterAny}, + wantUpdate: map[string]ClusterUpdateErrTuple{ + v2ClusterName: {Update: ClusterUpdate{ + ClusterName: v2ClusterName, + EDSServiceName: v2Service, EnableLRS: true, + Raw: v2ClusterAny, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 cluster", + resources: []*anypb.Any{v3ClusterAny}, + wantUpdate: map[string]ClusterUpdateErrTuple{ + v3ClusterName: {Update: ClusterUpdate{ + ClusterName: v3ClusterName, + EDSServiceName: v3Service, EnableLRS: true, + Raw: v3ClusterAny, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "multiple clusters", + resources: []*anypb.Any{v2ClusterAny, v3ClusterAny}, + wantUpdate: map[string]ClusterUpdateErrTuple{ + v2ClusterName: {Update: ClusterUpdate{ + ClusterName: v2ClusterName, + EDSServiceName: v2Service, EnableLRS: true, + Raw: v2ClusterAny, + }}, + v3ClusterName: {Update: ClusterUpdate{ + ClusterName: v3ClusterName, + EDSServiceName: v3Service, EnableLRS: true, + Raw: v3ClusterAny, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + // To test that unmarshal keeps processing on errors. + name: "good and bad clusters", + resources: []*anypb.Any{ + v2ClusterAny, + // bad cluster resource + testutils.MarshalAny(&v3clusterpb.Cluster{ + Name: "bad", + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, + }), + v3ClusterAny, + }, + wantUpdate: map[string]ClusterUpdateErrTuple{ + v2ClusterName: {Update: ClusterUpdate{ + ClusterName: v2ClusterName, + EDSServiceName: v2Service, EnableLRS: true, + Raw: v2ClusterAny, + }}, + v3ClusterName: {Update: ClusterUpdate{ + ClusterName: v3ClusterName, + EDSServiceName: v3Service, EnableLRS: true, + Raw: v3ClusterAny, + }}, + "bad": {Err: cmpopts.AnyError}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + }, + wantErr: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + update, md, err := UnmarshalCluster(opts) + if (err != nil) != test.wantErr { + t.Fatalf("UnmarshalCluster(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) + } + if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { + t.Errorf("got unexpected update, diff (-got +want): %v", diff) + } + if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { + t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) + } + }) + } +} diff --git a/xds/internal/client/client.go b/xds/internal/xdsclient/client.go similarity index 68% rename from xds/internal/client/client.go rename to xds/internal/xdsclient/client.go index 21881cc6eae..3230c66c06e 100644 --- a/xds/internal/client/client.go +++ b/xds/internal/xdsclient/client.go @@ -16,14 +16,15 @@ * */ -// Package client implements a full fledged gRPC client for the xDS API used by -// the xds resolver and balancer implementations. -package client +// Package xdsclient implements a full fledged gRPC client for the xDS API used +// by the xds resolver and balancer implementations. +package xdsclient import ( "context" "errors" "fmt" + "regexp" "sync" "time" @@ -32,8 +33,10 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/xdsclient/load" "google.golang.org/grpc" "google.golang.org/grpc/internal/backoff" @@ -42,8 +45,8 @@ import ( "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) var ( @@ -69,12 +72,20 @@ func getAPIClientBuilder(version version.TransportAPI) APIClientBuilder { return nil } +// UpdateValidatorFunc performs validations on update structs using +// context/logic available at the xdsClient layer. Since these validation are +// performed on internal update structs, they can be shared between different +// API clients. +type UpdateValidatorFunc func(interface{}) error + // BuildOptions contains options to be passed to client builders. type BuildOptions struct { // Parent is a top-level xDS client which has the intelligence to take // appropriate action based on xDS responses received from the management // server. Parent UpdateHandler + // Validator performs post unmarshal validation checks. + Validator UpdateValidatorFunc // NodeProto contains the Node proto to be used in xDS requests. The actual // type depends on the transport protocol version used. NodeProto proto.Message @@ -131,14 +142,17 @@ type loadReportingOptions struct { // resource updates from an APIClient for a specific version. type UpdateHandler interface { // NewListeners handles updates to xDS listener resources. - NewListeners(map[string]ListenerUpdate, UpdateMetadata) + NewListeners(map[string]ListenerUpdateErrTuple, UpdateMetadata) // NewRouteConfigs handles updates to xDS RouteConfiguration resources. - NewRouteConfigs(map[string]RouteConfigUpdate, UpdateMetadata) + NewRouteConfigs(map[string]RouteConfigUpdateErrTuple, UpdateMetadata) // NewClusters handles updates to xDS Cluster resources. - NewClusters(map[string]ClusterUpdate, UpdateMetadata) + NewClusters(map[string]ClusterUpdateErrTuple, UpdateMetadata) // NewEndpoints handles updates to xDS ClusterLoadAssignment (or tersely // referred to as Endpoints) resources. - NewEndpoints(map[string]EndpointsUpdate, UpdateMetadata) + NewEndpoints(map[string]EndpointsUpdateErrTuple, UpdateMetadata) + // NewConnectionError handles connection errors from the xDS stream. The + // error will be reported to all the resource watchers. + NewConnectionError(err error) } // ServiceStatus is the status of the update. @@ -193,9 +207,15 @@ type UpdateMetadata struct { type ListenerUpdate struct { // RouteConfigName is the route configuration name corresponding to the // target which is being watched through LDS. + // + // Only one of RouteConfigName and InlineRouteConfig is set. RouteConfigName string - // SecurityCfg contains security configuration sent by the control plane. - SecurityCfg *SecurityConfig + // InlineRouteConfig is the inline route configuration (RDS response) + // returned inside LDS. + // + // Only one of RouteConfigName and InlineRouteConfig is set. + InlineRouteConfig *RouteConfigUpdate + // MaxStreamDuration contains the HTTP connection manager's // common_http_protocol_options.max_stream_duration field, or zero if // unset. @@ -203,6 +223,8 @@ type ListenerUpdate struct { // HTTPFilters is a list of HTTP filters (name, config) from the LDS // response. HTTPFilters []HTTPFilter + // InboundListenerCfg contains inbound listener configuration. + InboundListenerCfg *InboundListenerConfig // Raw is the resource from the xds response. Raw *anypb.Any @@ -221,15 +243,23 @@ type HTTPFilter struct { Config httpfilter.FilterConfig } -func (lu *ListenerUpdate) String() string { - return fmt.Sprintf("{RouteConfigName: %q, SecurityConfig: %+v", lu.RouteConfigName, lu.SecurityCfg) +// InboundListenerConfig contains information about the inbound listener, i.e +// the server-side listener. +type InboundListenerConfig struct { + // Address is the local address on which the inbound listener is expected to + // accept incoming connections. + Address string + // Port is the local port on which the inbound listener is expected to + // accept incoming connections. + Port string + // FilterChains is the list of filter chains associated with this listener. + FilterChains *FilterChainManager } // RouteConfigUpdate contains information received in an RDS response, which is // of interest to the registered RDS watcher. type RouteConfigUpdate struct { VirtualHosts []*VirtualHost - // Raw is the resource from the xds response. Raw *anypb.Any } @@ -248,18 +278,81 @@ type VirtualHost struct { // may be unused if the matching Route contains an override for that // filter. HTTPFilterConfigOverride map[string]httpfilter.FilterConfig + RetryConfig *RetryConfig +} + +// RetryConfig contains all retry-related configuration in either a VirtualHost +// or Route. +type RetryConfig struct { + // RetryOn is a set of status codes on which to retry. Only Canceled, + // DeadlineExceeded, Internal, ResourceExhausted, and Unavailable are + // supported; any other values will be omitted. + RetryOn map[codes.Code]bool + NumRetries uint32 // maximum number of retry attempts + RetryBackoff RetryBackoff // retry backoff policy +} + +// RetryBackoff describes the backoff policy for retries. +type RetryBackoff struct { + BaseInterval time.Duration // initial backoff duration between attempts + MaxInterval time.Duration // maximum backoff duration +} + +// HashPolicyType specifies the type of HashPolicy from a received RDS Response. +type HashPolicyType int + +const ( + // HashPolicyTypeHeader specifies to hash a Header in the incoming request. + HashPolicyTypeHeader HashPolicyType = iota + // HashPolicyTypeChannelID specifies to hash a unique Identifier of the + // Channel. In grpc-go, this will be done using the ClientConn pointer. + HashPolicyTypeChannelID +) + +// HashPolicy specifies the HashPolicy if the upstream cluster uses a hashing +// load balancer. +type HashPolicy struct { + HashPolicyType HashPolicyType + Terminal bool + // Fields used for type HEADER. + HeaderName string + Regex *regexp.Regexp + RegexSubstitution string } +// RouteAction is the action of the route from a received RDS response. +type RouteAction int + +const ( + // RouteActionUnsupported are routing types currently unsupported by grpc. + // According to A36, "A Route with an inappropriate action causes RPCs + // matching that route to fail." + RouteActionUnsupported RouteAction = iota + // RouteActionRoute is the expected route type on the client side. Route + // represents routing a request to some upstream cluster. On the client + // side, if an RPC matches to a route that is not RouteActionRoute, the RPC + // will fail according to A36. + RouteActionRoute + // RouteActionNonForwardingAction is the expected route type on the server + // side. NonForwardingAction represents when a route will generate a + // response directly, without forwarding to an upstream host. + RouteActionNonForwardingAction +) + // Route is both a specification of how to match a request as well as an // indication of the action to take upon match. type Route struct { - Path, Prefix, Regex *string + Path *string + Prefix *string + Regex *regexp.Regexp // Indicates if prefix/path matching should be case insensitive. The default // is false (case sensitive). CaseInsensitive bool Headers []*HeaderMatcher Fraction *uint32 + HashPolicies []*HashPolicy + // If the matchers above indicate a match, the below configuration is used. WeightedClusters map[string]WeightedCluster // If MaxStreamDuration is nil, it indicates neither of the route action's @@ -273,6 +366,9 @@ type Route struct { // unused if the matching WeightedCluster contains an override for that // filter. HTTPFilterConfigOverride map[string]httpfilter.FilterConfig + RetryConfig *RetryConfig + + RouteAction RouteAction } // WeightedCluster contains settings for an xds RouteAction.WeightedCluster. @@ -286,20 +382,20 @@ type WeightedCluster struct { // HeaderMatcher represents header matchers. type HeaderMatcher struct { - Name string `json:"name"` - InvertMatch *bool `json:"invertMatch,omitempty"` - ExactMatch *string `json:"exactMatch,omitempty"` - RegexMatch *string `json:"regexMatch,omitempty"` - PrefixMatch *string `json:"prefixMatch,omitempty"` - SuffixMatch *string `json:"suffixMatch,omitempty"` - RangeMatch *Int64Range `json:"rangeMatch,omitempty"` - PresentMatch *bool `json:"presentMatch,omitempty"` + Name string + InvertMatch *bool + ExactMatch *string + RegexMatch *regexp.Regexp + PrefixMatch *string + SuffixMatch *string + RangeMatch *Int64Range + PresentMatch *bool } // Int64Range is a range for header range match. type Int64Range struct { - Start int64 `json:"start"` - End int64 `json:"end"` + Start int64 + End int64 } // SecurityConfig contains the security configuration received as part of the @@ -322,29 +418,107 @@ type SecurityConfig struct { // IdentityCertName is the certificate name to be passed to the plugin // (looked up from the bootstrap file) while fetching identity certificates. IdentityCertName string - // AcceptedSANs is a list of Subject Alternative Names. During the TLS - // handshake, the SAN present in the peer certificate is compared against - // this list, and the handshake succeeds only if a match is found. Used only - // on the client-side. - AcceptedSANs []string + // SubjectAltNameMatchers is an optional list of match criteria for SANs + // specified on the peer certificate. Used only on the client-side. + // + // Some intricacies: + // - If this field is empty, then any peer certificate is accepted. + // - If the peer certificate contains a wildcard DNS SAN, and an `exact` + // matcher is configured, a wildcard DNS match is performed instead of a + // regular string comparison. + SubjectAltNameMatchers []matcher.StringMatcher // RequireClientCert indicates if the server handshake process expects the // client to present a certificate. Set to true when performing mTLS. Used // only on the server-side. RequireClientCert bool } +// Equal returns true if sc is equal to other. +func (sc *SecurityConfig) Equal(other *SecurityConfig) bool { + switch { + case sc == nil && other == nil: + return true + case (sc != nil) != (other != nil): + return false + } + switch { + case sc.RootInstanceName != other.RootInstanceName: + return false + case sc.RootCertName != other.RootCertName: + return false + case sc.IdentityInstanceName != other.IdentityInstanceName: + return false + case sc.IdentityCertName != other.IdentityCertName: + return false + case sc.RequireClientCert != other.RequireClientCert: + return false + default: + if len(sc.SubjectAltNameMatchers) != len(other.SubjectAltNameMatchers) { + return false + } + for i := 0; i < len(sc.SubjectAltNameMatchers); i++ { + if !sc.SubjectAltNameMatchers[i].Equal(other.SubjectAltNameMatchers[i]) { + return false + } + } + } + return true +} + +// ClusterType is the type of cluster from a received CDS response. +type ClusterType int + +const ( + // ClusterTypeEDS represents the EDS cluster type, which will delegate endpoint + // discovery to the management server. + ClusterTypeEDS ClusterType = iota + // ClusterTypeLogicalDNS represents the Logical DNS cluster type, which essentially + // maps to the gRPC behavior of using the DNS resolver with pick_first LB policy. + ClusterTypeLogicalDNS + // ClusterTypeAggregate represents the Aggregate Cluster type, which provides a + // prioritized list of clusters to use. It is used for failover between clusters + // with a different configuration. + ClusterTypeAggregate +) + +// ClusterLBPolicyRingHash represents ring_hash lb policy, and also contains its +// config. +type ClusterLBPolicyRingHash struct { + MinimumRingSize uint64 + MaximumRingSize uint64 +} + // ClusterUpdate contains information from a received CDS response, which is of // interest to the registered CDS watcher. type ClusterUpdate struct { - // ServiceName is the service name corresponding to the clusterName which - // is being watched for through CDS. - ServiceName string + ClusterType ClusterType + // ClusterName is the clusterName being watched for through CDS. + ClusterName string + // EDSServiceName is an optional name for EDS. If it's not set, the balancer + // should watch ClusterName for the EDS resources. + EDSServiceName string // EnableLRS indicates whether or not load should be reported through LRS. EnableLRS bool // SecurityCfg contains security configuration sent by the control plane. SecurityCfg *SecurityConfig // MaxRequests for circuit breaking, if any (otherwise nil). MaxRequests *uint32 + // DNSHostName is used only for cluster type DNS. It's the DNS name to + // resolve in "host:port" form + DNSHostName string + // PrioritizedClusterNames is used only for cluster type aggregate. It represents + // a prioritized list of cluster names. + PrioritizedClusterNames []string + + // LBPolicy is the lb policy for this cluster. + // + // This only support round_robin and ring_hash. + // - if it's nil, the lb policy is round_robin + // - if it's not nil, the lb policy is ring_hash, the this field has the config. + // + // When we add more support policies, this can be made an interface, and + // will be set to different types based on the policy type. + LBPolicy *ClusterLBPolicyRingHash // Raw is the resource from the xds response. Raw *anypb.Any @@ -514,6 +688,7 @@ func newWithConfig(config *bootstrap.Config, watchExpiryTimeout time.Duration) ( apiClient, err := newAPIClient(config.TransportAPI, cc, BuildOptions{ Parent: c, + Validator: c.updateValidator, NodeProto: config.NodeProto, Backoff: backoff.DefaultExponential.Backoff, Logger: c.logger, @@ -529,7 +704,7 @@ func newWithConfig(config *bootstrap.Config, watchExpiryTimeout time.Duration) ( // BootstrapConfig returns the configuration read from the bootstrap file. // Callers must treat the return value as read-only. -func (c *Client) BootstrapConfig() *bootstrap.Config { +func (c *clientRefCounted) BootstrapConfig() *bootstrap.Config { return c.config } @@ -567,6 +742,64 @@ func (c *clientImpl) Close() { c.logger.Infof("Shutdown") } +func (c *clientImpl) filterChainUpdateValidator(fc *FilterChain) error { + if fc == nil { + return nil + } + return c.securityConfigUpdateValidator(fc.SecurityCfg) +} + +func (c *clientImpl) securityConfigUpdateValidator(sc *SecurityConfig) error { + if sc == nil { + return nil + } + if sc.IdentityInstanceName != "" { + if _, ok := c.config.CertProviderConfigs[sc.IdentityInstanceName]; !ok { + return fmt.Errorf("identitiy certificate provider instance name %q missing in bootstrap configuration", sc.IdentityInstanceName) + } + } + if sc.RootInstanceName != "" { + if _, ok := c.config.CertProviderConfigs[sc.RootInstanceName]; !ok { + return fmt.Errorf("root certificate provider instance name %q missing in bootstrap configuration", sc.RootInstanceName) + } + } + return nil +} + +func (c *clientImpl) updateValidator(u interface{}) error { + switch update := u.(type) { + case ListenerUpdate: + if update.InboundListenerCfg == nil || update.InboundListenerCfg.FilterChains == nil { + return nil + } + + fcm := update.InboundListenerCfg.FilterChains + for _, dst := range fcm.dstPrefixMap { + for _, srcType := range dst.srcTypeArr { + if srcType == nil { + continue + } + for _, src := range srcType.srcPrefixMap { + for _, fc := range src.srcPortMap { + if err := c.filterChainUpdateValidator(fc); err != nil { + return err + } + } + } + } + } + return c.filterChainUpdateValidator(fcm.def) + case ClusterUpdate: + return c.securityConfigUpdateValidator(update.SecurityCfg) + default: + // We currently invoke this update validation function only for LDS and + // CDS updates. In the future, if we wish to invoke it for other xDS + // updates, corresponding plumbing needs to be added to those unmarshal + // functions. + } + return nil +} + // ResourceType identifies resources in a transport protocol agnostic way. These // will be used in transport version agnostic code, while the versioned API // clients will map these to appropriate version URLs. diff --git a/xds/internal/client/client_test.go b/xds/internal/xdsclient/client_test.go similarity index 78% rename from xds/internal/client/client_test.go rename to xds/internal/xdsclient/client_test.go index 8275ea60e0d..7c3423cd5ad 100644 --- a/xds/internal/client/client_test.go +++ b/xds/internal/xdsclient/client_test.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "context" @@ -26,15 +26,16 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/testing/protocmp" ) @@ -62,19 +63,11 @@ const ( var ( cmpOpts = cmp.Options{ cmpopts.EquateEmpty(), + cmp.FilterValues(func(x, y error) bool { return true }, cmpopts.EquateErrors()), cmp.Comparer(func(a, b time.Time) bool { return true }), - cmp.Comparer(func(x, y error) bool { - if x == nil || y == nil { - return x == nil && y == nil - } - return x.Error() == y.Error() - }), protocmp.Transform(), } - // When comparing NACK UpdateMetadata, we only care if error is nil, but not - // the details in error. - errPlaceHolder = fmt.Errorf("error whose details don't matter") cmpOptsIgnoreDetails = cmp.Options{ cmp.Comparer(func(a, b time.Time) bool { return true }), cmp.Comparer(func(x, y error) bool { @@ -170,7 +163,7 @@ func (s) TestWatchCallAnotherWatch(t *testing.T) { clusterUpdateCh := testutils.NewChannel() firstTime := true client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) // Calls another watch inline, to ensure there's deadlock. client.WatchCluster("another-random-name", func(ClusterUpdate, error) {}) @@ -183,63 +176,89 @@ func (s) TestWatchCallAnotherWatch(t *testing.T) { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } - wantUpdate2 := ClusterUpdate{ServiceName: testEDSName + "2"} - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate2}, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate2); err != nil { + // The second update needs to be different in the underlying resource proto + // for the watch callback to be invoked. + wantUpdate2 := ClusterUpdate{ClusterName: testEDSName + "2", Raw: &anypb.Any{}} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate2}}, UpdateMetadata{}) + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate2, nil); err != nil { t.Fatal(err) } } -func verifyListenerUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ListenerUpdate) error { +func verifyListenerUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ListenerUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for listener update: %v", err) } - gotUpdate := u.(ldsUpdateErr) - if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate) { - return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) + gotUpdate := u.(ListenerUpdateErrTuple) + if wantErr != nil { + if gotUpdate.Err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.Err, wantErr) + } + return nil + } + if gotUpdate.Err != nil || !cmp.Equal(gotUpdate.Update, wantUpdate, protocmp.Transform()) { + return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.Update, gotUpdate.Err, wantUpdate) } return nil } -func verifyRouteConfigUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate RouteConfigUpdate) error { +func verifyRouteConfigUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate RouteConfigUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for route configuration update: %v", err) } - gotUpdate := u.(rdsUpdateErr) - if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate) { - return fmt.Errorf("unexpected route config update: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) + gotUpdate := u.(RouteConfigUpdateErrTuple) + if wantErr != nil { + if gotUpdate.Err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.Err, wantErr) + } + return nil + } + if gotUpdate.Err != nil || !cmp.Equal(gotUpdate.Update, wantUpdate, protocmp.Transform()) { + return fmt.Errorf("unexpected route config update: (%v, %v), want: (%v, nil)", gotUpdate.Update, gotUpdate.Err, wantUpdate) } return nil } -func verifyClusterUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ClusterUpdate) error { +func verifyClusterUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ClusterUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for cluster update: %v", err) } - gotUpdate := u.(clusterUpdateErr) - if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate) { - return fmt.Errorf("unexpected clusterUpdate: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) + gotUpdate := u.(ClusterUpdateErrTuple) + if wantErr != nil { + if gotUpdate.Err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.Err, wantErr) + } + return nil + } + if !cmp.Equal(gotUpdate.Update, wantUpdate, protocmp.Transform()) { + return fmt.Errorf("unexpected clusterUpdate: (%v, %v), want: (%v, nil)", gotUpdate.Update, gotUpdate.Err, wantUpdate) } return nil } -func verifyEndpointsUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate EndpointsUpdate) error { +func verifyEndpointsUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate EndpointsUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for endpoints update: %v", err) } - gotUpdate := u.(endpointsUpdateErr) - if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate, cmpopts.EquateEmpty()) { - return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) + gotUpdate := u.(EndpointsUpdateErrTuple) + if wantErr != nil { + if gotUpdate.Err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.Err, wantErr) + } + return nil + } + if gotUpdate.Err != nil || !cmp.Equal(gotUpdate.Update, wantUpdate, cmpopts.EquateEmpty(), protocmp.Transform()) { + return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.Update, gotUpdate.Err, wantUpdate) } return nil } @@ -261,7 +280,7 @@ func (s) TestClientNewSingleton(t *testing.T) { defer cleanup() // The first New(). Should create a Client and a new APIClient. - client, err := New() + client, err := newRefCounted() if err != nil { t.Fatalf("failed to create client: %v", err) } @@ -278,7 +297,7 @@ func (s) TestClientNewSingleton(t *testing.T) { // and should not create new API client. const count = 9 for i := 0; i < count; i++ { - tc, terr := New() + tc, terr := newRefCounted() if terr != nil { client.Close() t.Fatalf("%d-th call to New() failed with error: %v", i, terr) @@ -322,7 +341,7 @@ func (s) TestClientNewSingleton(t *testing.T) { // Call New() again after the previous Client is actually closed. Should // create a Client and a new APIClient. - client2, err2 := New() + client2, err2 := newRefCounted() if err2 != nil { t.Fatalf("failed to create client: %v", err) } diff --git a/xds/internal/client/dump.go b/xds/internal/xdsclient/dump.go similarity index 99% rename from xds/internal/client/dump.go rename to xds/internal/xdsclient/dump.go index 3fd18f6103b..db9b474f370 100644 --- a/xds/internal/client/dump.go +++ b/xds/internal/xdsclient/dump.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import anypb "github.com/golang/protobuf/ptypes/any" diff --git a/xds/internal/client/tests/dump_test.go b/xds/internal/xdsclient/dump_test.go similarity index 76% rename from xds/internal/client/tests/dump_test.go rename to xds/internal/xdsclient/dump_test.go index 58220866eb1..d03479ca4ad 100644 --- a/xds/internal/client/tests/dump_test.go +++ b/xds/internal/xdsclient/dump_test.go @@ -16,7 +16,7 @@ * */ -package tests_test +package xdsclient_test import ( "fmt" @@ -28,7 +28,6 @@ import ( v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" - "github.com/golang/protobuf/ptypes" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/protobuf/testing/protocmp" @@ -37,9 +36,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/internal/testutils" xdstestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const defaultTestWatchExpiryTimeout = 500 * time.Millisecond @@ -56,29 +56,22 @@ func (s) TestLDSConfigDump(t *testing.T) { listenersT := &v3listenerpb.Listener{ Name: ldsTargets[i], ApiListener: &v3listenerpb.ApiListener{ - ApiListener: func() *anypb.Any { - mcm, _ := ptypes.MarshalAny(&v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{ - ConfigSource: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, - }, - RouteConfigName: routeConfigNames[i], + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, }, + RouteConfigName: routeConfigNames[i], }, - CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ - MaxStreamDuration: durationpb.New(time.Second), - }, - }) - return mcm - }(), + }, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + }), }, } - anyT, err := ptypes.MarshalAny(listenersT) - if err != nil { - t.Fatalf("failed to marshal proto to any: %v", err) - } - listenerRaws[ldsTargets[i]] = anyT + listenerRaws[ldsTargets[i]] = testutils.MarshalAny(listenersT) } client, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ @@ -90,6 +83,7 @@ func (s) TestLDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpLDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -107,16 +101,16 @@ func (s) TestLDSConfigDump(t *testing.T) { t.Fatalf(err.Error()) } - update0 := make(map[string]xdsclient.ListenerUpdate) + update0 := make(map[string]xdsclient.ListenerUpdateErrTuple) want0 := make(map[string]xdsclient.UpdateWithMD) for n, r := range listenerRaws { - update0[n] = xdsclient.ListenerUpdate{Raw: r} + update0[n] = xdsclient.ListenerUpdateErrTuple{Update: xdsclient.ListenerUpdate{Raw: r}} want0[n] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}, Raw: r, } } - client.NewListeners(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewListeners(update0, xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpLDS, testVersion, want0); err != nil { @@ -125,11 +119,13 @@ func (s) TestLDSConfigDump(t *testing.T) { const nackVersion = "lds-version-nack" var nackErr = fmt.Errorf("lds nack error") - client.NewListeners( - map[string]xdsclient.ListenerUpdate{ - ldsTargets[0]: {}, + updateHandler.NewListeners( + map[string]xdsclient.ListenerUpdateErrTuple{ + ldsTargets[0]: {Err: nackErr}, + ldsTargets[1]: {Update: xdsclient.ListenerUpdate{Raw: listenerRaws[ldsTargets[1]]}}, }, xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, Err: nackErr, @@ -143,6 +139,7 @@ func (s) TestLDSConfigDump(t *testing.T) { // message, as well as the NACK error. wantDump[ldsTargets[0]] = xdsclient.UpdateWithMD{ MD: xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, Version: testVersion, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, @@ -153,7 +150,7 @@ func (s) TestLDSConfigDump(t *testing.T) { } wantDump[ldsTargets[1]] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: nackVersion}, Raw: listenerRaws[ldsTargets[1]], } if err := compareDump(client.DumpLDS, nackVersion, wantDump); err != nil { @@ -188,11 +185,7 @@ func (s) TestRDSConfigDump(t *testing.T) { }, } - anyT, err := ptypes.MarshalAny(routeConfigT) - if err != nil { - t.Fatalf("failed to marshal proto to any: %v", err) - } - routeRaws[rdsTargets[i]] = anyT + routeRaws[rdsTargets[i]] = testutils.MarshalAny(routeConfigT) } client, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ @@ -204,6 +197,7 @@ func (s) TestRDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpRDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -221,16 +215,16 @@ func (s) TestRDSConfigDump(t *testing.T) { t.Fatalf(err.Error()) } - update0 := make(map[string]xdsclient.RouteConfigUpdate) + update0 := make(map[string]xdsclient.RouteConfigUpdateErrTuple) want0 := make(map[string]xdsclient.UpdateWithMD) for n, r := range routeRaws { - update0[n] = xdsclient.RouteConfigUpdate{Raw: r} + update0[n] = xdsclient.RouteConfigUpdateErrTuple{Update: xdsclient.RouteConfigUpdate{Raw: r}} want0[n] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}, Raw: r, } } - client.NewRouteConfigs(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewRouteConfigs(update0, xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpRDS, testVersion, want0); err != nil { @@ -239,11 +233,13 @@ func (s) TestRDSConfigDump(t *testing.T) { const nackVersion = "rds-version-nack" var nackErr = fmt.Errorf("rds nack error") - client.NewRouteConfigs( - map[string]xdsclient.RouteConfigUpdate{ - rdsTargets[0]: {}, + updateHandler.NewRouteConfigs( + map[string]xdsclient.RouteConfigUpdateErrTuple{ + rdsTargets[0]: {Err: nackErr}, + rdsTargets[1]: {Update: xdsclient.RouteConfigUpdate{Raw: routeRaws[rdsTargets[1]]}}, }, xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, Err: nackErr, @@ -257,6 +253,7 @@ func (s) TestRDSConfigDump(t *testing.T) { // message, as well as the NACK error. wantDump[rdsTargets[0]] = xdsclient.UpdateWithMD{ MD: xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, Version: testVersion, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, @@ -266,7 +263,7 @@ func (s) TestRDSConfigDump(t *testing.T) { Raw: routeRaws[rdsTargets[0]], } wantDump[rdsTargets[1]] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: nackVersion}, Raw: routeRaws[rdsTargets[1]], } if err := compareDump(client.DumpRDS, nackVersion, wantDump); err != nil { @@ -302,11 +299,7 @@ func (s) TestCDSConfigDump(t *testing.T) { }, } - anyT, err := ptypes.MarshalAny(clusterT) - if err != nil { - t.Fatalf("failed to marshal proto to any: %v", err) - } - clusterRaws[cdsTargets[i]] = anyT + clusterRaws[cdsTargets[i]] = testutils.MarshalAny(clusterT) } client, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ @@ -318,6 +311,7 @@ func (s) TestCDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpCDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -335,16 +329,16 @@ func (s) TestCDSConfigDump(t *testing.T) { t.Fatalf(err.Error()) } - update0 := make(map[string]xdsclient.ClusterUpdate) + update0 := make(map[string]xdsclient.ClusterUpdateErrTuple) want0 := make(map[string]xdsclient.UpdateWithMD) for n, r := range clusterRaws { - update0[n] = xdsclient.ClusterUpdate{Raw: r} + update0[n] = xdsclient.ClusterUpdateErrTuple{Update: xdsclient.ClusterUpdate{Raw: r}} want0[n] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}, Raw: r, } } - client.NewClusters(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewClusters(update0, xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpCDS, testVersion, want0); err != nil { @@ -353,11 +347,13 @@ func (s) TestCDSConfigDump(t *testing.T) { const nackVersion = "cds-version-nack" var nackErr = fmt.Errorf("cds nack error") - client.NewClusters( - map[string]xdsclient.ClusterUpdate{ - cdsTargets[0]: {}, + updateHandler.NewClusters( + map[string]xdsclient.ClusterUpdateErrTuple{ + cdsTargets[0]: {Err: nackErr}, + cdsTargets[1]: {Update: xdsclient.ClusterUpdate{Raw: clusterRaws[cdsTargets[1]]}}, }, xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, Err: nackErr, @@ -371,6 +367,7 @@ func (s) TestCDSConfigDump(t *testing.T) { // message, as well as the NACK error. wantDump[cdsTargets[0]] = xdsclient.UpdateWithMD{ MD: xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, Version: testVersion, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, @@ -380,7 +377,7 @@ func (s) TestCDSConfigDump(t *testing.T) { Raw: clusterRaws[cdsTargets[0]], } wantDump[cdsTargets[1]] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: nackVersion}, Raw: clusterRaws[cdsTargets[1]], } if err := compareDump(client.DumpCDS, nackVersion, wantDump); err != nil { @@ -402,11 +399,7 @@ func (s) TestEDSConfigDump(t *testing.T) { clab0.AddLocality(localityNames[i], 1, 1, []string{addrs[i]}, nil) claT := clab0.Build() - anyT, err := ptypes.MarshalAny(claT) - if err != nil { - t.Fatalf("failed to marshal proto to any: %v", err) - } - endpointRaws[edsTargets[i]] = anyT + endpointRaws[edsTargets[i]] = testutils.MarshalAny(claT) } client, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ @@ -418,6 +411,7 @@ func (s) TestEDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpEDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -435,16 +429,16 @@ func (s) TestEDSConfigDump(t *testing.T) { t.Fatalf(err.Error()) } - update0 := make(map[string]xdsclient.EndpointsUpdate) + update0 := make(map[string]xdsclient.EndpointsUpdateErrTuple) want0 := make(map[string]xdsclient.UpdateWithMD) for n, r := range endpointRaws { - update0[n] = xdsclient.EndpointsUpdate{Raw: r} + update0[n] = xdsclient.EndpointsUpdateErrTuple{Update: xdsclient.EndpointsUpdate{Raw: r}} want0[n] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}, Raw: r, } } - client.NewEndpoints(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewEndpoints(update0, xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpEDS, testVersion, want0); err != nil { @@ -453,11 +447,13 @@ func (s) TestEDSConfigDump(t *testing.T) { const nackVersion = "eds-version-nack" var nackErr = fmt.Errorf("eds nack error") - client.NewEndpoints( - map[string]xdsclient.EndpointsUpdate{ - edsTargets[0]: {}, + updateHandler.NewEndpoints( + map[string]xdsclient.EndpointsUpdateErrTuple{ + edsTargets[0]: {Err: nackErr}, + edsTargets[1]: {Update: xdsclient.EndpointsUpdate{Raw: endpointRaws[edsTargets[1]]}}, }, xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, Err: nackErr, @@ -471,6 +467,7 @@ func (s) TestEDSConfigDump(t *testing.T) { // message, as well as the NACK error. wantDump[edsTargets[0]] = xdsclient.UpdateWithMD{ MD: xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, Version: testVersion, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, @@ -480,7 +477,7 @@ func (s) TestEDSConfigDump(t *testing.T) { Raw: endpointRaws[edsTargets[0]], } wantDump[edsTargets[1]] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: nackVersion}, Raw: endpointRaws[edsTargets[1]], } if err := compareDump(client.DumpEDS, nackVersion, wantDump); err != nil { diff --git a/xds/internal/client/eds_test.go b/xds/internal/xdsclient/eds_test.go similarity index 82% rename from xds/internal/client/eds_test.go rename to xds/internal/xdsclient/eds_test.go index daa5d6525e1..d0af8a988d8 100644 --- a/xds/internal/client/eds_test.go +++ b/xds/internal/xdsclient/eds_test.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "fmt" @@ -27,10 +27,11 @@ import ( v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" - "github.com/golang/protobuf/proto" anypb "github.com/golang/protobuf/ptypes/any" wrapperspb "github.com/golang/protobuf/ptypes/wrappers" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/version" ) @@ -120,29 +121,24 @@ func (s) TestEDSParseRespProto(t *testing.T) { } func (s) TestUnmarshalEndpoints(t *testing.T) { - var v3EndpointsAny = &anypb.Any{ - TypeUrl: version.V3EndpointsURL, - Value: func() []byte { - clab0 := newClaBuilder("test", nil) - clab0.addLocality("locality-1", 1, 1, []string{"addr1:314"}, &addLocalityOptions{ - Health: []v3corepb.HealthStatus{v3corepb.HealthStatus_UNHEALTHY}, - Weight: []uint32{271}, - }) - clab0.addLocality("locality-2", 1, 0, []string{"addr2:159"}, &addLocalityOptions{ - Health: []v3corepb.HealthStatus{v3corepb.HealthStatus_DRAINING}, - Weight: []uint32{828}, - }) - e := clab0.Build() - me, _ := proto.Marshal(e) - return me - }(), - } + var v3EndpointsAny = testutils.MarshalAny(func() *v3endpointpb.ClusterLoadAssignment { + clab0 := newClaBuilder("test", nil) + clab0.addLocality("locality-1", 1, 1, []string{"addr1:314"}, &addLocalityOptions{ + Health: []v3corepb.HealthStatus{v3corepb.HealthStatus_UNHEALTHY}, + Weight: []uint32{271}, + }) + clab0.addLocality("locality-2", 1, 0, []string{"addr2:159"}, &addLocalityOptions{ + Health: []v3corepb.HealthStatus{v3corepb.HealthStatus_DRAINING}, + Weight: []uint32{828}, + }) + return clab0.Build() + }()) const testVersion = "test-version-eds" tests := []struct { name string resources []*anypb.Any - wantUpdate map[string]EndpointsUpdate + wantUpdate map[string]EndpointsUpdateErrTuple wantMD UpdateMetadata wantErr bool }{ @@ -154,7 +150,7 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -172,33 +168,26 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, }, { name: "bad endpoints resource", - resources: []*anypb.Any{ - { - TypeUrl: version.V3EndpointsURL, - Value: func() []byte { - clab0 := newClaBuilder("test", nil) - clab0.addLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) - clab0.addLocality("locality-2", 1, 2, []string{"addr2:159"}, nil) - e := clab0.Build() - me, _ := proto.Marshal(e) - return me - }(), - }, - }, - wantUpdate: map[string]EndpointsUpdate{"test": {}}, + resources: []*anypb.Any{testutils.MarshalAny(func() *v3endpointpb.ClusterLoadAssignment { + clab0 := newClaBuilder("test", nil) + clab0.addLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) + clab0.addLocality("locality-2", 1, 2, []string{"addr2:159"}, nil) + return clab0.Build() + }())}, + wantUpdate: map[string]EndpointsUpdateErrTuple{"test": {Err: cmpopts.AnyError}}, wantMD: UpdateMetadata{ Status: ServiceStatusNACKed, Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -206,8 +195,8 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { { name: "v3 endpoints", resources: []*anypb.Any{v3EndpointsAny}, - wantUpdate: map[string]EndpointsUpdate{ - "test": { + wantUpdate: map[string]EndpointsUpdateErrTuple{ + "test": {Update: EndpointsUpdate{ Drops: nil, Localities: []Locality{ { @@ -232,7 +221,7 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { }, }, Raw: v3EndpointsAny, - }, + }}, }, wantMD: UpdateMetadata{ Status: ServiceStatusACKed, @@ -244,21 +233,15 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { name: "good and bad endpoints", resources: []*anypb.Any{ v3EndpointsAny, - { - // bad endpoints resource - TypeUrl: version.V3EndpointsURL, - Value: func() []byte { - clab0 := newClaBuilder("bad", nil) - clab0.addLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) - clab0.addLocality("locality-2", 1, 2, []string{"addr2:159"}, nil) - e := clab0.Build() - me, _ := proto.Marshal(e) - return me - }(), - }, + testutils.MarshalAny(func() *v3endpointpb.ClusterLoadAssignment { + clab0 := newClaBuilder("bad", nil) + clab0.addLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) + clab0.addLocality("locality-2", 1, 2, []string{"addr2:159"}, nil) + return clab0.Build() + }()), }, - wantUpdate: map[string]EndpointsUpdate{ - "test": { + wantUpdate: map[string]EndpointsUpdateErrTuple{ + "test": {Update: EndpointsUpdate{ Drops: nil, Localities: []Locality{ { @@ -283,15 +266,15 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { }, }, Raw: v3EndpointsAny, - }, - "bad": {}, + }}, + "bad": {Err: cmpopts.AnyError}, }, wantMD: UpdateMetadata{ Status: ServiceStatusNACKed, Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -299,9 +282,13 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - update, md, err := UnmarshalEndpoints(testVersion, test.resources, nil) + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + update, md, err := UnmarshalEndpoints(opts) if (err != nil) != test.wantErr { - t.Fatalf("UnmarshalEndpoints(), got err: %v, wantErr: %v", err, test.wantErr) + t.Fatalf("UnmarshalEndpoints(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) } if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { t.Errorf("got unexpected update, diff (-got +want): %v", diff) diff --git a/xds/internal/client/errors.go b/xds/internal/xdsclient/errors.go similarity index 98% rename from xds/internal/client/errors.go rename to xds/internal/xdsclient/errors.go index 34ae2738db0..4d6cdaaf9b4 100644 --- a/xds/internal/client/errors.go +++ b/xds/internal/xdsclient/errors.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import "fmt" diff --git a/xds/internal/xdsclient/filter_chain.go b/xds/internal/xdsclient/filter_chain.go new file mode 100644 index 00000000000..f2b29f52a44 --- /dev/null +++ b/xds/internal/xdsclient/filter_chain.go @@ -0,0 +1,852 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "errors" + "fmt" + "net" + + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/version" +) + +const ( + // Used as the map key for unspecified prefixes. The actual value of this + // key is immaterial. + unspecifiedPrefixMapKey = "unspecified" + + // An unspecified destination or source prefix should be considered a less + // specific match than a wildcard prefix, `0.0.0.0/0` or `::/0`. Also, an + // unspecified prefix should match most v4 and v6 addresses compared to the + // wildcard prefixes which match only a specific network (v4 or v6). + // + // We use these constants when looking up the most specific prefix match. A + // wildcard prefix will match 0 bits, and to make sure that a wildcard + // prefix is considered a more specific match than an unspecified prefix, we + // use a value of -1 for the latter. + noPrefixMatch = -2 + unspecifiedPrefixMatch = -1 +) + +// FilterChain captures information from within a FilterChain message in a +// Listener resource. +type FilterChain struct { + // SecurityCfg contains transport socket security configuration. + SecurityCfg *SecurityConfig + // HTTPFilters represent the HTTP Filters that comprise this FilterChain. + HTTPFilters []HTTPFilter + // RouteConfigName is the route configuration name for this FilterChain. + // + // Only one of RouteConfigName and InlineRouteConfig is set. + RouteConfigName string + // InlineRouteConfig is the inline route configuration (RDS response) + // returned for this filter chain. + // + // Only one of RouteConfigName and InlineRouteConfig is set. + InlineRouteConfig *RouteConfigUpdate +} + +// VirtualHostWithInterceptors captures information present in a VirtualHost +// update, and also contains routes with instantiated HTTP Filters. +type VirtualHostWithInterceptors struct { + // Domains are the domain names which map to this Virtual Host. On the + // server side, this will be dictated by the :authority header of the + // incoming RPC. + Domains []string + // Routes are the Routes for this Virtual Host. + Routes []RouteWithInterceptors +} + +// RouteWithInterceptors captures information in a Route, and contains +// a usable matcher and also instantiated HTTP Filters. +type RouteWithInterceptors struct { + // M is the matcher used to match to this route. + M *CompositeMatcher + // RouteAction is the type of routing action to initiate once matched to. + RouteAction RouteAction + // Interceptors are interceptors instantiated for this route. These will be + // constructed from a combination of the top level configuration and any + // HTTP Filter overrides present in Virtual Host or Route. + Interceptors []resolver.ServerInterceptor +} + +// ConstructUsableRouteConfiguration takes Route Configuration and converts it +// into matchable route configuration, with instantiated HTTP Filters per route. +func (f *FilterChain) ConstructUsableRouteConfiguration(config RouteConfigUpdate) ([]VirtualHostWithInterceptors, error) { + vhs := make([]VirtualHostWithInterceptors, len(config.VirtualHosts)) + for _, vh := range config.VirtualHosts { + vhwi, err := f.convertVirtualHost(vh) + if err != nil { + return nil, fmt.Errorf("virtual host construction: %v", err) + } + vhs = append(vhs, vhwi) + } + return vhs, nil +} + +func (f *FilterChain) convertVirtualHost(virtualHost *VirtualHost) (VirtualHostWithInterceptors, error) { + rs := make([]RouteWithInterceptors, len(virtualHost.Routes)) + for i, r := range virtualHost.Routes { + var err error + rs[i].RouteAction = r.RouteAction + rs[i].M, err = RouteToMatcher(r) + if err != nil { + return VirtualHostWithInterceptors{}, fmt.Errorf("matcher construction: %v", err) + } + for _, filter := range f.HTTPFilters { + // Route is highest priority on server side, as there is no concept + // of an upstream cluster on server side. + override := r.HTTPFilterConfigOverride[filter.Name] + if override == nil { + // Virtual Host is second priority. + override = virtualHost.HTTPFilterConfigOverride[filter.Name] + } + sb, ok := filter.Filter.(httpfilter.ServerInterceptorBuilder) + if !ok { + // Should not happen if it passed xdsClient validation. + return VirtualHostWithInterceptors{}, fmt.Errorf("filter does not support use in server") + } + si, err := sb.BuildServerInterceptor(filter.Config, override) + if err != nil { + return VirtualHostWithInterceptors{}, fmt.Errorf("filter construction: %v", err) + } + if si != nil { + rs[i].Interceptors = append(rs[i].Interceptors, si) + } + } + } + return VirtualHostWithInterceptors{Domains: virtualHost.Domains, Routes: rs}, nil +} + +// SourceType specifies the connection source IP match type. +type SourceType int + +const ( + // SourceTypeAny matches connection attempts from any source. + SourceTypeAny SourceType = iota + // SourceTypeSameOrLoopback matches connection attempts from the same host. + SourceTypeSameOrLoopback + // SourceTypeExternal matches connection attempts from a different host. + SourceTypeExternal +) + +// FilterChainManager contains all the match criteria specified through all +// filter chains in a single Listener resource. It also contains the default +// filter chain specified in the Listener resource. It provides two important +// pieces of functionality: +// 1. Validate the filter chains in an incoming Listener resource to make sure +// that there aren't filter chains which contain the same match criteria. +// 2. As part of performing the above validation, it builds an internal data +// structure which will if used to look up the matching filter chain at +// connection time. +// +// The logic specified in the documentation around the xDS FilterChainMatch +// proto mentions 8 criteria to match on. +// The following order applies: +// +// 1. Destination port. +// 2. Destination IP address. +// 3. Server name (e.g. SNI for TLS protocol), +// 4. Transport protocol. +// 5. Application protocols (e.g. ALPN for TLS protocol). +// 6. Source type (e.g. any, local or external network). +// 7. Source IP address. +// 8. Source port. +type FilterChainManager struct { + // Destination prefix is the first match criteria that we support. + // Therefore, this multi-stage map is indexed on destination prefixes + // specified in the match criteria. + // Unspecified destination prefix matches end up as a wildcard entry here + // with a key of 0.0.0.0/0. + dstPrefixMap map[string]*destPrefixEntry + + // At connection time, we do not have the actual destination prefix to match + // on. We only have the real destination address of the incoming connection. + // This means that we cannot use the above map at connection time. This list + // contains the map entries from the above map that we can use at connection + // time to find matching destination prefixes in O(n) time. + // + // TODO: Implement LC-trie to support logarithmic time lookups. If that + // involves too much time/effort, sort this slice based on the netmask size. + dstPrefixes []*destPrefixEntry + + def *FilterChain // Default filter chain, if specified. + + // RouteConfigNames are the route configuration names which need to be + // dynamically queried for RDS Configuration for any FilterChains which + // specify to load RDS Configuration dynamically. + RouteConfigNames map[string]bool +} + +// destPrefixEntry is the value type of the map indexed on destination prefixes. +type destPrefixEntry struct { + // The actual destination prefix. Set to nil for unspecified prefixes. + net *net.IPNet + // We need to keep track of the transport protocols seen as part of the + // config validation (and internal structure building) phase. The only two + // values that we support are empty string and "raw_buffer", with the latter + // taking preference. Once we have seen one filter chain with "raw_buffer", + // we can drop everything other filter chain with an empty transport + // protocol. + rawBufferSeen bool + // For each specified source type in the filter chain match criteria, this + // array points to the set of specified source prefixes. + // Unspecified source type matches end up as a wildcard entry here with an + // index of 0, which actually represents the source type `ANY`. + srcTypeArr sourceTypesArray +} + +// An array for the fixed number of source types that we have. +type sourceTypesArray [3]*sourcePrefixes + +// sourcePrefixes contains source prefix related information specified in the +// match criteria. These are pointed to by the array of source types. +type sourcePrefixes struct { + // These are very similar to the 'dstPrefixMap' and 'dstPrefixes' field of + // FilterChainManager. Go there for more info. + srcPrefixMap map[string]*sourcePrefixEntry + srcPrefixes []*sourcePrefixEntry +} + +// sourcePrefixEntry contains match criteria per source prefix. +type sourcePrefixEntry struct { + // The actual destination prefix. Set to nil for unspecified prefixes. + net *net.IPNet + // Mapping from source ports specified in the match criteria to the actual + // filter chain. Unspecified source port matches en up as a wildcard entry + // here with a key of 0. + srcPortMap map[int]*FilterChain +} + +// NewFilterChainManager parses the received Listener resource and builds a +// FilterChainManager. Returns a non-nil error on validation failures. +// +// This function is only exported so that tests outside of this package can +// create a FilterChainManager. +func NewFilterChainManager(lis *v3listenerpb.Listener) (*FilterChainManager, error) { + // Parse all the filter chains and build the internal data structures. + fci := &FilterChainManager{ + dstPrefixMap: make(map[string]*destPrefixEntry), + RouteConfigNames: make(map[string]bool), + } + if err := fci.addFilterChains(lis.GetFilterChains()); err != nil { + return nil, err + } + // Build the source and dest prefix slices used by Lookup(). + fcSeen := false + for _, dstPrefix := range fci.dstPrefixMap { + fci.dstPrefixes = append(fci.dstPrefixes, dstPrefix) + for _, st := range dstPrefix.srcTypeArr { + if st == nil { + continue + } + for _, srcPrefix := range st.srcPrefixMap { + st.srcPrefixes = append(st.srcPrefixes, srcPrefix) + for _, fc := range srcPrefix.srcPortMap { + if fc != nil { + fcSeen = true + } + } + } + } + } + + // Retrieve the default filter chain. The match criteria specified on the + // default filter chain is never used. The default filter chain simply gets + // used when none of the other filter chains match. + var def *FilterChain + if dfc := lis.GetDefaultFilterChain(); dfc != nil { + var err error + if def, err = fci.filterChainFromProto(dfc); err != nil { + return nil, err + } + } + fci.def = def + + // If there are no supported filter chains and no default filter chain, we + // fail here. This will call the Listener resource to be NACK'ed. + if !fcSeen && fci.def == nil { + return nil, fmt.Errorf("no supported filter chains and no default filter chain") + } + return fci, nil +} + +// addFilterChains parses the filter chains in fcs and adds the required +// internal data structures corresponding to the match criteria. +func (fci *FilterChainManager) addFilterChains(fcs []*v3listenerpb.FilterChain) error { + for _, fc := range fcs { + fcm := fc.GetFilterChainMatch() + if fcm.GetDestinationPort().GetValue() != 0 { + // Destination port is the first match criteria and we do not + // support filter chains which contains this match criteria. + logger.Warningf("Dropping filter chain %+v since it contains unsupported destination_port match field", fc) + continue + } + + // Build the internal representation of the filter chain match fields. + if err := fci.addFilterChainsForDestPrefixes(fc); err != nil { + return err + } + } + + return nil +} + +func (fci *FilterChainManager) addFilterChainsForDestPrefixes(fc *v3listenerpb.FilterChain) error { + ranges := fc.GetFilterChainMatch().GetPrefixRanges() + dstPrefixes := make([]*net.IPNet, 0, len(ranges)) + for _, pr := range ranges { + cidr := fmt.Sprintf("%s/%d", pr.GetAddressPrefix(), pr.GetPrefixLen().GetValue()) + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("failed to parse destination prefix range: %+v", pr) + } + dstPrefixes = append(dstPrefixes, ipnet) + } + + if len(dstPrefixes) == 0 { + // Use the unspecified entry when destination prefix is unspecified, and + // set the `net` field to nil. + if fci.dstPrefixMap[unspecifiedPrefixMapKey] == nil { + fci.dstPrefixMap[unspecifiedPrefixMapKey] = &destPrefixEntry{} + } + return fci.addFilterChainsForServerNames(fci.dstPrefixMap[unspecifiedPrefixMapKey], fc) + } + for _, prefix := range dstPrefixes { + p := prefix.String() + if fci.dstPrefixMap[p] == nil { + fci.dstPrefixMap[p] = &destPrefixEntry{net: prefix} + } + if err := fci.addFilterChainsForServerNames(fci.dstPrefixMap[p], fc); err != nil { + return err + } + } + return nil +} + +func (fci *FilterChainManager) addFilterChainsForServerNames(dstEntry *destPrefixEntry, fc *v3listenerpb.FilterChain) error { + // Filter chains specifying server names in their match criteria always fail + // a match at connection time. So, these filter chains can be dropped now. + if len(fc.GetFilterChainMatch().GetServerNames()) != 0 { + logger.Warningf("Dropping filter chain %+v since it contains unsupported server_names match field", fc) + return nil + } + + return fci.addFilterChainsForTransportProtocols(dstEntry, fc) +} + +func (fci *FilterChainManager) addFilterChainsForTransportProtocols(dstEntry *destPrefixEntry, fc *v3listenerpb.FilterChain) error { + tp := fc.GetFilterChainMatch().GetTransportProtocol() + switch { + case tp != "" && tp != "raw_buffer": + // Only allow filter chains with transport protocol set to empty string + // or "raw_buffer". + logger.Warningf("Dropping filter chain %+v since it contains unsupported value for transport_protocols match field", fc) + return nil + case tp == "" && dstEntry.rawBufferSeen: + // If we have already seen filter chains with transport protocol set to + // "raw_buffer", we can drop filter chains with transport protocol set + // to empty string, since the former takes precedence. + logger.Warningf("Dropping filter chain %+v since it contains unsupported value for transport_protocols match field", fc) + return nil + case tp != "" && !dstEntry.rawBufferSeen: + // This is the first "raw_buffer" that we are seeing. Set the bit and + // reset the source types array which might contain entries for filter + // chains with transport protocol set to empty string. + dstEntry.rawBufferSeen = true + dstEntry.srcTypeArr = sourceTypesArray{} + } + return fci.addFilterChainsForApplicationProtocols(dstEntry, fc) +} + +func (fci *FilterChainManager) addFilterChainsForApplicationProtocols(dstEntry *destPrefixEntry, fc *v3listenerpb.FilterChain) error { + if len(fc.GetFilterChainMatch().GetApplicationProtocols()) != 0 { + logger.Warningf("Dropping filter chain %+v since it contains unsupported application_protocols match field", fc) + return nil + } + return fci.addFilterChainsForSourceType(dstEntry, fc) +} + +// addFilterChainsForSourceType adds source types to the internal data +// structures and delegates control to addFilterChainsForSourcePrefixes to +// continue building the internal data structure. +func (fci *FilterChainManager) addFilterChainsForSourceType(dstEntry *destPrefixEntry, fc *v3listenerpb.FilterChain) error { + var srcType SourceType + switch st := fc.GetFilterChainMatch().GetSourceType(); st { + case v3listenerpb.FilterChainMatch_ANY: + srcType = SourceTypeAny + case v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK: + srcType = SourceTypeSameOrLoopback + case v3listenerpb.FilterChainMatch_EXTERNAL: + srcType = SourceTypeExternal + default: + return fmt.Errorf("unsupported source type: %v", st) + } + + st := int(srcType) + if dstEntry.srcTypeArr[st] == nil { + dstEntry.srcTypeArr[st] = &sourcePrefixes{srcPrefixMap: make(map[string]*sourcePrefixEntry)} + } + return fci.addFilterChainsForSourcePrefixes(dstEntry.srcTypeArr[st].srcPrefixMap, fc) +} + +// addFilterChainsForSourcePrefixes adds source prefixes to the internal data +// structures and delegates control to addFilterChainsForSourcePorts to continue +// building the internal data structure. +func (fci *FilterChainManager) addFilterChainsForSourcePrefixes(srcPrefixMap map[string]*sourcePrefixEntry, fc *v3listenerpb.FilterChain) error { + ranges := fc.GetFilterChainMatch().GetSourcePrefixRanges() + srcPrefixes := make([]*net.IPNet, 0, len(ranges)) + for _, pr := range fc.GetFilterChainMatch().GetSourcePrefixRanges() { + cidr := fmt.Sprintf("%s/%d", pr.GetAddressPrefix(), pr.GetPrefixLen().GetValue()) + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("failed to parse source prefix range: %+v", pr) + } + srcPrefixes = append(srcPrefixes, ipnet) + } + + if len(srcPrefixes) == 0 { + // Use the unspecified entry when destination prefix is unspecified, and + // set the `net` field to nil. + if srcPrefixMap[unspecifiedPrefixMapKey] == nil { + srcPrefixMap[unspecifiedPrefixMapKey] = &sourcePrefixEntry{ + srcPortMap: make(map[int]*FilterChain), + } + } + return fci.addFilterChainsForSourcePorts(srcPrefixMap[unspecifiedPrefixMapKey], fc) + } + for _, prefix := range srcPrefixes { + p := prefix.String() + if srcPrefixMap[p] == nil { + srcPrefixMap[p] = &sourcePrefixEntry{ + net: prefix, + srcPortMap: make(map[int]*FilterChain), + } + } + if err := fci.addFilterChainsForSourcePorts(srcPrefixMap[p], fc); err != nil { + return err + } + } + return nil +} + +// addFilterChainsForSourcePorts adds source ports to the internal data +// structures and completes the process of building the internal data structure. +// It is here that we determine if there are multiple filter chains with +// overlapping matching rules. +func (fci *FilterChainManager) addFilterChainsForSourcePorts(srcEntry *sourcePrefixEntry, fcProto *v3listenerpb.FilterChain) error { + ports := fcProto.GetFilterChainMatch().GetSourcePorts() + srcPorts := make([]int, 0, len(ports)) + for _, port := range ports { + srcPorts = append(srcPorts, int(port)) + } + + fc, err := fci.filterChainFromProto(fcProto) + if err != nil { + return err + } + + if len(srcPorts) == 0 { + // Use the wildcard port '0', when source ports are unspecified. + if curFC := srcEntry.srcPortMap[0]; curFC != nil { + return errors.New("multiple filter chains with overlapping matching rules are defined") + } + srcEntry.srcPortMap[0] = fc + return nil + } + for _, port := range srcPorts { + if curFC := srcEntry.srcPortMap[port]; curFC != nil { + return errors.New("multiple filter chains with overlapping matching rules are defined") + } + srcEntry.srcPortMap[port] = fc + } + return nil +} + +// filterChainFromProto extracts the relevant information from the FilterChain +// proto and stores it in our internal representation. It also persists any +// RouteNames which need to be queried dynamically via RDS. +func (fci *FilterChainManager) filterChainFromProto(fc *v3listenerpb.FilterChain) (*FilterChain, error) { + filterChain, err := processNetworkFilters(fc.GetFilters()) + if err != nil { + return nil, err + } + // These route names will be dynamically queried via RDS in the wrapped + // listener, which receives the LDS response, if specified for the filter + // chain. + if filterChain.RouteConfigName != "" { + fci.RouteConfigNames[filterChain.RouteConfigName] = true + } + // If the transport_socket field is not specified, it means that the control + // plane has not sent us any security config. This is fine and the server + // will use the fallback credentials configured as part of the + // xdsCredentials. + ts := fc.GetTransportSocket() + if ts == nil { + return filterChain, nil + } + if name := ts.GetName(); name != transportSocketName { + return nil, fmt.Errorf("transport_socket field has unexpected name: %s", name) + } + any := ts.GetTypedConfig() + if any == nil || any.TypeUrl != version.V3DownstreamTLSContextURL { + return nil, fmt.Errorf("transport_socket field has unexpected typeURL: %s", any.TypeUrl) + } + downstreamCtx := &v3tlspb.DownstreamTlsContext{} + if err := proto.Unmarshal(any.GetValue(), downstreamCtx); err != nil { + return nil, fmt.Errorf("failed to unmarshal DownstreamTlsContext in LDS response: %v", err) + } + if downstreamCtx.GetRequireSni().GetValue() { + return nil, fmt.Errorf("require_sni field set to true in DownstreamTlsContext message: %v", downstreamCtx) + } + if downstreamCtx.GetOcspStaplePolicy() != v3tlspb.DownstreamTlsContext_LENIENT_STAPLING { + return nil, fmt.Errorf("ocsp_staple_policy field set to unsupported value in DownstreamTlsContext message: %v", downstreamCtx) + } + // The following fields from `DownstreamTlsContext` are ignore: + // - disable_stateless_session_resumption + // - session_ticket_keys + // - session_ticket_keys_sds_secret_config + // - session_timeout + if downstreamCtx.GetCommonTlsContext() == nil { + return nil, errors.New("DownstreamTlsContext in LDS response does not contain a CommonTlsContext") + } + sc, err := securityConfigFromCommonTLSContext(downstreamCtx.GetCommonTlsContext(), true) + if err != nil { + return nil, err + } + if sc == nil { + // sc == nil is a valid case where the control plane has not sent us any + // security configuration. xDS creds will use fallback creds. + return filterChain, nil + } + sc.RequireClientCert = downstreamCtx.GetRequireClientCertificate().GetValue() + if sc.RequireClientCert && sc.RootInstanceName == "" { + return nil, errors.New("security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set") + } + filterChain.SecurityCfg = sc + return filterChain, nil +} + +func processNetworkFilters(filters []*v3listenerpb.Filter) (*FilterChain, error) { + filterChain := &FilterChain{} + seenNames := make(map[string]bool, len(filters)) + seenHCM := false + for _, filter := range filters { + name := filter.GetName() + if name == "" { + return nil, fmt.Errorf("network filters {%+v} is missing name field in filter: {%+v}", filters, filter) + } + if seenNames[name] { + return nil, fmt.Errorf("network filters {%+v} has duplicate filter name %q", filters, name) + } + seenNames[name] = true + + // Network filters have a oneof field named `config_type` where we + // only support `TypedConfig` variant. + switch typ := filter.GetConfigType().(type) { + case *v3listenerpb.Filter_TypedConfig: + // The typed_config field has an `anypb.Any` proto which could + // directly contain the serialized bytes of the actual filter + // configuration, or it could be encoded as a `TypedStruct`. + // TODO: Add support for `TypedStruct`. + tc := filter.GetTypedConfig() + + // The only network filter that we currently support is the v3 + // HttpConnectionManager. So, we can directly check the type_url + // and unmarshal the config. + // TODO: Implement a registry of supported network filters (like + // we have for HTTP filters), when we have to support network + // filters other than HttpConnectionManager. + if tc.GetTypeUrl() != version.V3HTTPConnManagerURL { + return nil, fmt.Errorf("network filters {%+v} has unsupported network filter %q in filter {%+v}", filters, tc.GetTypeUrl(), filter) + } + hcm := &v3httppb.HttpConnectionManager{} + if err := ptypes.UnmarshalAny(tc, hcm); err != nil { + return nil, fmt.Errorf("network filters {%+v} failed unmarshaling of network filter {%+v}: %v", filters, filter, err) + } + // "Any filters after HttpConnectionManager should be ignored during + // connection processing but still be considered for validity. + // HTTPConnectionManager must have valid http_filters." - A36 + filters, err := processHTTPFilters(hcm.GetHttpFilters(), true) + if err != nil { + return nil, fmt.Errorf("network filters {%+v} had invalid server side HTTP Filters {%+v}: %v", filters, hcm.GetHttpFilters(), err) + } + if !seenHCM { + // Validate for RBAC in only the HCM that will be used, since this isn't a logical validation failure, + // it's simply a validation to support RBAC HTTP Filter. + // "HttpConnectionManager.xff_num_trusted_hops must be unset or zero and + // HttpConnectionManager.original_ip_detection_extensions must be empty. If + // either field has an incorrect value, the Listener must be NACKed." - A41 + if hcm.XffNumTrustedHops != 0 { + return nil, fmt.Errorf("xff_num_trusted_hops must be unset or zero %+v", hcm) + } + if len(hcm.OriginalIpDetectionExtensions) != 0 { + return nil, fmt.Errorf("original_ip_detection_extensions must be empty %+v", hcm) + } + + // TODO: Implement terminal filter logic, as per A36. + filterChain.HTTPFilters = filters + seenHCM = true + if !env.RBACSupport { + continue + } + switch hcm.RouteSpecifier.(type) { + case *v3httppb.HttpConnectionManager_Rds: + if hcm.GetRds().GetConfigSource().GetAds() == nil { + return nil, fmt.Errorf("ConfigSource is not ADS: %+v", hcm) + } + name := hcm.GetRds().GetRouteConfigName() + if name == "" { + return nil, fmt.Errorf("empty route_config_name: %+v", hcm) + } + filterChain.RouteConfigName = name + case *v3httppb.HttpConnectionManager_RouteConfig: + // "RouteConfiguration validation logic inherits all + // previous validations made for client-side usage as RDS + // does not distinguish between client-side and + // server-side." - A36 + // Can specify v3 here, as will never get to this function + // if v2. + routeU, err := generateRDSUpdateFromRouteConfiguration(hcm.GetRouteConfig(), nil, false) + if err != nil { + return nil, fmt.Errorf("failed to parse inline RDS resp: %v", err) + } + filterChain.InlineRouteConfig = &routeU + case nil: + // No-op, as no route specifier is a valid configuration on + // the server side. + default: + return nil, fmt.Errorf("unsupported type %T for RouteSpecifier", hcm.RouteSpecifier) + } + } + default: + return nil, fmt.Errorf("network filters {%+v} has unsupported config_type %T in filter %s", filters, typ, filter.GetName()) + } + } + if !seenHCM { + return nil, fmt.Errorf("network filters {%+v} missing HttpConnectionManager filter", filters) + } + return filterChain, nil +} + +// FilterChainLookupParams wraps parameters to be passed to Lookup. +type FilterChainLookupParams struct { + // IsUnspecified indicates whether the server is listening on a wildcard + // address, "0.0.0.0" for IPv4 and "::" for IPv6. Only when this is set to + // true, do we consider the destination prefixes specified in the filter + // chain match criteria. + IsUnspecifiedListener bool + // DestAddr is the local address of an incoming connection. + DestAddr net.IP + // SourceAddr is the remote address of an incoming connection. + SourceAddr net.IP + // SourcePort is the remote port of an incoming connection. + SourcePort int +} + +// Lookup returns the most specific matching filter chain to be used for an +// incoming connection on the server side. +// +// Returns a non-nil error if no matching filter chain could be found or +// multiple matching filter chains were found, and in both cases, the incoming +// connection must be dropped. +func (fci *FilterChainManager) Lookup(params FilterChainLookupParams) (*FilterChain, error) { + dstPrefixes := filterByDestinationPrefixes(fci.dstPrefixes, params.IsUnspecifiedListener, params.DestAddr) + if len(dstPrefixes) == 0 { + if fci.def != nil { + return fci.def, nil + } + return nil, fmt.Errorf("no matching filter chain based on destination prefix match for %+v", params) + } + + srcType := SourceTypeExternal + if params.SourceAddr.Equal(params.DestAddr) || params.SourceAddr.IsLoopback() { + srcType = SourceTypeSameOrLoopback + } + srcPrefixes := filterBySourceType(dstPrefixes, srcType) + if len(srcPrefixes) == 0 { + if fci.def != nil { + return fci.def, nil + } + return nil, fmt.Errorf("no matching filter chain based on source type match for %+v", params) + } + srcPrefixEntry, err := filterBySourcePrefixes(srcPrefixes, params.SourceAddr) + if err != nil { + return nil, err + } + if fc := filterBySourcePorts(srcPrefixEntry, params.SourcePort); fc != nil { + return fc, nil + } + if fci.def != nil { + return fci.def, nil + } + return nil, fmt.Errorf("no matching filter chain after all match criteria for %+v", params) +} + +// filterByDestinationPrefixes is the first stage of the filter chain +// matching algorithm. It takes the complete set of configured filter chain +// matchers and returns the most specific matchers based on the destination +// prefix match criteria (the prefixes which match the most number of bits). +func filterByDestinationPrefixes(dstPrefixes []*destPrefixEntry, isUnspecified bool, dstAddr net.IP) []*destPrefixEntry { + if !isUnspecified { + // Destination prefix matchers are considered only when the listener is + // bound to the wildcard address. + return dstPrefixes + } + + var matchingDstPrefixes []*destPrefixEntry + maxSubnetMatch := noPrefixMatch + for _, prefix := range dstPrefixes { + if prefix.net != nil && !prefix.net.Contains(dstAddr) { + // Skip prefixes which don't match. + continue + } + // For unspecified prefixes, since we do not store a real net.IPNet + // inside prefix, we do not perform a match. Instead we simply set + // the matchSize to -1, which is less than the matchSize (0) for a + // wildcard prefix. + matchSize := unspecifiedPrefixMatch + if prefix.net != nil { + matchSize, _ = prefix.net.Mask.Size() + } + if matchSize < maxSubnetMatch { + continue + } + if matchSize > maxSubnetMatch { + maxSubnetMatch = matchSize + matchingDstPrefixes = make([]*destPrefixEntry, 0, 1) + } + matchingDstPrefixes = append(matchingDstPrefixes, prefix) + } + return matchingDstPrefixes +} + +// filterBySourceType is the second stage of the matching algorithm. It +// trims the filter chains based on the most specific source type match. +func filterBySourceType(dstPrefixes []*destPrefixEntry, srcType SourceType) []*sourcePrefixes { + var ( + srcPrefixes []*sourcePrefixes + bestSrcTypeMatch int + ) + for _, prefix := range dstPrefixes { + var ( + srcPrefix *sourcePrefixes + match int + ) + switch srcType { + case SourceTypeExternal: + match = int(SourceTypeExternal) + srcPrefix = prefix.srcTypeArr[match] + case SourceTypeSameOrLoopback: + match = int(SourceTypeSameOrLoopback) + srcPrefix = prefix.srcTypeArr[match] + } + if srcPrefix == nil { + match = int(SourceTypeAny) + srcPrefix = prefix.srcTypeArr[match] + } + if match < bestSrcTypeMatch { + continue + } + if match > bestSrcTypeMatch { + bestSrcTypeMatch = match + srcPrefixes = make([]*sourcePrefixes, 0) + } + if srcPrefix != nil { + // The source type array always has 3 entries, but these could be + // nil if the appropriate source type match was not specified. + srcPrefixes = append(srcPrefixes, srcPrefix) + } + } + return srcPrefixes +} + +// filterBySourcePrefixes is the third stage of the filter chain matching +// algorithm. It trims the filter chains based on the source prefix. At most one +// filter chain with the most specific match progress to the next stage. +func filterBySourcePrefixes(srcPrefixes []*sourcePrefixes, srcAddr net.IP) (*sourcePrefixEntry, error) { + var matchingSrcPrefixes []*sourcePrefixEntry + maxSubnetMatch := noPrefixMatch + for _, sp := range srcPrefixes { + for _, prefix := range sp.srcPrefixes { + if prefix.net != nil && !prefix.net.Contains(srcAddr) { + // Skip prefixes which don't match. + continue + } + // For unspecified prefixes, since we do not store a real net.IPNet + // inside prefix, we do not perform a match. Instead we simply set + // the matchSize to -1, which is less than the matchSize (0) for a + // wildcard prefix. + matchSize := unspecifiedPrefixMatch + if prefix.net != nil { + matchSize, _ = prefix.net.Mask.Size() + } + if matchSize < maxSubnetMatch { + continue + } + if matchSize > maxSubnetMatch { + maxSubnetMatch = matchSize + matchingSrcPrefixes = make([]*sourcePrefixEntry, 0, 1) + } + matchingSrcPrefixes = append(matchingSrcPrefixes, prefix) + } + } + if len(matchingSrcPrefixes) == 0 { + // Finding no match is not an error condition. The caller will end up + // using the default filter chain if one was configured. + return nil, nil + } + // We expect at most a single matching source prefix entry at this point. If + // we have multiple entries here, and some of their source port matchers had + // wildcard entries, we could be left with more than one matching filter + // chain and hence would have been flagged as an invalid configuration at + // config validation time. + if len(matchingSrcPrefixes) != 1 { + return nil, errors.New("multiple matching filter chains") + } + return matchingSrcPrefixes[0], nil +} + +// filterBySourcePorts is the last stage of the filter chain matching +// algorithm. It trims the filter chains based on the source ports. +func filterBySourcePorts(spe *sourcePrefixEntry, srcPort int) *FilterChain { + if spe == nil { + return nil + } + // A match could be a wildcard match (this happens when the match + // criteria does not specify source ports) or a specific port match (this + // happens when the match criteria specifies a set of ports and the source + // port of the incoming connection matches one of the specified ports). The + // latter is considered to be a more specific match. + if fc := spe.srcPortMap[srcPort]; fc != nil { + return fc + } + if fc := spe.srcPortMap[0]; fc != nil { + return fc + } + return nil +} diff --git a/xds/internal/xdsclient/filter_chain_test.go b/xds/internal/xdsclient/filter_chain_test.go new file mode 100644 index 00000000000..2cc73b0a511 --- /dev/null +++ b/xds/internal/xdsclient/filter_chain_test.go @@ -0,0 +1,2939 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "testing" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3routerpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/router/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/wrapperspb" + + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/httpfilter/router" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/grpc/xds/internal/version" +) + +const ( + topLevel = "top level" + vhLevel = "virtual host level" + rLevel = "route level" +) + +var ( + routeConfig = &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}} + inlineRouteConfig = &RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*Route{{Prefix: newStringP("/"), RouteAction: RouteActionNonForwardingAction}}, + }}} + emptyValidNetworkFilters = []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + } + validServerSideHTTPFilter1 = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } + validServerSideHTTPFilter2 = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter2", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } + emptyRouterFilter = e2e.RouterHTTPFilter + routerBuilder = httpfilter.Get(router.TypeURL) + routerConfig, _ = routerBuilder.ParseFilterConfig(testutils.MarshalAny(&v3routerpb.Router{})) + routerFilter = HTTPFilter{Name: "router", Filter: routerBuilder, Config: routerConfig} + routerFilterList = []HTTPFilter{routerFilter} +) + +// TestNewFilterChainImpl_Failure_BadMatchFields verifies cases where we have a +// single filter chain with match criteria that contains unsupported fields. +func TestNewFilterChainImpl_Failure_BadMatchFields(t *testing.T) { + tests := []struct { + desc string + lis *v3listenerpb.Listener + }{ + { + desc: "unsupported destination port field", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{DestinationPort: &wrapperspb.UInt32Value{Value: 666}}, + }, + }, + }, + }, + { + desc: "unsupported server names field", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ServerNames: []string{"example-server"}}, + }, + }, + }, + }, + { + desc: "unsupported transport protocol ", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{TransportProtocol: "tls"}, + }, + }, + }, + }, + { + desc: "unsupported application protocol field", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ApplicationProtocols: []string{"h2"}}, + }, + }, + }, + }, + { + desc: "bad dest address prefix", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{{AddressPrefix: "a.b.c.d"}}}, + }, + }, + }, + }, + { + desc: "bad dest prefix length", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.1.1.0", 50)}}, + }, + }, + }, + }, + { + desc: "bad source address prefix", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePrefixRanges: []*v3corepb.CidrRange{{AddressPrefix: "a.b.c.d"}}}, + }, + }, + }, + }, + { + desc: "bad source prefix length", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.1.1.0", 50)}}, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if fci, err := NewFilterChainManager(test.lis); err == nil { + t.Fatalf("NewFilterChainManager() returned %v when expected to fail", fci) + } + }) + } +} + +// TestNewFilterChainImpl_Failure_OverlappingMatchingRules verifies cases where +// there are multiple filter chains and they have overlapping match rules. +func TestNewFilterChainImpl_Failure_OverlappingMatchingRules(t *testing.T) { + tests := []struct { + desc string + lis *v3listenerpb.Listener + }{ + { + desc: "matching destination prefixes with no other matchers", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16), cidrRangeFromAddressAndPrefixLen("10.0.0.0", 0)}, + }, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.2.2", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + }, + { + desc: "matching source type", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_ANY}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_EXTERNAL}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_EXTERNAL}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + }, + { + desc: "matching source prefixes", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16), cidrRangeFromAddressAndPrefixLen("10.0.0.0", 0)}, + }, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.2.2", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + }, + { + desc: "matching source ports", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3, 4, 5}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{5, 6, 7}}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + }, + } + + const wantErr = "multiple filter chains with overlapping matching rules are defined" + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if _, err := NewFilterChainManager(test.lis); err == nil || !strings.Contains(err.Error(), wantErr) { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: %s", err, wantErr) + } + }) + } +} + +// TestNewFilterChainImpl_Failure_BadSecurityConfig verifies cases where the +// security configuration in the filter chain is invalid. +func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { + tests := []struct { + desc string + lis *v3listenerpb.Listener + wantErr string + }{ + { + desc: "no filter chains", + lis: &v3listenerpb.Listener{}, + wantErr: "no supported filter chains and no default filter chain", + }, + { + desc: "unexpected transport socket name", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{Name: "unsupported-transport-socket-name"}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "transport_socket field has unexpected name", + }, + { + desc: "unexpected transport socket URL", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{}), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "transport_socket field has unexpected typeURL", + }, + { + desc: "badly marshaled transport socket", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3DownstreamTLSContextURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "failed to unmarshal DownstreamTlsContext in LDS response", + }, + { + desc: "missing CommonTlsContext", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{}), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "DownstreamTlsContext in LDS response does not contain a CommonTlsContext", + }, + { + desc: "require_sni-set-to-true-in-downstreamTlsContext", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireSni: &wrapperspb.BoolValue{Value: true}, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "require_sni field set to true in DownstreamTlsContext message", + }, + { + desc: "unsupported-ocsp_staple_policy-in-downstreamTlsContext", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + OcspStaplePolicy: v3tlspb.DownstreamTlsContext_STRICT_STAPLING, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "ocsp_staple_policy field set to unsupported value in DownstreamTlsContext message", + }, + { + desc: "unsupported validation context in transport socket", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "validation context contains unexpected type", + }, + { + desc: "unsupported match_subject_alt_names field in transport socket", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "validation context contains unexpected type", + }, + { + desc: "no root certificate provider with require_client_cert", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set", + }, + { + desc: "no identity certificate provider", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{}, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "security configuration on the server-side does not contain identity certificate provider instance name", + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + _, err := NewFilterChainManager(test.lis) + if err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: %s", err, test.wantErr) + } + }) + } +} + +// TestNewFilterChainImpl_Success_RouteUpdate tests the construction of the +// filter chain with valid HTTP Filters present. +func TestNewFilterChainImpl_Success_RouteUpdate(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + name string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + name: "rds", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + RouteConfigName: "route-1", + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + RouteConfigName: "route-1", + HTTPFilters: routerFilterList, + }, + RouteConfigNames: map[string]bool{"route-1": true}, + }, + }, + { + name: "inline route config", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + // two rds tests whether the Filter Chain Manager successfully persists + // the two RDS names that need to be dynamically queried. + { + name: "two rds", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-2", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + RouteConfigName: "route-1", + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + RouteConfigName: "route-2", + HTTPFilters: routerFilterList, + }, + RouteConfigNames: map[string]bool{ + "route-1": true, + "route-2": true, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpOpts) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +// TestNewFilterChainImpl_Failure_BadRouteUpdate verifies cases where the Route +// Update in the filter chain are invalid. +func TestNewFilterChainImpl_Failure_BadRouteUpdate(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + name string + lis *v3listenerpb.Listener + wantErr string + }{ + { + name: "not-ads", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantErr: "ConfigSource is not ADS", + }, + { + name: "unsupported-route-specifier", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantErr: "unsupported type", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := NewFilterChainManager(test.lis) + if err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: %s", err, test.wantErr) + } + }) + } +} + +// TestNewFilterChainImpl_Failure_BadHTTPFilters verifies cases where the HTTP +// Filters in the filter chain are invalid. +func TestNewFilterChainImpl_Failure_BadHTTPFilters(t *testing.T) { + tests := []struct { + name string + lis *v3listenerpb.Listener + wantErr string + }{ + { + name: "client side HTTP filter", + lis: &v3listenerpb.Listener{ + Name: "grpc/server?xds.resource.listening_address=0.0.0.0:9999", + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + { + Name: "clientOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, + }, + }, + }), + }, + }, + }, + }, + }, + }, + wantErr: "invalid server side HTTP Filters", + }, + { + name: "one valid then one invalid HTTP filter", + lis: &v3listenerpb.Listener{ + Name: "grpc/server?xds.resource.listening_address=0.0.0.0:9999", + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + { + Name: "clientOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, + }, + }, + }), + }, + }, + }, + }, + }, + }, + wantErr: "invalid server side HTTP Filters", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := NewFilterChainManager(test.lis) + if err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: %s", err, test.wantErr) + } + }) + } +} + +// TestNewFilterChainImpl_Success_HTTPFilters tests the construction of the +// filter chain with valid HTTP Filters present. +func TestNewFilterChainImpl_Success_HTTPFilters(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + name string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + name: "singular valid http filter", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: {HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + { + name: "two valid http filters", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: {HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + // In the case of two HTTP Connection Manager's being present, the + // second HTTP Connection Manager should be validated, but ignored. + { + name: "two hcms", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + { + Name: "hcm2", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + { + Name: "hcm2", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: {HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpOpts) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +// TestNewFilterChainImpl_Success_SecurityConfig verifies cases where the +// security configuration in the filter chain contains valid data. +func TestNewFilterChainImpl_Success_SecurityConfig(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + desc string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + desc: "empty transport socket", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: emptyValidNetworkFilters, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "no validation context", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "validation context with certificate provider", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultRootPluginInstance", + CertificateName: "defaultRootCertName", + }, + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + RootInstanceName: "rootPluginInstance", + RootCertName: "rootCertName", + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + RootInstanceName: "defaultRootPluginInstance", + RootCertName: "defaultRootCertName", + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpopts.EquateEmpty()) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +// TestNewFilterChainImpl_Success_UnsupportedMatchFields verifies cases where +// there are multiple filter chains, and one of them is valid while the other +// contains unsupported match fields. These configurations should lead to +// success at config validation time and the filter chains which contains +// unsupported match fields will be skipped at lookup time. +func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + unspecifiedEntry := &destPrefixEntry{ + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + } + + tests := []struct { + desc string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + desc: "unsupported destination port", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "good-chain", + Filters: emptyValidNetworkFilters, + }, + { + Name: "unsupported-destination-port", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + DestinationPort: &wrapperspb.UInt32Value{Value: 666}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: unspecifiedEntry, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "unsupported server names", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "good-chain", + Filters: emptyValidNetworkFilters, + }, + { + Name: "unsupported-server-names", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + ServerNames: []string{"example-server"}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: unspecifiedEntry, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "unsupported transport protocol", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "good-chain", + Filters: emptyValidNetworkFilters, + }, + { + Name: "unsupported-transport-protocol", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + TransportProtocol: "tls", + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: unspecifiedEntry, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "unsupported application protocol", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "good-chain", + Filters: emptyValidNetworkFilters, + }, + { + Name: "unsupported-application-protocol", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + ApplicationProtocols: []string{"h2"}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: unspecifiedEntry, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpopts.EquateEmpty()) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +// TestNewFilterChainImpl_Success_AllCombinations verifies different +// combinations of the supported match criteria. +func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + desc string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + desc: "multiple destination prefixes", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + // Unspecified destination prefix. + FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, + }, + { + // v4 wildcard destination prefix. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("0.0.0.0", 0)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + Filters: emptyValidNetworkFilters, + }, + { + // v6 wildcard destination prefix. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("::", 0)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}}, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "0.0.0.0/0": { + net: ipNetFromCIDR("0.0.0.0/0"), + srcTypeArr: [3]*sourcePrefixes{ + nil, + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "::/0": { + net: ipNetFromCIDR("::/0"), + srcTypeArr: [3]*sourcePrefixes{ + nil, + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "10.0.0.0/8": { + net: ipNetFromCIDR("10.0.0.0/8"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "multiple source types", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + nil, + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "multiple source prefixes", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + "10.0.0.0/8": { + net: ipNetFromCIDR("10.0.0.0/8"), + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.0.0/16"), + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "multiple source ports", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + SourcePorts: []uint32{1, 2, 3}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 1: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + 2: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + 3: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + nil, + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.0.0/16"), + srcPortMap: map[int]*FilterChain{ + 1: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + 2: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + 3: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "some chains have unsupported fields", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}, + TransportProtocol: "raw_buffer", + }, + Filters: emptyValidNetworkFilters, + }, + { + // This chain will be dropped in favor of the above + // filter chain because they both have the same + // destination prefix, but this one has an empty + // transport protocol while the above chain has the more + // preferred "raw_buffer". + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}, + TransportProtocol: "", + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + { + // This chain will be dropped for unsupported server + // names. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.1", 32)}, + ServerNames: []string{"foo", "bar"}, + }, + Filters: emptyValidNetworkFilters, + }, + { + // This chain will be dropped for unsupported transport + // protocol. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.2", 32)}, + TransportProtocol: "not-raw-buffer", + }, + Filters: emptyValidNetworkFilters, + }, + { + // This chain will be dropped for unsupported + // application protocol. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.3", 32)}, + ApplicationProtocols: []string{"h2"}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "10.0.0.0/8": { + net: ipNetFromCIDR("10.0.0.0/8"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.100.1/32": { + net: ipNetFromCIDR("192.168.100.1/32"), + srcTypeArr: [3]*sourcePrefixes{}, + }, + "192.168.100.2/32": { + net: ipNetFromCIDR("192.168.100.2/32"), + srcTypeArr: [3]*sourcePrefixes{}, + }, + "192.168.100.3/32": { + net: ipNetFromCIDR("192.168.100.3/32"), + srcTypeArr: [3]*sourcePrefixes{}, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{})) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +func TestLookup_Failures(t *testing.T) { + tests := []struct { + desc string + lis *v3listenerpb.Listener + params FilterChainLookupParams + wantErr string + }{ + { + desc: "no destination prefix match", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(10, 1, 1, 1), + }, + wantErr: "no matching filter chain based on destination prefix match", + }, + { + desc: "no source type match", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 2), + }, + wantErr: "no matching filter chain based on source type match", + }, + { + desc: "no source prefix match", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 24)}, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + }, + wantErr: "no matching filter chain after all match criteria", + }, + { + desc: "multiple matching filter chains", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourcePorts: []uint32{1}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + // IsUnspecified is not set. This means that the destination + // prefix matchers will be ignored. + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + SourcePort: 1, + }, + wantErr: "multiple matching filter chains", + }, + { + desc: "no default filter chain", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3}}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + SourcePort: 80, + }, + wantErr: "no matching filter chain after all match criteria", + }, + { + desc: "most specific match dropped for unsupported field", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + // This chain will be picked in the destination prefix + // stage, but will be dropped at the server names stage. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.1", 32)}, + ServerNames: []string{"foo"}, + }, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.0", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + SourcePort: 80, + }, + wantErr: "no matching filter chain based on source type match", + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + fci, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() failed: %v", err) + } + fc, err := fci.Lookup(test.params) + if err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("FilterChainManager.Lookup(%v) = (%v, %v) want (nil, %s)", test.params, fc, err, test.wantErr) + } + }) + } +} + +func TestLookup_Successes(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + lisWithDefaultChain := &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{InstanceName: "instance1"}, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + // A default filter chain with an empty transport socket. + DefaultFilterChain: &v3listenerpb.FilterChain{ + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{InstanceName: "default"}, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + } + lisWithoutDefaultChain := &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: transportSocketWithInstanceName("unspecified-dest-and-source-prefix"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("0.0.0.0", 0)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("0.0.0.0", 0)}, + }, + TransportSocket: transportSocketWithInstanceName("wildcard-prefixes-v4"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("::", 0)}, + }, + TransportSocket: transportSocketWithInstanceName("wildcard-source-prefix-v6"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-unspecified-source-type"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 24)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-specific-source-type"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 24)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.92.1", 24)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-specific-source-type-specific-source-prefix"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 24)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.92.1", 24)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + SourcePorts: []uint32{80}, + }, + TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-specific-source-type-specific-source-prefix-specific-source-port"), + Filters: emptyValidNetworkFilters, + }, + }, + } + + tests := []struct { + desc string + lis *v3listenerpb.Listener + params FilterChainLookupParams + wantFC *FilterChain + }{ + { + desc: "default filter chain", + lis: lisWithDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(10, 1, 1, 1), + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "default"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "unspecified destination match", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.ParseIP("2001:68::db8"), + SourceAddr: net.IPv4(10, 1, 1, 1), + SourcePort: 1, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "unspecified-dest-and-source-prefix"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "wildcard destination match v4", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(10, 1, 1, 1), + SourceAddr: net.IPv4(10, 1, 1, 1), + SourcePort: 1, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "wildcard-prefixes-v4"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "wildcard source match v6", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.ParseIP("2001:68::1"), + SourceAddr: net.ParseIP("2001:68::2"), + SourcePort: 1, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "wildcard-source-prefix-v6"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "specific destination and wildcard source type match", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + SourcePort: 80, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "specific-destination-prefix-unspecified-source-type"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "specific destination and source type match", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 1, 1), + SourceAddr: net.IPv4(10, 1, 1, 1), + SourcePort: 80, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "specific-destination-prefix-specific-source-type"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "specific destination source type and source prefix", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 1, 1), + SourceAddr: net.IPv4(192, 168, 92, 100), + SourcePort: 70, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "specific-destination-prefix-specific-source-type-specific-source-prefix"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "specific destination source type source prefix and source port", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 1, 1), + SourceAddr: net.IPv4(192, 168, 92, 100), + SourcePort: 80, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "specific-destination-prefix-specific-source-type-specific-source-prefix-specific-source-port"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + fci, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() failed: %v", err) + } + gotFC, err := fci.Lookup(test.params) + if err != nil { + t.Fatalf("FilterChainManager.Lookup(%v) failed: %v", test.params, err) + } + if !cmp.Equal(gotFC, test.wantFC, cmpopts.EquateEmpty()) { + t.Fatalf("FilterChainManager.Lookup(%v) = %v, want %v", test.params, gotFC, test.wantFC) + } + }) + } +} + +type filterCfg struct { + httpfilter.FilterConfig + // Level is what differentiates top level filters ("top level") vs. second + // level ("virtual host level"), and third level ("route level"). + level string +} + +type filterBuilder struct { + httpfilter.Filter +} + +var _ httpfilter.ServerInterceptorBuilder = &filterBuilder{} + +func (fb *filterBuilder) BuildServerInterceptor(config httpfilter.FilterConfig, override httpfilter.FilterConfig) (iresolver.ServerInterceptor, error) { + var level string + level = config.(filterCfg).level + + if override != nil { + level = override.(filterCfg).level + } + return &serverInterceptor{level: level}, nil +} + +type serverInterceptor struct { + level string +} + +func (si *serverInterceptor) AllowRPC(context.Context) error { + return errors.New(si.level) +} + +func TestHTTPFilterInstantiation(t *testing.T) { + tests := []struct { + name string + filters []HTTPFilter + routeConfig RouteConfigUpdate + // A list of strings which will be built from iterating through the + // filters ["top level", "vh level", "route level", "route level"...] + // wantErrs is the list of error strings that will be constructed from + // the deterministic iteration through the vh list and route list. The + // error string will be determined by the level of config that the + // filter builder receives (i.e. top level, vs. virtual host level vs. + // route level). + wantErrs []string + }{ + { + name: "one http filter no overrides", + filters: []HTTPFilter{ + {Name: "server-interceptor", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + }, + }, + }, + }}, + wantErrs: []string{topLevel}, + }, + { + name: "one http filter vh override", + filters: []HTTPFilter{ + {Name: "server-interceptor", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + }, + }, + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor": filterCfg{level: vhLevel}, + }, + }, + }}, + wantErrs: []string{vhLevel}, + }, + { + name: "one http filter route override", + filters: []HTTPFilter{ + {Name: "server-interceptor", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor": filterCfg{level: rLevel}, + }, + }, + }, + }, + }}, + wantErrs: []string{rLevel}, + }, + // This tests the scenario where there are three http filters, and one + // gets overridden by route and one by virtual host. + { + name: "three http filters vh override route override", + filters: []HTTPFilter{ + {Name: "server-interceptor1", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + {Name: "server-interceptor2", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + {Name: "server-interceptor3", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor3": filterCfg{level: rLevel}, + }, + }, + }, + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor2": filterCfg{level: vhLevel}, + }, + }, + }}, + wantErrs: []string{topLevel, vhLevel, rLevel}, + }, + // This tests the scenario where there are three http filters, and two + // virtual hosts with different vh + route overrides for each virtual + // host. + { + name: "three http filters two vh", + filters: []HTTPFilter{ + {Name: "server-interceptor1", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + {Name: "server-interceptor2", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + {Name: "server-interceptor3", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor3": filterCfg{level: rLevel}, + }, + }, + }, + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor2": filterCfg{level: vhLevel}, + }, + }, + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor1": filterCfg{level: rLevel}, + "server-interceptor2": filterCfg{level: rLevel}, + }, + }, + }, + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor2": filterCfg{level: vhLevel}, + "server-interceptor3": filterCfg{level: vhLevel}, + }, + }, + }}, + wantErrs: []string{topLevel, vhLevel, rLevel, rLevel, rLevel, vhLevel}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + fc := FilterChain{ + HTTPFilters: test.filters, + } + vhswi, err := fc.ConstructUsableRouteConfiguration(test.routeConfig) + if err != nil { + t.Fatalf("Error constructing usable route configuration: %v", err) + } + // Build out list of errors by iterating through the virtual hosts and routes, + // and running the filters in route configurations. + var errs []string + for _, vh := range vhswi { + for _, r := range vh.Routes { + for _, int := range r.Interceptors { + errs = append(errs, int.AllowRPC(context.Background()).Error()) + } + } + } + if !cmp.Equal(errs, test.wantErrs) { + t.Fatalf("List of errors %v, want %v", errs, test.wantErrs) + } + }) + } +} + +// The Equal() methods defined below help with using cmp.Equal() on these types +// which contain all unexported fields. + +func (fci *FilterChainManager) Equal(other *FilterChainManager) bool { + if (fci == nil) != (other == nil) { + return false + } + if fci == nil { + return true + } + switch { + case !cmp.Equal(fci.dstPrefixMap, other.dstPrefixMap, cmpopts.EquateEmpty()): + return false + // TODO: Support comparing dstPrefixes slice? + case !cmp.Equal(fci.def, other.def, cmpopts.EquateEmpty(), protocmp.Transform()): + return false + case !cmp.Equal(fci.RouteConfigNames, other.RouteConfigNames, cmpopts.EquateEmpty()): + return false + } + return true +} + +func (dpe *destPrefixEntry) Equal(other *destPrefixEntry) bool { + if (dpe == nil) != (other == nil) { + return false + } + if dpe == nil { + return true + } + if !cmp.Equal(dpe.net, other.net) { + return false + } + for i, st := range dpe.srcTypeArr { + if !cmp.Equal(st, other.srcTypeArr[i], cmpopts.EquateEmpty()) { + return false + } + } + return true +} + +func (sp *sourcePrefixes) Equal(other *sourcePrefixes) bool { + if (sp == nil) != (other == nil) { + return false + } + if sp == nil { + return true + } + // TODO: Support comparing srcPrefixes slice? + return cmp.Equal(sp.srcPrefixMap, other.srcPrefixMap, cmpopts.EquateEmpty()) +} + +func (spe *sourcePrefixEntry) Equal(other *sourcePrefixEntry) bool { + if (spe == nil) != (other == nil) { + return false + } + if spe == nil { + return true + } + switch { + case !cmp.Equal(spe.net, other.net): + return false + case !cmp.Equal(spe.srcPortMap, other.srcPortMap, cmpopts.EquateEmpty(), protocmp.Transform()): + return false + } + return true +} + +// The String() methods defined below help with debugging test failures as the +// regular %v or %+v formatting directives do not expands pointer fields inside +// structs, and these types have a lot of pointers pointing to other structs. +func (fci *FilterChainManager) String() string { + if fci == nil { + return "" + } + + var sb strings.Builder + if fci.dstPrefixMap != nil { + sb.WriteString("destination_prefix_map: map {\n") + for k, v := range fci.dstPrefixMap { + sb.WriteString(fmt.Sprintf("%q: %v\n", k, v)) + } + sb.WriteString("}\n") + } + if fci.dstPrefixes != nil { + sb.WriteString("destination_prefixes: [") + for _, p := range fci.dstPrefixes { + sb.WriteString(fmt.Sprintf("%v ", p)) + } + sb.WriteString("]") + } + if fci.def != nil { + sb.WriteString(fmt.Sprintf("default_filter_chain: %+v ", fci.def)) + } + return sb.String() +} + +func (dpe *destPrefixEntry) String() string { + if dpe == nil { + return "" + } + var sb strings.Builder + if dpe.net != nil { + sb.WriteString(fmt.Sprintf("destination_prefix: %s ", dpe.net.String())) + } + sb.WriteString("source_types_array: [") + for _, st := range dpe.srcTypeArr { + sb.WriteString(fmt.Sprintf("%v ", st)) + } + sb.WriteString("]") + return sb.String() +} + +func (sp *sourcePrefixes) String() string { + if sp == nil { + return "" + } + var sb strings.Builder + if sp.srcPrefixMap != nil { + sb.WriteString("source_prefix_map: map {") + for k, v := range sp.srcPrefixMap { + sb.WriteString(fmt.Sprintf("%q: %v ", k, v)) + } + sb.WriteString("}") + } + if sp.srcPrefixes != nil { + sb.WriteString("source_prefixes: [") + for _, p := range sp.srcPrefixes { + sb.WriteString(fmt.Sprintf("%v ", p)) + } + sb.WriteString("]") + } + return sb.String() +} + +func (spe *sourcePrefixEntry) String() string { + if spe == nil { + return "" + } + var sb strings.Builder + if spe.net != nil { + sb.WriteString(fmt.Sprintf("source_prefix: %s ", spe.net.String())) + } + if spe.srcPortMap != nil { + sb.WriteString("source_ports_map: map {") + for k, v := range spe.srcPortMap { + sb.WriteString(fmt.Sprintf("%d: %+v ", k, v)) + } + sb.WriteString("}") + } + return sb.String() +} + +func (f *FilterChain) String() string { + if f == nil || f.SecurityCfg == nil { + return "" + } + return fmt.Sprintf("security_config: %v", f.SecurityCfg) +} + +func ipNetFromCIDR(cidr string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + panic(err) + } + return ipnet +} + +func transportSocketWithInstanceName(name string) *v3corepb.TransportSocket { + return &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{InstanceName: name}, + }, + }), + }, + } +} + +func cidrRangeFromAddressAndPrefixLen(address string, len int) *v3corepb.CidrRange { + return &v3corepb.CidrRange{ + AddressPrefix: address, + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(len), + }, + } +} diff --git a/xds/internal/xdsclient/lds_test.go b/xds/internal/xdsclient/lds_test.go new file mode 100644 index 00000000000..18e2f55ede4 --- /dev/null +++ b/xds/internal/xdsclient/lds_test.go @@ -0,0 +1,1944 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "fmt" + "strings" + "testing" + "time" + + v1typepb "github.com/cncf/udpa/go/udpa/type/v1" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + "github.com/golang/protobuf/proto" + spb "github.com/golang/protobuf/ptypes/struct" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/durationpb" + + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/httpfilter" + _ "google.golang.org/grpc/xds/internal/httpfilter/router" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/grpc/xds/internal/version" + + v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" + v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v2httppb "github.com/envoyproxy/go-control-plane/envoy/config/filter/network/http_connection_manager/v2" + v2listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v2" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + anypb "github.com/golang/protobuf/ptypes/any" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" +) + +func (s) TestUnmarshalListener_ClientSide(t *testing.T) { + const ( + v2LDSTarget = "lds.target.good:2222" + v3LDSTarget = "lds.target.good:3333" + v2RouteConfigName = "v2RouteConfig" + v3RouteConfigName = "v3RouteConfig" + routeName = "routeName" + testVersion = "test-version-lds-client" + ) + + var ( + v2Lis = testutils.MarshalAny(&v2xdspb.Listener{ + Name: v2LDSTarget, + ApiListener: &v2listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v2httppb.HttpConnectionManager{ + RouteSpecifier: &v2httppb.HttpConnectionManager_Rds{ + Rds: &v2httppb.Rds{ + ConfigSource: &v2corepb.ConfigSource{ + ConfigSourceSpecifier: &v2corepb.ConfigSource_Ads{Ads: &v2corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: v2RouteConfigName, + }, + }, + }), + }, + }) + customFilter = &v3httppb.HttpFilter{ + Name: "customFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + } + typedStructFilter = &v3httppb.HttpFilter{ + Name: "customFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: wrappedCustomFilterTypedStructConfig}, + } + customOptionalFilter = &v3httppb.HttpFilter{ + Name: "customFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + IsOptional: true, + } + customFilter2 = &v3httppb.HttpFilter{ + Name: "customFilter2", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + } + errFilter = &v3httppb.HttpFilter{ + Name: "errFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: errFilterConfig}, + } + errOptionalFilter = &v3httppb.HttpFilter{ + Name: "errFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: errFilterConfig}, + IsOptional: true, + } + clientOnlyCustomFilter = &v3httppb.HttpFilter{ + Name: "clientOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, + } + serverOnlyCustomFilter = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } + serverOnlyOptionalCustomFilter = &v3httppb.HttpFilter{ + Name: "serverOnlyOptionalCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + IsOptional: true, + } + unknownFilter = &v3httppb.HttpFilter{ + Name: "unknownFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: unknownFilterConfig}, + } + unknownOptionalFilter = &v3httppb.HttpFilter{ + Name: "unknownFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: unknownFilterConfig}, + IsOptional: true, + } + v3LisWithInlineRoute = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: routeName, + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{v3LDSTarget}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}, + }}}}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + }), + }, + }) + v3LisWithFilters = func(fs ...*v3httppb.HttpFilter) *anypb.Any { + fs = append(fs, emptyRouterFilter) + return testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny( + &v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: v3RouteConfigName, + }, + }, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + HttpFilters: fs, + }), + }, + }) + } + v3LisToTestRBAC = func(xffNumTrustedHops uint32, originalIpDetectionExtensions []*v3corepb.TypedExtensionConfig) *anypb.Any { + return testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny( + &v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: v3RouteConfigName, + }, + }, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + XffNumTrustedHops: xffNumTrustedHops, + OriginalIpDetectionExtensions: originalIpDetectionExtensions, + }), + }, + }) + } + errMD = UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + } + ) + + tests := []struct { + name string + resources []*anypb.Any + wantUpdate map[string]ListenerUpdateErrTuple + wantMD UpdateMetadata + wantErr bool + }{ + { + name: "non-listener resource", + resources: []*anypb.Any{{TypeUrl: version.V3HTTPConnManagerURL}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "badly marshaled listener resource", + resources: []*anypb.Any{ + { + TypeUrl: version.V3ListenerURL, + Value: func() []byte { + lis := &v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: &anypb.Any{ + TypeUrl: version.V3HTTPConnManagerURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + } + mLis, _ := proto.Marshal(lis) + return mLis + }(), + }, + }, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "wrong type in apiListener", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v2xdspb.Listener{}), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "empty httpConnMgr in apiListener", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{}, + }, + }), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "scopedRoutes routeConfig in apiListener", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, + }), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "rds.ConfigSource in apiListener is not ADS", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Path{ + Path: "/some/path", + }, + }, + RouteConfigName: v3RouteConfigName, + }, + }, + }), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "empty resource list", + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with no filters", + resources: []*anypb.Any{v3LisWithFilters()}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, HTTPFilters: routerFilterList, Raw: v3LisWithFilters()}}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 no terminal filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny( + &v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: v3RouteConfigName, + }, + }, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + }), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with custom filter", + resources: []*anypb.Any{v3LisWithFilters(customFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{ + { + Name: "customFilter", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterConfig}, + }, + routerFilter, + }, + Raw: v3LisWithFilters(customFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with custom filter in typed struct", + resources: []*anypb.Any{v3LisWithFilters(typedStructFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{ + { + Name: "customFilter", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterTypedStructConfig}, + }, + routerFilter, + }, + Raw: v3LisWithFilters(typedStructFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with optional custom filter", + resources: []*anypb.Any{v3LisWithFilters(customOptionalFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{ + { + Name: "customFilter", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterConfig}, + }, + routerFilter, + }, + Raw: v3LisWithFilters(customOptionalFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with two filters with same name", + resources: []*anypb.Any{v3LisWithFilters(customFilter, customFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with two filters - same type different name", + resources: []*anypb.Any{v3LisWithFilters(customFilter, customFilter2)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{{ + Name: "customFilter", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterConfig}, + }, { + Name: "customFilter2", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterConfig}, + }, + routerFilter, + }, + Raw: v3LisWithFilters(customFilter, customFilter2), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with server-only filter", + resources: []*anypb.Any{v3LisWithFilters(serverOnlyCustomFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with optional server-only filter", + resources: []*anypb.Any{v3LisWithFilters(serverOnlyOptionalCustomFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, + MaxStreamDuration: time.Second, + Raw: v3LisWithFilters(serverOnlyOptionalCustomFilter), + HTTPFilters: routerFilterList, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with client-only filter", + resources: []*anypb.Any{v3LisWithFilters(clientOnlyCustomFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{ + { + Name: "clientOnlyCustomFilter", + Filter: clientOnlyHTTPFilter{}, + Config: filterConfig{Cfg: clientOnlyCustomFilterConfig}, + }, + routerFilter}, + Raw: v3LisWithFilters(clientOnlyCustomFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with err filter", + resources: []*anypb.Any{v3LisWithFilters(errFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with optional err filter", + resources: []*anypb.Any{v3LisWithFilters(errOptionalFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with unknown filter", + resources: []*anypb.Any{v3LisWithFilters(unknownFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with unknown filter (optional)", + resources: []*anypb.Any{v3LisWithFilters(unknownOptionalFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, + MaxStreamDuration: time.Second, + HTTPFilters: routerFilterList, + Raw: v3LisWithFilters(unknownOptionalFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v2 listener resource", + resources: []*anypb.Any{v2Lis}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v2LDSTarget: {Update: ListenerUpdate{RouteConfigName: v2RouteConfigName, Raw: v2Lis}}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 listener resource", + resources: []*anypb.Any{v3LisWithFilters()}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, HTTPFilters: routerFilterList, Raw: v3LisWithFilters()}}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + // "To allow equating RBAC's direct_remote_ip and + // remote_ip...HttpConnectionManager.xff_num_trusted_hops must be unset + // or zero and HttpConnectionManager.original_ip_detection_extensions + // must be empty." - A41 + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-valid", + resources: []*anypb.Any{v3LisToTestRBAC(0, nil)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, + MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{routerFilter}, + Raw: v3LisToTestRBAC(0, nil), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + // In order to support xDS Configured RBAC HTTPFilter equating direct + // remote ip and remote ip, xffNumTrustedHops cannot be greater than + // zero. This is because if you can trust a ingress proxy hop when + // determining an origin clients ip address, direct remote ip != remote + // ip. + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-invalid-num-untrusted-hops", + resources: []*anypb.Any{v3LisToTestRBAC(1, nil)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + // In order to support xDS Configured RBAC HTTPFilter equating direct + // remote ip and remote ip, originalIpDetectionExtensions must be empty. + // This is because if you have to ask ip-detection-extension for the + // original ip, direct remote ip might not equal remote ip. + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-invalid-original-ip-detection-extension", + resources: []*anypb.Any{v3LisToTestRBAC(0, []*v3corepb.TypedExtensionConfig{{Name: "something"}})}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 listener with inline route configuration", + resources: []*anypb.Any{v3LisWithInlineRoute}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InlineRouteConfig: &RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{{ + Domains: []string{v3LDSTarget}, + Routes: []*Route{{Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, RouteAction: RouteActionRoute}}, + }}}, + MaxStreamDuration: time.Second, + Raw: v3LisWithInlineRoute, + HTTPFilters: routerFilterList, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "multiple listener resources", + resources: []*anypb.Any{v2Lis, v3LisWithFilters()}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v2LDSTarget: {Update: ListenerUpdate{RouteConfigName: v2RouteConfigName, Raw: v2Lis}}, + v3LDSTarget: {Update: ListenerUpdate{RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters(), HTTPFilters: routerFilterList}}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + // To test that unmarshal keeps processing on errors. + name: "good and bad listener resources", + resources: []*anypb.Any{ + v2Lis, + testutils.MarshalAny(&v3listenerpb.Listener{ + Name: "bad", + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, + }), + }}), + v3LisWithFilters(), + }, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v2LDSTarget: {Update: ListenerUpdate{RouteConfigName: v2RouteConfigName, Raw: v2Lis}}, + v3LDSTarget: {Update: ListenerUpdate{RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters(), HTTPFilters: routerFilterList}}, + "bad": {Err: cmpopts.AnyError}, + }, + wantMD: errMD, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + update, md, err := UnmarshalListener(opts) + if (err != nil) != test.wantErr { + t.Fatalf("UnmarshalListener(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) + } + if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { + t.Errorf("got unexpected update, diff (-got +want): %v", diff) + } + if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { + t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) + } + }) + } +} + +func (s) TestUnmarshalListener_ServerSide(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + const ( + v3LDSTarget = "grpc/server?xds.resource.listening_address=0.0.0.0:9999" + testVersion = "test-version-lds-server" + ) + + var ( + serverOnlyCustomFilter = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } + routeConfig = &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}} + inlineRouteConfig = &RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*Route{{Prefix: newStringP("/"), RouteAction: RouteActionNonForwardingAction}}, + }}} + emptyValidNetworkFilters = []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + } + localSocketAddress = &v3corepb.Address{ + Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Address: "0.0.0.0", + PortSpecifier: &v3corepb.SocketAddress_PortValue{ + PortValue: 9999, + }, + }, + }, + } + listenerEmptyTransportSocket = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + }, + }, + }) + listenerNoValidationContextDeprecatedFields = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + }, + }), + }, + }, + }, + }) + listenerNoValidationContextNewFields = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + }, + }), + }, + }, + }, + }) + listenerWithValidationContextDeprecatedFields = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + }, + }, + }), + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultRootPluginInstance", + CertificateName: "defaultRootCertName", + }, + }, + }, + }), + }, + }, + }, + }) + listenerWithValidationContextNewFields = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + }, + }, + }, + }), + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "defaultRootPluginInstance", + CertificateName: "defaultRootCertName", + }, + }, + }, + }, + }, + }), + }, + }, + }, + }) + errMD = UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + } + ) + v3LisToTestRBAC := func(xffNumTrustedHops uint32, originalIpDetectionExtensions []*v3corepb.TypedExtensionConfig) *anypb.Any { + return testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + XffNumTrustedHops: xffNumTrustedHops, + OriginalIpDetectionExtensions: originalIpDetectionExtensions, + }), + }, + }, + }, + }, + }, + }) + } + + tests := []struct { + name string + resources []*anypb.Any + wantUpdate map[string]ListenerUpdateErrTuple + wantMD UpdateMetadata + wantErr string + }{ + { + name: "non-empty listener filters", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ListenerFilters: []*v3listenerpb.ListenerFilter{ + {Name: "listener-filter-1"}, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "unsupported field 'listener_filters'", + }, + { + name: "use_original_dst is set", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + UseOriginalDst: &wrapperspb.BoolValue{Value: true}, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "unsupported field 'use_original_dst'", + }, + { + name: "no address field", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{Name: v3LDSTarget})}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "no address field in LDS response", + }, + { + name: "no socket address field", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: &v3corepb.Address{}, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "no socket_address field in LDS response", + }, + { + name: "no filter chains and no default filter chain", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{DestinationPort: &wrapperspb.UInt32Value{Value: 666}}, + Filters: emptyValidNetworkFilters, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "no supported filter chains and no default filter chain", + }, + { + name: "missing http connection manager network filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "missing HttpConnectionManager filter", + }, + { + name: "missing filter name in http filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{}), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "missing name field in filter", + }, + { + name: "duplicate filter names in http filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "duplicate filter name", + }, + { + name: "no terminal filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "http filters list is empty", + }, + { + name: "terminal filter not last", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter, serverOnlyCustomFilter}, + }), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "is a terminal filter but it is not last in the filter chain", + }, + { + name: "last not terminal filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{serverOnlyCustomFilter}, + }), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "is not a terminal filter", + }, + { + name: "unsupported oneof in typed config of http filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_ConfigDiscovery{}, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "unsupported config_type", + }, + { + name: "overlapping filter chain match criteria", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3, 4, 5}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{5, 6, 7}}, + Filters: emptyValidNetworkFilters, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "multiple filter chains with overlapping matching rules are defined", + }, + { + name: "unsupported network filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.LocalReplyConfig{}), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "unsupported network filter", + }, + { + name: "badly marshaled network filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3HTTPConnManagerURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "failed unmarshaling of network filter", + }, + { + name: "unexpected transport socket name", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "unsupported-transport-socket-name", + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "transport_socket field has unexpected name", + }, + { + name: "unexpected transport socket typedConfig URL", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{}), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "transport_socket field has unexpected typeURL", + }, + { + name: "badly marshaled transport socket", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3DownstreamTLSContextURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "failed to unmarshal DownstreamTlsContext in LDS response", + }, + { + name: "missing CommonTlsContext", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{}), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "DownstreamTlsContext in LDS response does not contain a CommonTlsContext", + }, + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-valid", + resources: []*anypb.Any{v3LisToTestRBAC(0, nil)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Raw: listenerEmptyTransportSocket, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-invalid-num-untrusted-hops", + resources: []*anypb.Any{v3LisToTestRBAC(1, nil)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "xff_num_trusted_hops must be unset or zero", + }, + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-invalid-original-ip-detection-extension", + resources: []*anypb.Any{v3LisToTestRBAC(0, []*v3corepb.TypedExtensionConfig{{Name: "something"}})}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "original_ip_detection_extensions must be empty", + }, + { + name: "unsupported validation context in transport socket", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + }), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "validation context contains unexpected type", + }, + { + name: "empty transport socket", + resources: []*anypb.Any{listenerEmptyTransportSocket}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Raw: listenerEmptyTransportSocket, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "no identity and root certificate providers using deprecated fields", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set", + }, + { + name: "no identity and root certificate providers using new fields", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set", + }, + { + name: "no identity certificate provider with require_client_cert", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{}, + }), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "security configuration on the server-side does not contain identity certificate provider instance name", + }, + { + name: "happy case with no validation context using deprecated fields", + resources: []*anypb.Any{listenerNoValidationContextDeprecatedFields}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + Raw: listenerNoValidationContextDeprecatedFields, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "happy case with no validation context using new fields", + resources: []*anypb.Any{listenerNoValidationContextNewFields}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + Raw: listenerNoValidationContextNewFields, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "happy case with validation context provider instance with deprecated fields", + resources: []*anypb.Any{listenerWithValidationContextDeprecatedFields}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + RootInstanceName: "rootPluginInstance", + RootCertName: "rootCertName", + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + RootInstanceName: "defaultRootPluginInstance", + RootCertName: "defaultRootCertName", + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + Raw: listenerWithValidationContextDeprecatedFields, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "happy case with validation context provider instance with new fields", + resources: []*anypb.Any{listenerWithValidationContextNewFields}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + RootInstanceName: "rootPluginInstance", + RootCertName: "rootCertName", + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + RootInstanceName: "defaultRootPluginInstance", + RootCertName: "defaultRootCertName", + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + Raw: listenerWithValidationContextNewFields, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + gotUpdate, md, err := UnmarshalListener(opts) + if (err != nil) != (test.wantErr != "") { + t.Fatalf("UnmarshalListener(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) + } + if err != nil && !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("UnmarshalListener(%+v) = %v wantErr: %q", opts, err, test.wantErr) + } + if diff := cmp.Diff(gotUpdate, test.wantUpdate, cmpOpts); diff != "" { + t.Errorf("got unexpected update, diff (-got +want): %v", diff) + } + if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { + t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) + } + }) + } +} + +type filterConfig struct { + httpfilter.FilterConfig + Cfg proto.Message + Override proto.Message +} + +// httpFilter allows testing the http filter registry and parsing functionality. +type httpFilter struct { + httpfilter.ClientInterceptorBuilder + httpfilter.ServerInterceptorBuilder +} + +func (httpFilter) TypeURLs() []string { return []string{"custom.filter"} } + +func (httpFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Cfg: cfg}, nil +} + +func (httpFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Override: override}, nil +} + +func (httpFilter) IsTerminal() bool { + return false +} + +// errHTTPFilter returns errors no matter what is passed to ParseFilterConfig. +type errHTTPFilter struct { + httpfilter.ClientInterceptorBuilder +} + +func (errHTTPFilter) TypeURLs() []string { return []string{"err.custom.filter"} } + +func (errHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return nil, fmt.Errorf("error from ParseFilterConfig") +} + +func (errHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return nil, fmt.Errorf("error from ParseFilterConfigOverride") +} + +func (errHTTPFilter) IsTerminal() bool { + return false +} + +func init() { + httpfilter.Register(httpFilter{}) + httpfilter.Register(errHTTPFilter{}) + httpfilter.Register(serverOnlyHTTPFilter{}) + httpfilter.Register(clientOnlyHTTPFilter{}) +} + +// serverOnlyHTTPFilter does not implement ClientInterceptorBuilder +type serverOnlyHTTPFilter struct { + httpfilter.ServerInterceptorBuilder +} + +func (serverOnlyHTTPFilter) TypeURLs() []string { return []string{"serverOnly.custom.filter"} } + +func (serverOnlyHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Cfg: cfg}, nil +} + +func (serverOnlyHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Override: override}, nil +} + +func (serverOnlyHTTPFilter) IsTerminal() bool { + return false +} + +// clientOnlyHTTPFilter does not implement ServerInterceptorBuilder +type clientOnlyHTTPFilter struct { + httpfilter.ClientInterceptorBuilder +} + +func (clientOnlyHTTPFilter) TypeURLs() []string { return []string{"clientOnly.custom.filter"} } + +func (clientOnlyHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Cfg: cfg}, nil +} + +func (clientOnlyHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Override: override}, nil +} + +func (clientOnlyHTTPFilter) IsTerminal() bool { + return false +} + +var customFilterConfig = &anypb.Any{ + TypeUrl: "custom.filter", + Value: []byte{1, 2, 3}, +} + +var errFilterConfig = &anypb.Any{ + TypeUrl: "err.custom.filter", + Value: []byte{1, 2, 3}, +} + +var serverOnlyCustomFilterConfig = &anypb.Any{ + TypeUrl: "serverOnly.custom.filter", + Value: []byte{1, 2, 3}, +} + +var clientOnlyCustomFilterConfig = &anypb.Any{ + TypeUrl: "clientOnly.custom.filter", + Value: []byte{1, 2, 3}, +} + +var customFilterTypedStructConfig = &v1typepb.TypedStruct{ + TypeUrl: "custom.filter", + Value: &spb.Struct{ + Fields: map[string]*spb.Value{ + "foo": {Kind: &spb.Value_StringValue{StringValue: "bar"}}, + }, + }, +} +var wrappedCustomFilterTypedStructConfig *anypb.Any + +func init() { + wrappedCustomFilterTypedStructConfig = testutils.MarshalAny(customFilterTypedStructConfig) +} + +var unknownFilterConfig = &anypb.Any{ + TypeUrl: "unknown.custom.filter", + Value: []byte{1, 2, 3}, +} + +func wrappedOptionalFilter(name string) *anypb.Any { + return testutils.MarshalAny(&v3routepb.FilterConfig{ + IsOptional: true, + Config: &anypb.Any{ + TypeUrl: name, + Value: []byte{1, 2, 3}, + }, + }) +} diff --git a/xds/internal/client/load/reporter.go b/xds/internal/xdsclient/load/reporter.go similarity index 100% rename from xds/internal/client/load/reporter.go rename to xds/internal/xdsclient/load/reporter.go diff --git a/xds/internal/client/load/store.go b/xds/internal/xdsclient/load/store.go similarity index 100% rename from xds/internal/client/load/store.go rename to xds/internal/xdsclient/load/store.go diff --git a/xds/internal/client/load/store_test.go b/xds/internal/xdsclient/load/store_test.go similarity index 100% rename from xds/internal/client/load/store_test.go rename to xds/internal/xdsclient/load/store_test.go diff --git a/xds/internal/client/loadreport.go b/xds/internal/xdsclient/loadreport.go similarity index 98% rename from xds/internal/client/loadreport.go rename to xds/internal/xdsclient/loadreport.go index be42a6e0c38..32a71dada7f 100644 --- a/xds/internal/client/loadreport.go +++ b/xds/internal/xdsclient/loadreport.go @@ -15,13 +15,13 @@ * limitations under the License. */ -package client +package xdsclient import ( "context" "google.golang.org/grpc" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // ReportLoad starts an load reporting stream to the given server. If the server diff --git a/xds/internal/client/tests/loadreport_test.go b/xds/internal/xdsclient/loadreport_test.go similarity index 94% rename from xds/internal/client/tests/loadreport_test.go rename to xds/internal/xdsclient/loadreport_test.go index af145e7f2a9..88a08eb43fd 100644 --- a/xds/internal/client/tests/loadreport_test.go +++ b/xds/internal/xdsclient/loadreport_test.go @@ -16,7 +16,7 @@ * */ -package tests_test +package xdsclient_test import ( "context" @@ -32,13 +32,13 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" - "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/testing/protocmp" - _ "google.golang.org/grpc/xds/internal/client/v2" // Register the v2 xDS API client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register the v2 xDS API client. ) const ( @@ -54,7 +54,7 @@ func (s) TestLRSClient(t *testing.T) { } defer sCleanup() - xdsC, err := client.NewWithConfigForTesting(&bootstrap.Config{ + xdsC, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ BalancerName: fs.Address, Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), NodeProto: &v2corepb.Node{}, diff --git a/xds/internal/client/logging.go b/xds/internal/xdsclient/logging.go similarity index 98% rename from xds/internal/client/logging.go rename to xds/internal/xdsclient/logging.go index bff3fb1d3df..e28ea0d0410 100644 --- a/xds/internal/client/logging.go +++ b/xds/internal/xdsclient/logging.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "fmt" diff --git a/xds/internal/xdsclient/matcher.go b/xds/internal/xdsclient/matcher.go new file mode 100644 index 00000000000..e663e02769f --- /dev/null +++ b/xds/internal/xdsclient/matcher.go @@ -0,0 +1,278 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "fmt" + "strings" + + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpcutil" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/xds/matcher" + "google.golang.org/grpc/metadata" +) + +// RouteToMatcher converts a route to a Matcher to match incoming RPC's against. +func RouteToMatcher(r *Route) (*CompositeMatcher, error) { + var pm pathMatcher + switch { + case r.Regex != nil: + pm = newPathRegexMatcher(r.Regex) + case r.Path != nil: + pm = newPathExactMatcher(*r.Path, r.CaseInsensitive) + case r.Prefix != nil: + pm = newPathPrefixMatcher(*r.Prefix, r.CaseInsensitive) + default: + return nil, fmt.Errorf("illegal route: missing path_matcher") + } + + headerMatchers := make([]matcher.HeaderMatcher, 0, len(r.Headers)) + for _, h := range r.Headers { + var matcherT matcher.HeaderMatcher + switch { + case h.ExactMatch != nil && *h.ExactMatch != "": + matcherT = matcher.NewHeaderExactMatcher(h.Name, *h.ExactMatch) + case h.RegexMatch != nil: + matcherT = matcher.NewHeaderRegexMatcher(h.Name, h.RegexMatch) + case h.PrefixMatch != nil && *h.PrefixMatch != "": + matcherT = matcher.NewHeaderPrefixMatcher(h.Name, *h.PrefixMatch) + case h.SuffixMatch != nil && *h.SuffixMatch != "": + matcherT = matcher.NewHeaderSuffixMatcher(h.Name, *h.SuffixMatch) + case h.RangeMatch != nil: + matcherT = matcher.NewHeaderRangeMatcher(h.Name, h.RangeMatch.Start, h.RangeMatch.End) + case h.PresentMatch != nil: + matcherT = matcher.NewHeaderPresentMatcher(h.Name, *h.PresentMatch) + default: + return nil, fmt.Errorf("illegal route: missing header_match_specifier") + } + if h.InvertMatch != nil && *h.InvertMatch { + matcherT = matcher.NewInvertMatcher(matcherT) + } + headerMatchers = append(headerMatchers, matcherT) + } + + var fractionMatcher *fractionMatcher + if r.Fraction != nil { + fractionMatcher = newFractionMatcher(*r.Fraction) + } + return newCompositeMatcher(pm, headerMatchers, fractionMatcher), nil +} + +// CompositeMatcher is a matcher that holds onto many matchers and aggregates +// the matching results. +type CompositeMatcher struct { + pm pathMatcher + hms []matcher.HeaderMatcher + fm *fractionMatcher +} + +func newCompositeMatcher(pm pathMatcher, hms []matcher.HeaderMatcher, fm *fractionMatcher) *CompositeMatcher { + return &CompositeMatcher{pm: pm, hms: hms, fm: fm} +} + +// Match returns true if all matchers return true. +func (a *CompositeMatcher) Match(info iresolver.RPCInfo) bool { + if a.pm != nil && !a.pm.match(info.Method) { + return false + } + + // Call headerMatchers even if md is nil, because routes may match + // non-presence of some headers. + var md metadata.MD + if info.Context != nil { + md, _ = metadata.FromOutgoingContext(info.Context) + if extraMD, ok := grpcutil.ExtraMetadata(info.Context); ok { + md = metadata.Join(md, extraMD) + // Remove all binary headers. They are hard to match with. May need + // to add back if asked by users. + for k := range md { + if strings.HasSuffix(k, "-bin") { + delete(md, k) + } + } + } + } + for _, m := range a.hms { + if !m.Match(md) { + return false + } + } + + if a.fm != nil && !a.fm.match() { + return false + } + return true +} + +func (a *CompositeMatcher) String() string { + var ret string + if a.pm != nil { + ret += a.pm.String() + } + for _, m := range a.hms { + ret += m.String() + } + if a.fm != nil { + ret += a.fm.String() + } + return ret +} + +type fractionMatcher struct { + fraction int64 // real fraction is fraction/1,000,000. +} + +func newFractionMatcher(fraction uint32) *fractionMatcher { + return &fractionMatcher{fraction: int64(fraction)} +} + +// RandInt63n overwrites grpcrand for control in tests. +var RandInt63n = grpcrand.Int63n + +func (fm *fractionMatcher) match() bool { + t := RandInt63n(1000000) + return t <= fm.fraction +} + +func (fm *fractionMatcher) String() string { + return fmt.Sprintf("fraction:%v", fm.fraction) +} + +type domainMatchType int + +const ( + domainMatchTypeInvalid domainMatchType = iota + domainMatchTypeUniversal + domainMatchTypePrefix + domainMatchTypeSuffix + domainMatchTypeExact +) + +// Exact > Suffix > Prefix > Universal > Invalid. +func (t domainMatchType) betterThan(b domainMatchType) bool { + return t > b +} + +func matchTypeForDomain(d string) domainMatchType { + if d == "" { + return domainMatchTypeInvalid + } + if d == "*" { + return domainMatchTypeUniversal + } + if strings.HasPrefix(d, "*") { + return domainMatchTypeSuffix + } + if strings.HasSuffix(d, "*") { + return domainMatchTypePrefix + } + if strings.Contains(d, "*") { + return domainMatchTypeInvalid + } + return domainMatchTypeExact +} + +func match(domain, host string) (domainMatchType, bool) { + switch typ := matchTypeForDomain(domain); typ { + case domainMatchTypeInvalid: + return typ, false + case domainMatchTypeUniversal: + return typ, true + case domainMatchTypePrefix: + // abc.* + return typ, strings.HasPrefix(host, strings.TrimSuffix(domain, "*")) + case domainMatchTypeSuffix: + // *.123 + return typ, strings.HasSuffix(host, strings.TrimPrefix(domain, "*")) + case domainMatchTypeExact: + return typ, domain == host + default: + return domainMatchTypeInvalid, false + } +} + +// FindBestMatchingVirtualHost returns the virtual host whose domains field best +// matches host +// +// The domains field support 4 different matching pattern types: +// - Exact match +// - Suffix match (e.g. “*ABC”) +// - Prefix match (e.g. “ABC*) +// - Universal match (e.g. “*”) +// +// The best match is defined as: +// - A match is better if it’s matching pattern type is better +// - Exact match > suffix match > prefix match > universal match +// - If two matches are of the same pattern type, the longer match is better +// - This is to compare the length of the matching pattern, e.g. “*ABCDE” > +// “*ABC” +func FindBestMatchingVirtualHost(host string, vHosts []*VirtualHost) *VirtualHost { // Maybe move this crap to client + var ( + matchVh *VirtualHost + matchType = domainMatchTypeInvalid + matchLen int + ) + for _, vh := range vHosts { + for _, domain := range vh.Domains { + typ, matched := match(domain, host) + if typ == domainMatchTypeInvalid { + // The rds response is invalid. + return nil + } + if matchType.betterThan(typ) || matchType == typ && matchLen >= len(domain) || !matched { + // The previous match has better type, or the previous match has + // better length, or this domain isn't a match. + continue + } + matchVh = vh + matchType = typ + matchLen = len(domain) + } + } + return matchVh +} + +// FindBestMatchingVirtualHostServer returns the virtual host whose domains field best +// matches authority. +func FindBestMatchingVirtualHostServer(authority string, vHosts []VirtualHostWithInterceptors) *VirtualHostWithInterceptors { + var ( + matchVh *VirtualHostWithInterceptors + matchType = domainMatchTypeInvalid + matchLen int + ) + for _, vh := range vHosts { + for _, domain := range vh.Domains { + typ, matched := match(domain, authority) + if typ == domainMatchTypeInvalid { + // The rds response is invalid. + return nil + } + if matchType.betterThan(typ) || matchType == typ && matchLen >= len(domain) || !matched { + // The previous match has better type, or the previous match has + // better length, or this domain isn't a match. + continue + } + matchVh = &vh + matchType = typ + matchLen = len(domain) + } + } + return matchVh +} diff --git a/xds/internal/resolver/matcher_path.go b/xds/internal/xdsclient/matcher_path.go similarity index 97% rename from xds/internal/resolver/matcher_path.go rename to xds/internal/xdsclient/matcher_path.go index 011d1a94c49..a00c6954ef5 100644 --- a/xds/internal/resolver/matcher_path.go +++ b/xds/internal/xdsclient/matcher_path.go @@ -16,14 +16,14 @@ * */ -package resolver +package xdsclient import ( "regexp" "strings" ) -type pathMatcherInterface interface { +type pathMatcher interface { match(path string) bool String() string } diff --git a/xds/internal/resolver/matcher_path_test.go b/xds/internal/xdsclient/matcher_path_test.go similarity index 99% rename from xds/internal/resolver/matcher_path_test.go rename to xds/internal/xdsclient/matcher_path_test.go index 263a049108e..a211034a60d 100644 --- a/xds/internal/resolver/matcher_path_test.go +++ b/xds/internal/xdsclient/matcher_path_test.go @@ -16,7 +16,7 @@ * */ -package resolver +package xdsclient import ( "regexp" diff --git a/xds/internal/resolver/matcher_test.go b/xds/internal/xdsclient/matcher_test.go similarity index 55% rename from xds/internal/resolver/matcher_test.go rename to xds/internal/xdsclient/matcher_test.go index 7657b87bf45..f750d07d6e4 100644 --- a/xds/internal/resolver/matcher_test.go +++ b/xds/internal/xdsclient/matcher_test.go @@ -16,7 +16,7 @@ * */ -package resolver +package xdsclient import ( "context" @@ -25,21 +25,22 @@ import ( "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/grpcutil" iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/metadata" ) func TestAndMatcherMatch(t *testing.T) { tests := []struct { name string - pm pathMatcherInterface - hm headerMatcherInterface + pm pathMatcher + hm matcher.HeaderMatcher info iresolver.RPCInfo want bool }{ { name: "both match", pm: newPathExactMatcher("/a/b", false), - hm: newHeaderExactMatcher("th", "tv"), + hm: matcher.NewHeaderExactMatcher("th", "tv"), info: iresolver.RPCInfo{ Method: "/a/b", Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), @@ -49,7 +50,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "both match with path case insensitive", pm: newPathExactMatcher("/A/B", true), - hm: newHeaderExactMatcher("th", "tv"), + hm: matcher.NewHeaderExactMatcher("th", "tv"), info: iresolver.RPCInfo{ Method: "/a/b", Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), @@ -59,7 +60,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "only one match", pm: newPathExactMatcher("/a/b", false), - hm: newHeaderExactMatcher("th", "tv"), + hm: matcher.NewHeaderExactMatcher("th", "tv"), info: iresolver.RPCInfo{ Method: "/z/y", Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), @@ -69,7 +70,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "both not match", pm: newPathExactMatcher("/z/y", false), - hm: newHeaderExactMatcher("th", "abc"), + hm: matcher.NewHeaderExactMatcher("th", "abc"), info: iresolver.RPCInfo{ Method: "/a/b", Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), @@ -79,7 +80,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "fake header", pm: newPathPrefixMatcher("/", false), - hm: newHeaderExactMatcher("content-type", "fake"), + hm: matcher.NewHeaderExactMatcher("content-type", "fake"), info: iresolver.RPCInfo{ Method: "/a/b", Context: grpcutil.WithExtraMetadata(context.Background(), metadata.Pairs( @@ -91,7 +92,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "binary header", pm: newPathPrefixMatcher("/", false), - hm: newHeaderPresentMatcher("t-bin", true), + hm: matcher.NewHeaderPresentMatcher("t-bin", true), info: iresolver.RPCInfo{ Method: "/a/b", Context: grpcutil.WithExtraMetadata( @@ -105,8 +106,8 @@ func TestAndMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := newCompositeMatcher(tt.pm, []headerMatcherInterface{tt.hm}, nil) - if got := a.match(tt.info); got != tt.want { + a := newCompositeMatcher(tt.pm, []matcher.HeaderMatcher{tt.hm}, nil) + if got := a.Match(tt.info); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -117,11 +118,11 @@ func TestFractionMatcherMatch(t *testing.T) { const fraction = 500000 fm := newFractionMatcher(fraction) defer func() { - grpcrandInt63n = grpcrand.Int63n + RandInt63n = grpcrand.Int63n }() // rand > fraction, should return false. - grpcrandInt63n = func(n int64) int64 { + RandInt63n = func(n int64) int64 { return fraction + 1 } if matched := fm.match(); matched { @@ -129,7 +130,7 @@ func TestFractionMatcherMatch(t *testing.T) { } // rand == fraction, should return true. - grpcrandInt63n = func(n int64) int64 { + RandInt63n = func(n int64) int64 { return fraction } if matched := fm.match(); !matched { @@ -137,10 +138,56 @@ func TestFractionMatcherMatch(t *testing.T) { } // rand < fraction, should return true. - grpcrandInt63n = func(n int64) int64 { + RandInt63n = func(n int64) int64 { return fraction - 1 } if matched := fm.match(); !matched { t.Errorf("match() = %v, want match", matched) } } + +func (s) TestMatchTypeForDomain(t *testing.T) { + tests := []struct { + d string + want domainMatchType + }{ + {d: "", want: domainMatchTypeInvalid}, + {d: "*", want: domainMatchTypeUniversal}, + {d: "bar.*", want: domainMatchTypePrefix}, + {d: "*.abc.com", want: domainMatchTypeSuffix}, + {d: "foo.bar.com", want: domainMatchTypeExact}, + {d: "foo.*.com", want: domainMatchTypeInvalid}, + } + for _, tt := range tests { + if got := matchTypeForDomain(tt.d); got != tt.want { + t.Errorf("matchTypeForDomain(%q) = %v, want %v", tt.d, got, tt.want) + } + } +} + +func (s) TestMatch(t *testing.T) { + tests := []struct { + name string + domain string + host string + wantTyp domainMatchType + wantMatched bool + }{ + {name: "invalid-empty", domain: "", host: "", wantTyp: domainMatchTypeInvalid, wantMatched: false}, + {name: "invalid", domain: "a.*.b", host: "", wantTyp: domainMatchTypeInvalid, wantMatched: false}, + {name: "universal", domain: "*", host: "abc.com", wantTyp: domainMatchTypeUniversal, wantMatched: true}, + {name: "prefix-match", domain: "abc.*", host: "abc.123", wantTyp: domainMatchTypePrefix, wantMatched: true}, + {name: "prefix-no-match", domain: "abc.*", host: "abcd.123", wantTyp: domainMatchTypePrefix, wantMatched: false}, + {name: "suffix-match", domain: "*.123", host: "abc.123", wantTyp: domainMatchTypeSuffix, wantMatched: true}, + {name: "suffix-no-match", domain: "*.123", host: "abc.1234", wantTyp: domainMatchTypeSuffix, wantMatched: false}, + {name: "exact-match", domain: "foo.bar", host: "foo.bar", wantTyp: domainMatchTypeExact, wantMatched: true}, + {name: "exact-no-match", domain: "foo.bar.com", host: "foo.bar", wantTyp: domainMatchTypeExact, wantMatched: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotTyp, gotMatched := match(tt.domain, tt.host); gotTyp != tt.wantTyp || gotMatched != tt.wantMatched { + t.Errorf("match() = %v, %v, want %v, %v", gotTyp, gotMatched, tt.wantTyp, tt.wantMatched) + } + }) + } +} diff --git a/xds/internal/client/rds_test.go b/xds/internal/xdsclient/rds_test.go similarity index 55% rename from xds/internal/client/rds_test.go rename to xds/internal/xdsclient/rds_test.go index 0c1e2b28538..c89e8cddca4 100644 --- a/xds/internal/client/rds_test.go +++ b/xds/internal/xdsclient/rds_test.go @@ -16,28 +16,31 @@ * */ -package client +package xdsclient import ( "fmt" + "regexp" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/version" + "google.golang.org/protobuf/types/known/durationpb" + v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" v2routepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/route" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" - "github.com/golang/protobuf/proto" anypb "github.com/golang/protobuf/ptypes/any" wrapperspb "github.com/golang/protobuf/ptypes/wrappers" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/httpfilter" - "google.golang.org/grpc/xds/internal/version" - "google.golang.org/protobuf/types/known/durationpb" ) func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { @@ -72,11 +75,55 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { Routes: []*Route{{ Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute, }}, HTTPFilterConfigOverride: cfgs, }}, } } + goodRouteConfigWithRetryPolicy = func(vhrp *v3routepb.RetryPolicy, rrp *v3routepb.RetryPolicy) *v3routepb.RouteConfiguration { + return &v3routepb.RouteConfiguration{ + Name: routeName, + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{ldsTarget}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}}, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}, + RetryPolicy: rrp, + }, + }, + }}, + RetryPolicy: vhrp, + }}, + } + } + goodUpdateWithRetryPolicy = func(vhrc *RetryConfig, rrc *RetryConfig) RouteConfigUpdate { + if !env.RetrySupport { + vhrc = nil + rrc = nil + } + return RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{{ + Domains: []string{ldsTarget}, + Routes: []*Route{{ + Prefix: newStringP("/"), + WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute, + RetryConfig: rrc, + }}, + RetryConfig: vhrc, + }}, + } + } + defaultRetryBackoff = RetryBackoff{BaseInterval: 25 * time.Millisecond, MaxInterval: 250 * time.Millisecond} + goodUpdateIfRetryDisabled = func() RouteConfigUpdate { + if env.RetrySupport { + return RouteConfigUpdate{} + } + return goodUpdateWithRetryPolicy(nil, nil) + } ) tests := []struct { @@ -84,7 +131,6 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { rc *v3routepb.RouteConfiguration wantUpdate RouteConfigUpdate wantError bool - disableFI bool // disable fault injection }{ { name: "default-route-match-field-is-nil", @@ -175,7 +221,10 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { VirtualHosts: []*VirtualHost{ { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP("/"), CaseInsensitive: true, WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP("/"), + CaseInsensitive: true, + WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, }, @@ -217,11 +266,15 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, }, @@ -251,7 +304,9 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { VirtualHosts: []*VirtualHost{ { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP("/"), + WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, }, @@ -328,6 +383,7 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { "b": {Weight: 3}, "c": {Weight: 5}, }, + RouteAction: RouteActionRoute, }}, }, }, @@ -362,6 +418,7 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, MaxStreamDuration: newDurationP(time.Second), + RouteAction: RouteActionRoute, }}, }, }, @@ -396,6 +453,7 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, MaxStreamDuration: newDurationP(time.Second), + RouteAction: RouteActionRoute, }}, }, }, @@ -430,6 +488,7 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, MaxStreamDuration: newDurationP(0), + RouteAction: RouteActionRoute, }}, }, }, @@ -471,18 +530,50 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { wantUpdate: goodUpdateWithFilterConfigs(nil), }, { - name: "good-route-config-with-http-err-filter-config-fi-disabled", - disableFI: true, - rc: goodRouteConfigWithFilterConfigs(map[string]*anypb.Any{"foo": errFilterConfig}), - wantUpdate: goodUpdateWithFilterConfigs(nil), + name: "good-route-config-with-retry-policy", + rc: goodRouteConfigWithRetryPolicy( + &v3routepb.RetryPolicy{RetryOn: "cancelled"}, + &v3routepb.RetryPolicy{RetryOn: "deadline-exceeded,unsupported", NumRetries: &wrapperspb.UInt32Value{Value: 2}}), + wantUpdate: goodUpdateWithRetryPolicy( + &RetryConfig{RetryOn: map[codes.Code]bool{codes.Canceled: true}, NumRetries: 1, RetryBackoff: defaultRetryBackoff}, + &RetryConfig{RetryOn: map[codes.Code]bool{codes.DeadlineExceeded: true}, NumRetries: 2, RetryBackoff: defaultRetryBackoff}), + }, + { + name: "good-route-config-with-retry-backoff", + rc: goodRouteConfigWithRetryPolicy( + &v3routepb.RetryPolicy{RetryOn: "internal", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{BaseInterval: durationpb.New(10 * time.Millisecond), MaxInterval: durationpb.New(10 * time.Millisecond)}}, + &v3routepb.RetryPolicy{RetryOn: "resource-exhausted", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{BaseInterval: durationpb.New(10 * time.Millisecond)}}), + wantUpdate: goodUpdateWithRetryPolicy( + &RetryConfig{RetryOn: map[codes.Code]bool{codes.Internal: true}, NumRetries: 1, RetryBackoff: RetryBackoff{BaseInterval: 10 * time.Millisecond, MaxInterval: 10 * time.Millisecond}}, + &RetryConfig{RetryOn: map[codes.Code]bool{codes.ResourceExhausted: true}, NumRetries: 1, RetryBackoff: RetryBackoff{BaseInterval: 10 * time.Millisecond, MaxInterval: 100 * time.Millisecond}}), + }, + { + name: "bad-retry-policy-0-retries", + rc: goodRouteConfigWithRetryPolicy(&v3routepb.RetryPolicy{RetryOn: "cancelled", NumRetries: &wrapperspb.UInt32Value{Value: 0}}, nil), + wantUpdate: goodUpdateIfRetryDisabled(), + wantError: env.RetrySupport, + }, + { + name: "bad-retry-policy-0-base-interval", + rc: goodRouteConfigWithRetryPolicy(&v3routepb.RetryPolicy{RetryOn: "cancelled", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{BaseInterval: durationpb.New(0)}}, nil), + wantUpdate: goodUpdateIfRetryDisabled(), + wantError: env.RetrySupport, + }, + { + name: "bad-retry-policy-negative-max-interval", + rc: goodRouteConfigWithRetryPolicy(&v3routepb.RetryPolicy{RetryOn: "cancelled", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{MaxInterval: durationpb.New(-time.Second)}}, nil), + wantUpdate: goodUpdateIfRetryDisabled(), + wantError: env.RetrySupport, + }, + { + name: "bad-retry-policy-negative-max-interval-no-known-retry-on", + rc: goodRouteConfigWithRetryPolicy(&v3routepb.RetryPolicy{RetryOn: "something", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{MaxInterval: durationpb.New(-time.Second)}}, nil), + wantUpdate: goodUpdateIfRetryDisabled(), + wantError: env.RetrySupport, }, } - for _, test := range tests { t.Run(test.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = !test.disableFI - gotUpdate, gotError := generateRDSUpdateFromRouteConfiguration(test.rc, nil, false) if (gotError != nil) != test.wantError || !cmp.Equal(gotUpdate, test.wantUpdate, cmpopts.EquateEmpty(), @@ -490,8 +581,6 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { return fmt.Sprint(fc) })) { t.Errorf("generateRDSUpdateFromRouteConfiguration(%+v, %v) returned unexpected, diff (-want +got):\\n%s", test.rc, ldsTarget, cmp.Diff(test.wantUpdate, gotUpdate, cmpopts.EquateEmpty())) - - env.FaultInjectionSupport = oldFI } }) } @@ -537,17 +626,10 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { }, }, } - v2RouteConfig = &anypb.Any{ - TypeUrl: version.V2RouteConfigURL, - Value: func() []byte { - rc := &v2xdspb.RouteConfiguration{ - Name: v2RouteConfigName, - VirtualHosts: v2VirtualHost, - } - m, _ := proto.Marshal(rc) - return m - }(), - } + v2RouteConfig = testutils.MarshalAny(&v2xdspb.RouteConfiguration{ + Name: v2RouteConfigName, + VirtualHosts: v2VirtualHost, + }) v3VirtualHost = []*v3routepb.VirtualHost{ { Domains: []string{uninterestingDomain}, @@ -576,24 +658,17 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { }, }, } - v3RouteConfig = &anypb.Any{ - TypeUrl: version.V2RouteConfigURL, - Value: func() []byte { - rc := &v3routepb.RouteConfiguration{ - Name: v3RouteConfigName, - VirtualHosts: v3VirtualHost, - } - m, _ := proto.Marshal(rc) - return m - }(), - } + v3RouteConfig = testutils.MarshalAny(&v3routepb.RouteConfiguration{ + Name: v3RouteConfigName, + VirtualHosts: v3VirtualHost, + }) ) const testVersion = "test-version-rds" tests := []struct { name string resources []*anypb.Any - wantUpdate map[string]RouteConfigUpdate + wantUpdate map[string]RouteConfigUpdateErrTuple wantMD UpdateMetadata wantErr bool }{ @@ -605,7 +680,7 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -623,7 +698,7 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -638,20 +713,24 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { { name: "v2 routeConfig resource", resources: []*anypb.Any{v2RouteConfig}, - wantUpdate: map[string]RouteConfigUpdate{ - v2RouteConfigName: { + wantUpdate: map[string]RouteConfigUpdateErrTuple{ + v2RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v2RouteConfig, - }, + }}, }, wantMD: UpdateMetadata{ Status: ServiceStatusACKed, @@ -661,20 +740,24 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { { name: "v3 routeConfig resource", resources: []*anypb.Any{v3RouteConfig}, - wantUpdate: map[string]RouteConfigUpdate{ - v3RouteConfigName: { + wantUpdate: map[string]RouteConfigUpdateErrTuple{ + v3RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v3RouteConfig, - }, + }}, }, wantMD: UpdateMetadata{ Status: ServiceStatusACKed, @@ -684,33 +767,41 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { { name: "multiple routeConfig resources", resources: []*anypb.Any{v2RouteConfig, v3RouteConfig}, - wantUpdate: map[string]RouteConfigUpdate{ - v3RouteConfigName: { + wantUpdate: map[string]RouteConfigUpdateErrTuple{ + v3RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v3RouteConfig, - }, - v2RouteConfigName: { + }}, + v2RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v2RouteConfig, - }, + }}, }, wantMD: UpdateMetadata{ Status: ServiceStatusACKed, @@ -722,57 +813,58 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { name: "good and bad routeConfig resources", resources: []*anypb.Any{ v2RouteConfig, - { - TypeUrl: version.V2RouteConfigURL, - Value: func() []byte { - rc := &v3routepb.RouteConfiguration{ - Name: "bad", - VirtualHosts: []*v3routepb.VirtualHost{ - {Domains: []string{ldsTarget}, - Routes: []*v3routepb.Route{{ - Match: &v3routepb.RouteMatch{PathSpecifier: &v3routepb.RouteMatch_ConnectMatcher_{}}, - }}}}} - m, _ := proto.Marshal(rc) - return m - }(), - }, + testutils.MarshalAny(&v3routepb.RouteConfiguration{ + Name: "bad", + VirtualHosts: []*v3routepb.VirtualHost{ + {Domains: []string{ldsTarget}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{PathSpecifier: &v3routepb.RouteMatch_ConnectMatcher_{}}, + }}}}}), v3RouteConfig, }, - wantUpdate: map[string]RouteConfigUpdate{ - v3RouteConfigName: { + wantUpdate: map[string]RouteConfigUpdateErrTuple{ + v3RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v3RouteConfig, - }, - v2RouteConfigName: { + }}, + v2RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v2RouteConfig, - }, - "bad": {}, + }}, + "bad": {Err: cmpopts.AnyError}, }, wantMD: UpdateMetadata{ Status: ServiceStatusNACKed, Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -780,9 +872,13 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - update, md, err := UnmarshalRouteConfig(testVersion, test.resources, nil) + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + update, md, err := UnmarshalRouteConfig(opts) if (err != nil) != test.wantErr { - t.Fatalf("UnmarshalRouteConfig(), got err: %v, wantErr: %v", err, test.wantErr) + t.Fatalf("UnmarshalRouteConfig(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) } if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { t.Errorf("got unexpected update, diff (-got +want): %v", diff) @@ -823,6 +919,7 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { CaseInsensitive: true, WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60, HTTPFilterConfigOverride: cfgs}}, HTTPFilterConfigOverride: cfgs, + RouteAction: RouteActionRoute, }} } ) @@ -832,7 +929,6 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { routes []*v3routepb.Route wantRoutes []*Route wantErr bool - disableFI bool // disable fault injection }{ { name: "no path", @@ -863,6 +959,7 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { Prefix: newStringP("/"), CaseInsensitive: true, WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, }}, }, { @@ -910,6 +1007,53 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { }, Fraction: newUInt32P(10000), WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, + { + name: "good with regex matchers", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: "/a/"}}, + Headers: []*v3routepb.HeaderMatcher{ + { + Name: "th", + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{SafeRegexMatch: &v3matcherpb.RegexMatcher{Regex: "tv"}}, + }, + }, + RuntimeFraction: &v3corepb.RuntimeFractionalPercent{ + DefaultValue: &v3typepb.FractionalPercent{ + Numerator: 1, + Denominator: v3typepb.FractionalPercent_HUNDRED, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}}}, + }, + }, + wantRoutes: []*Route{{ + Regex: func() *regexp.Regexp { return regexp.MustCompile("/a/") }(), + Headers: []*HeaderMatcher{ + { + Name: "th", + InvertMatch: newBoolP(false), + RegexMatch: func() *regexp.Regexp { return regexp.MustCompile("tv") }(), + }, + }, + Fraction: newUInt32P(10000), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, }}, wantErr: false, }, @@ -944,6 +1088,7 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { wantRoutes: []*Route{{ Prefix: newStringP("/a/"), WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, }}, wantErr: false, }, @@ -958,6 +1103,44 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { }, wantErr: true, }, + { + name: "bad regex in path specifier", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: "??"}}, + Headers: []*v3routepb.HeaderMatcher{ + { + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "tv"}, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}}, + }, + }, + }, + wantErr: true, + }, + { + name: "bad regex in header specifier", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + Headers: []*v3routepb.HeaderMatcher{ + { + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{SafeRegexMatch: &v3matcherpb.RegexMatcher{Regex: "??"}}, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}}, + }, + }, + }, + wantErr: true, + }, { name: "unrecognized header match specifier", routes: []*v3routepb.Route{ @@ -967,7 +1150,7 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { Headers: []*v3routepb.HeaderMatcher{ { Name: "th", - HeaderMatchSpecifier: &v3routepb.HeaderMatcher_HiddenEnvoyDeprecatedRegexMatch{}, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_StringMatch{}, }, }, }, @@ -1011,6 +1194,227 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { }, wantErr: true, }, + { + name: "totalWeight is nil in weighted clusters action", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 20}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 30}}, + }, + }}}}, + }, + }, + wantErr: true, + }, + { + name: "The sum of all weighted clusters is not equal totalWeight", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 50}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 20}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}}}, + }, + }, + wantErr: true, + }, + { + name: "unsupported cluster specifier", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_ClusterSpecifierPlugin{}}}, + }, + }, + wantErr: true, + }, + { + name: "default totalWeight is 100 in weighted clusters action", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + }}}}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, + { + name: "default totalWeight is 100 in weighted clusters action", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 30}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 20}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 50}, + }}}}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 20}, "B": {Weight: 30}}, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, + { + name: "good-with-channel-id-hash-policy", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + Headers: []*v3routepb.HeaderMatcher{ + { + Name: "th", + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{ + PrefixMatch: "tv", + }, + InvertMatch: true, + }, + }, + RuntimeFraction: &v3corepb.RuntimeFractionalPercent{ + DefaultValue: &v3typepb.FractionalPercent{ + Numerator: 1, + Denominator: v3typepb.FractionalPercent_HUNDRED, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "io.grpc.channel_id"}}}, + }, + }}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + Headers: []*HeaderMatcher{ + { + Name: "th", + InvertMatch: newBoolP(true), + PrefixMatch: newStringP("tv"), + }, + }, + Fraction: newUInt32P(10000), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + HashPolicies: []*HashPolicy{ + {HashPolicyType: HashPolicyTypeChannelID}, + }, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, + // This tests that policy.Regex ends up being nil if RegexRewrite is not + // set in xds response. + { + name: "good-with-header-hash-policy-no-regex-specified", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + Headers: []*v3routepb.HeaderMatcher{ + { + Name: "th", + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{ + PrefixMatch: "tv", + }, + InvertMatch: true, + }, + }, + RuntimeFraction: &v3corepb.RuntimeFractionalPercent{ + DefaultValue: &v3typepb.FractionalPercent{ + Numerator: 1, + Denominator: v3typepb.FractionalPercent_HUNDRED, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{Header: &v3routepb.RouteAction_HashPolicy_Header{HeaderName: ":path"}}}, + }, + }}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + Headers: []*HeaderMatcher{ + { + Name: "th", + InvertMatch: newBoolP(true), + PrefixMatch: newStringP("tv"), + }, + }, + Fraction: newUInt32P(10000), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + HashPolicies: []*HashPolicy{ + {HashPolicyType: HashPolicyTypeHeader, + HeaderName: ":path"}, + }, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, { name: "with custom HTTP filter config", routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": customFilterConfig}), @@ -1026,12 +1430,6 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": wrappedOptionalFilter("custom.filter")}), wantRoutes: goodUpdateWithFilterConfigs(map[string]httpfilter.FilterConfig{"foo": filterConfig{Override: customFilterConfig}}), }, - { - name: "with custom HTTP filter config, FI disabled", - disableFI: true, - routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": customFilterConfig}), - wantRoutes: goodUpdateWithFilterConfigs(nil), - }, { name: "with erroring custom HTTP filter config", routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": errFilterConfig}), @@ -1042,12 +1440,6 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": wrappedOptionalFilter("err.custom.filter")}), wantErr: true, }, - { - name: "with erroring custom HTTP filter config, FI disabled", - disableFI: true, - routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": errFilterConfig}), - wantRoutes: goodUpdateWithFilterConfigs(nil), - }, { name: "with unknown custom HTTP filter config", routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": unknownFilterConfig}), @@ -1061,28 +1453,137 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { } cmpOpts := []cmp.Option{ - cmp.AllowUnexported(Route{}, HeaderMatcher{}, Int64Range{}), + cmp.AllowUnexported(Route{}, HeaderMatcher{}, Int64Range{}, regexp.Regexp{}), cmpopts.EquateEmpty(), cmp.Transformer("FilterConfig", func(fc httpfilter.FilterConfig) string { return fmt.Sprint(fc) }), } - + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = !tt.disableFI - got, err := routesProtoToSlice(tt.routes, nil, false) if (err != nil) != tt.wantErr { - t.Errorf("routesProtoToSlice() error = %v, wantErr %v", err, tt.wantErr) - return + t.Fatalf("routesProtoToSlice() error = %v, wantErr %v", err, tt.wantErr) } - if !cmp.Equal(got, tt.wantRoutes, cmpOpts...) { - t.Errorf("routesProtoToSlice() got = %v, want %v, diff: %v", got, tt.wantRoutes, cmp.Diff(got, tt.wantRoutes, cmpOpts...)) + if diff := cmp.Diff(got, tt.wantRoutes, cmpOpts...); diff != "" { + t.Fatalf("routesProtoToSlice() returned unexpected diff (-got +want):\n%s", diff) } + }) + } +} + +func (s) TestHashPoliciesProtoToSlice(t *testing.T) { + tests := []struct { + name string + hashPolicies []*v3routepb.RouteAction_HashPolicy + wantHashPolicies []*HashPolicy + wantErr bool + }{ + // header-hash-policy tests a basic hash policy that specifies to hash a + // certain header. + { + name: "header-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + { + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: ":path", + RegexRewrite: &v3matcherpb.RegexMatchAndSubstitute{ + Pattern: &v3matcherpb.RegexMatcher{Regex: "/products"}, + Substitution: "/products", + }, + }, + }, + }, + }, + wantHashPolicies: []*HashPolicy{ + { + HashPolicyType: HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("/products") }(), + RegexSubstitution: "/products", + }, + }, + }, + // channel-id-hash-policy tests a basic hash policy that specifies to + // hash a unique identifier of the channel. + { + name: "channel-id-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "io.grpc.channel_id"}}}, + }, + wantHashPolicies: []*HashPolicy{ + {HashPolicyType: HashPolicyTypeChannelID}, + }, + }, + // unsupported-filter-state-key tests that an unsupported key in the + // filter state hash policy are treated as a no-op. + { + name: "wrong-filter-state-key", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "unsupported key"}}}, + }, + }, + // no-op-hash-policy tests that hash policies that are not supported by + // grpc are treated as a no-op. + { + name: "no-op-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{}}, + }, + }, + // header-and-channel-id-hash-policy test that a list of header and + // channel id hash policies are successfully converted to an internal + // struct. + { + name: "header-and-channel-id-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + { + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: ":path", + RegexRewrite: &v3matcherpb.RegexMatchAndSubstitute{ + Pattern: &v3matcherpb.RegexMatcher{Regex: "/products"}, + Substitution: "/products", + }, + }, + }, + }, + { + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "io.grpc.channel_id"}}, + Terminal: true, + }, + }, + wantHashPolicies: []*HashPolicy{ + { + HashPolicyType: HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("/products") }(), + RegexSubstitution: "/products", + }, + { + HashPolicyType: HashPolicyTypeChannelID, + Terminal: true, + }, + }, + }, + } - env.FaultInjectionSupport = oldFI + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := hashPoliciesProtoToSlice(tt.hashPolicies, nil) + if (err != nil) != tt.wantErr { + t.Fatalf("hashPoliciesProtoToSlice() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.wantHashPolicies, cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Fatalf("hashPoliciesProtoToSlice() returned unexpected diff (-got +want):\n%s", diff) + } }) } } diff --git a/xds/internal/xdsclient/requests_counter.go b/xds/internal/xdsclient/requests_counter.go new file mode 100644 index 00000000000..beed2e9d0ad --- /dev/null +++ b/xds/internal/xdsclient/requests_counter.go @@ -0,0 +1,107 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "fmt" + "sync" + "sync/atomic" +) + +type clusterNameAndServiceName struct { + clusterName, edsServcieName string +} + +type clusterRequestsCounter struct { + mu sync.Mutex + clusters map[clusterNameAndServiceName]*ClusterRequestsCounter +} + +var src = &clusterRequestsCounter{ + clusters: make(map[clusterNameAndServiceName]*ClusterRequestsCounter), +} + +// ClusterRequestsCounter is used to track the total inflight requests for a +// service with the provided name. +type ClusterRequestsCounter struct { + ClusterName string + EDSServiceName string + numRequests uint32 +} + +// GetClusterRequestsCounter returns the ClusterRequestsCounter with the +// provided serviceName. If one does not exist, it creates it. +func GetClusterRequestsCounter(clusterName, edsServiceName string) *ClusterRequestsCounter { + src.mu.Lock() + defer src.mu.Unlock() + k := clusterNameAndServiceName{ + clusterName: clusterName, + edsServcieName: edsServiceName, + } + c, ok := src.clusters[k] + if !ok { + c = &ClusterRequestsCounter{ClusterName: clusterName} + src.clusters[k] = c + } + return c +} + +// StartRequest starts a request for a cluster, incrementing its number of +// requests by 1. Returns an error if the max number of requests is exceeded. +func (c *ClusterRequestsCounter) StartRequest(max uint32) error { + // Note that during race, the limits could be exceeded. This is allowed: + // "Since the implementation is eventually consistent, races between threads + // may allow limits to be potentially exceeded." + // https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/upstream/circuit_breaking#arch-overview-circuit-break. + if atomic.LoadUint32(&c.numRequests) >= max { + return fmt.Errorf("max requests %v exceeded on service %v", max, c.ClusterName) + } + atomic.AddUint32(&c.numRequests, 1) + return nil +} + +// EndRequest ends a request for a service, decrementing its number of requests +// by 1. +func (c *ClusterRequestsCounter) EndRequest() { + atomic.AddUint32(&c.numRequests, ^uint32(0)) +} + +// ClearCounterForTesting clears the counter for the service. Should be only +// used in tests. +func ClearCounterForTesting(clusterName, edsServiceName string) { + src.mu.Lock() + defer src.mu.Unlock() + k := clusterNameAndServiceName{ + clusterName: clusterName, + edsServcieName: edsServiceName, + } + c, ok := src.clusters[k] + if !ok { + return + } + c.numRequests = 0 +} + +// ClearAllCountersForTesting clears all the counters. Should be only used in +// tests. +func ClearAllCountersForTesting() { + src.mu.Lock() + defer src.mu.Unlock() + src.clusters = make(map[clusterNameAndServiceName]*ClusterRequestsCounter) +} diff --git a/xds/internal/client/requests_counter_test.go b/xds/internal/xdsclient/requests_counter_test.go similarity index 76% rename from xds/internal/client/requests_counter_test.go rename to xds/internal/xdsclient/requests_counter_test.go index fe532724d14..e2eeea774e2 100644 --- a/xds/internal/client/requests_counter_test.go +++ b/xds/internal/xdsclient/requests_counter_test.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "sync" @@ -24,6 +24,8 @@ import ( "testing" ) +const testService = "test-service-name" + type counterTest struct { name string maxRequests uint32 @@ -49,9 +51,9 @@ var tests = []counterTest{ }, } -func resetServiceRequestsCounter() { - src = &servicesRequestsCounter{ - services: make(map[string]*ServiceRequestsCounter), +func resetClusterRequestsCounter() { + src = &clusterRequestsCounter{ + clusters: make(map[clusterNameAndServiceName]*ClusterRequestsCounter), } } @@ -65,7 +67,7 @@ func testCounter(t *testing.T, test counterTest) { var successes, errors uint32 for i := 0; i < int(test.numRequests); i++ { go func() { - counter := GetServiceRequestsCounter(test.name) + counter := GetClusterRequestsCounter(test.name, testService) defer requestsDone.Done() err := counter.StartRequest(test.maxRequests) if err == nil { @@ -91,13 +93,17 @@ func testCounter(t *testing.T, test counterTest) { if test.expectedErrors == 0 && loadedError != nil { t.Errorf("error starting request: %v", loadedError.(error)) } - if successes != test.expectedSuccesses || errors != test.expectedErrors { + // We allow the limits to be exceeded during races. + // + // But we should never over-limit, so this test fails if there are less + // successes than expected. + if successes < test.expectedSuccesses || errors > test.expectedErrors { t.Errorf("unexpected number of (successes, errors), expected (%v, %v), encountered (%v, %v)", test.expectedSuccesses, test.expectedErrors, successes, errors) } } func (s) TestRequestsCounter(t *testing.T) { - defer resetServiceRequestsCounter() + defer resetClusterRequestsCounter() for _, test := range tests { t.Run(test.name, func(t *testing.T) { testCounter(t, test) @@ -105,18 +111,18 @@ func (s) TestRequestsCounter(t *testing.T) { } } -func (s) TestGetServiceRequestsCounter(t *testing.T) { - defer resetServiceRequestsCounter() +func (s) TestGetClusterRequestsCounter(t *testing.T) { + defer resetClusterRequestsCounter() for _, test := range tests { - counterA := GetServiceRequestsCounter(test.name) - counterB := GetServiceRequestsCounter(test.name) + counterA := GetClusterRequestsCounter(test.name, testService) + counterB := GetClusterRequestsCounter(test.name, testService) if counterA != counterB { t.Errorf("counter %v %v != counter %v %v", counterA, *counterA, counterB, *counterB) } } } -func startRequests(t *testing.T, n uint32, max uint32, counter *ServiceRequestsCounter) { +func startRequests(t *testing.T, n uint32, max uint32, counter *ClusterRequestsCounter) { for i := uint32(0); i < n; i++ { if err := counter.StartRequest(max); err != nil { t.Fatalf("error starting initial request: %v", err) @@ -125,11 +131,11 @@ func startRequests(t *testing.T, n uint32, max uint32, counter *ServiceRequestsC } func (s) TestSetMaxRequestsIncreased(t *testing.T) { - defer resetServiceRequestsCounter() - const serviceName string = "set-max-requests-increased" + defer resetClusterRequestsCounter() + const clusterName string = "set-max-requests-increased" var initialMax uint32 = 16 - counter := GetServiceRequestsCounter(serviceName) + counter := GetClusterRequestsCounter(clusterName, testService) startRequests(t, initialMax, initialMax, counter) if err := counter.StartRequest(initialMax); err == nil { t.Fatal("unexpected success on start request after max met") @@ -142,11 +148,11 @@ func (s) TestSetMaxRequestsIncreased(t *testing.T) { } func (s) TestSetMaxRequestsDecreased(t *testing.T) { - defer resetServiceRequestsCounter() - const serviceName string = "set-max-requests-decreased" + defer resetClusterRequestsCounter() + const clusterName string = "set-max-requests-decreased" var initialMax uint32 = 16 - counter := GetServiceRequestsCounter(serviceName) + counter := GetClusterRequestsCounter(clusterName, testService) startRequests(t, initialMax-1, initialMax, counter) newMax := initialMax - 1 diff --git a/xds/internal/xdsclient/singleton.go b/xds/internal/xdsclient/singleton.go new file mode 100644 index 00000000000..f045790e2a4 --- /dev/null +++ b/xds/internal/xdsclient/singleton.go @@ -0,0 +1,198 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "bytes" + "encoding/json" + "fmt" + "sync" + "time" + + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" +) + +const defaultWatchExpiryTimeout = 15 * time.Second + +// This is the Client returned by New(). It contains one client implementation, +// and maintains the refcount. +var singletonClient = &clientRefCounted{} + +// To override in tests. +var bootstrapNewConfig = bootstrap.NewConfig + +// clientRefCounted is ref-counted, and to be shared by the xds resolver and +// balancer implementations, across multiple ClientConns and Servers. +type clientRefCounted struct { + *clientImpl + + // This mu protects all the fields, including the embedded clientImpl above. + mu sync.Mutex + refCount int +} + +// New returns a new xdsClient configured by the bootstrap file specified in env +// variable GRPC_XDS_BOOTSTRAP or GRPC_XDS_BOOTSTRAP_CONFIG. +// +// The returned xdsClient is a singleton. This function creates the xds client +// if it doesn't already exist. +// +// Note that the first invocation of New() or NewWithConfig() sets the client +// singleton. The following calls will return the singleton xds client without +// checking or using the config. +func New() (XDSClient, error) { + // This cannot just return newRefCounted(), because in error cases, the + // returned nil is a typed nil (*clientRefCounted), which may cause nil + // checks fail. + c, err := newRefCounted() + if err != nil { + return nil, err + } + return c, nil +} + +func newRefCounted() (*clientRefCounted, error) { + singletonClient.mu.Lock() + defer singletonClient.mu.Unlock() + // If the client implementation was created, increment ref count and return + // the client. + if singletonClient.clientImpl != nil { + singletonClient.refCount++ + return singletonClient, nil + } + + // Create the new client implementation. + config, err := bootstrapNewConfig() + if err != nil { + return nil, fmt.Errorf("xds: failed to read bootstrap file: %v", err) + } + c, err := newWithConfig(config, defaultWatchExpiryTimeout) + if err != nil { + return nil, err + } + + singletonClient.clientImpl = c + singletonClient.refCount++ + return singletonClient, nil +} + +// NewWithConfig returns a new xdsClient configured by the given config. +// +// The returned xdsClient is a singleton. This function creates the xds client +// if it doesn't already exist. +// +// Note that the first invocation of New() or NewWithConfig() sets the client +// singleton. The following calls will return the singleton xds client without +// checking or using the config. +// +// This function is internal only, for c2p resolver and testing to use. DO NOT +// use this elsewhere. Use New() instead. +func NewWithConfig(config *bootstrap.Config) (XDSClient, error) { + singletonClient.mu.Lock() + defer singletonClient.mu.Unlock() + // If the client implementation was created, increment ref count and return + // the client. + if singletonClient.clientImpl != nil { + singletonClient.refCount++ + return singletonClient, nil + } + + // Create the new client implementation. + c, err := newWithConfig(config, defaultWatchExpiryTimeout) + if err != nil { + return nil, err + } + + singletonClient.clientImpl = c + singletonClient.refCount++ + return singletonClient, nil +} + +// Close closes the client. It does ref count of the xds client implementation, +// and closes the gRPC connection to the management server when ref count +// reaches 0. +func (c *clientRefCounted) Close() { + c.mu.Lock() + defer c.mu.Unlock() + c.refCount-- + if c.refCount == 0 { + c.clientImpl.Close() + // Set clientImpl back to nil. So if New() is called after this, a new + // implementation will be created. + c.clientImpl = nil + } +} + +// NewWithConfigForTesting is exported for testing only. +// +// Note that this function doesn't set the singleton, so that the testing states +// don't leak. +func NewWithConfigForTesting(config *bootstrap.Config, watchExpiryTimeout time.Duration) (XDSClient, error) { + cl, err := newWithConfig(config, watchExpiryTimeout) + if err != nil { + return nil, err + } + return &clientRefCounted{clientImpl: cl, refCount: 1}, nil +} + +// NewClientWithBootstrapContents returns an xds client for this config, +// separate from the global singleton. This should be used for testing +// purposes only. +func NewClientWithBootstrapContents(contents []byte) (XDSClient, error) { + // Normalize the contents + buf := bytes.Buffer{} + err := json.Indent(&buf, contents, "", "") + if err != nil { + return nil, fmt.Errorf("xds: error normalizing JSON: %v", err) + } + contents = bytes.TrimSpace(buf.Bytes()) + + clientsMu.Lock() + defer clientsMu.Unlock() + if c := clients[string(contents)]; c != nil { + c.mu.Lock() + // Since we don't remove the *Client from the map when it is closed, we + // need to recreate the impl if the ref count dropped to zero. + if c.refCount > 0 { + c.refCount++ + c.mu.Unlock() + return c, nil + } + c.mu.Unlock() + } + + bcfg, err := bootstrap.NewConfigFromContents(contents) + if err != nil { + return nil, fmt.Errorf("xds: error with bootstrap config: %v", err) + } + + cImpl, err := newWithConfig(bcfg, defaultWatchExpiryTimeout) + if err != nil { + return nil, err + } + + c := &clientRefCounted{clientImpl: cImpl, refCount: 1} + clients[string(contents)] = c + return c, nil +} + +var ( + clients = map[string]*clientRefCounted{} + clientsMu sync.Mutex +) diff --git a/xds/internal/client/transport_helper.go b/xds/internal/xdsclient/transport_helper.go similarity index 98% rename from xds/internal/client/transport_helper.go rename to xds/internal/xdsclient/transport_helper.go index b286a61d638..4c56daaf011 100644 --- a/xds/internal/client/transport_helper.go +++ b/xds/internal/xdsclient/transport_helper.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "context" @@ -24,7 +24,7 @@ import ( "time" "github.com/golang/protobuf/proto" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" "google.golang.org/grpc" "google.golang.org/grpc/internal/buffer" @@ -297,7 +297,7 @@ func (t *TransportHelper) sendExisting(stream grpc.ClientStream) bool { for rType, s := range t.watchMap { if err := t.vClient.SendRequest(stream, mapToSlice(s), rType, "", "", ""); err != nil { - t.logger.Errorf("ADS request failed: %v", err) + t.logger.Warningf("ADS request failed: %v", err) return false } } @@ -342,11 +342,12 @@ func (t *TransportHelper) recv(stream grpc.ClientStream) bool { } } -func mapToSlice(m map[string]bool) (ret []string) { +func mapToSlice(m map[string]bool) []string { + ret := make([]string, 0, len(m)) for i := range m { ret = append(ret, i) } - return + return ret } type watchAction struct { diff --git a/xds/internal/client/v2/ack_test.go b/xds/internal/xdsclient/v2/ack_test.go similarity index 99% rename from xds/internal/client/v2/ack_test.go rename to xds/internal/xdsclient/v2/ack_test.go index 813d8baa79d..d2f0605f6d0 100644 --- a/xds/internal/client/v2/ack_test.go +++ b/xds/internal/xdsclient/v2/ack_test.go @@ -31,9 +31,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( diff --git a/xds/internal/client/v2/cds_test.go b/xds/internal/xdsclient/v2/cds_test.go similarity index 86% rename from xds/internal/client/v2/cds_test.go rename to xds/internal/xdsclient/v2/cds_test.go index c71b8453231..cef7563017c 100644 --- a/xds/internal/client/v2/cds_test.go +++ b/xds/internal/xdsclient/v2/cds_test.go @@ -24,10 +24,11 @@ import ( xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - "github.com/golang/protobuf/ptypes" anypb "github.com/golang/protobuf/ptypes/any" - xdsclient "google.golang.org/grpc/xds/internal/client" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -63,8 +64,8 @@ var ( }, }, } - marshaledCluster1, _ = ptypes.MarshalAny(goodCluster1) - goodCluster2 = &xdspb.Cluster{ + marshaledCluster1 = testutils.MarshalAny(goodCluster1) + goodCluster2 = &xdspb.Cluster{ Name: goodClusterName2, ClusterDiscoveryType: &xdspb.Cluster_Type{Type: xdspb.Cluster_EDS}, EdsClusterConfig: &xdspb.Cluster_EdsClusterConfig{ @@ -77,8 +78,8 @@ var ( }, LbPolicy: xdspb.Cluster_ROUND_ROBIN, } - marshaledCluster2, _ = ptypes.MarshalAny(goodCluster2) - goodCDSResponse1 = &xdspb.DiscoveryResponse{ + marshaledCluster2 = testutils.MarshalAny(goodCluster2) + goodCDSResponse1 = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ marshaledCluster1, }, @@ -100,7 +101,7 @@ func (s) TestCDSHandleResponse(t *testing.T) { name string cdsResponse *xdspb.DiscoveryResponse wantErr bool - wantUpdate map[string]xdsclient.ClusterUpdate + wantUpdate map[string]xdsclient.ClusterUpdateErrTuple wantUpdateMD xdsclient.UpdateMetadata wantUpdateErr bool }{ @@ -113,7 +114,7 @@ func (s) TestCDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -127,7 +128,7 @@ func (s) TestCDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -148,8 +149,8 @@ func (s) TestCDSHandleResponse(t *testing.T) { name: "one-uninteresting-cluster", cdsResponse: goodCDSResponse2, wantErr: false, - wantUpdate: map[string]xdsclient.ClusterUpdate{ - goodClusterName2: {ServiceName: serviceName2, Raw: marshaledCluster2}, + wantUpdate: map[string]xdsclient.ClusterUpdateErrTuple{ + goodClusterName2: {Update: xdsclient.ClusterUpdate{ClusterName: goodClusterName2, EDSServiceName: serviceName2, Raw: marshaledCluster2}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -161,8 +162,8 @@ func (s) TestCDSHandleResponse(t *testing.T) { name: "one-good-cluster", cdsResponse: goodCDSResponse1, wantErr: false, - wantUpdate: map[string]xdsclient.ClusterUpdate{ - goodClusterName1: {ServiceName: serviceName1, EnableLRS: true, Raw: marshaledCluster1}, + wantUpdate: map[string]xdsclient.ClusterUpdateErrTuple{ + goodClusterName1: {Update: xdsclient.ClusterUpdate{ClusterName: goodClusterName1, EDSServiceName: serviceName1, EnableLRS: true, Raw: marshaledCluster1}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, diff --git a/xds/internal/client/v2/client.go b/xds/internal/xdsclient/v2/client.go similarity index 82% rename from xds/internal/client/v2/client.go rename to xds/internal/xdsclient/v2/client.go index b6bc4908120..dc137f63e5f 100644 --- a/xds/internal/client/v2/client.go +++ b/xds/internal/xdsclient/v2/client.go @@ -27,8 +27,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpclog" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" @@ -65,10 +66,11 @@ func newClient(cc *grpc.ClientConn, opts xdsclient.BuildOptions) (xdsclient.APIC return nil, fmt.Errorf("xds: unsupported Node proto type: %T, want %T", opts.NodeProto, (*v2corepb.Node)(nil)) } v2c := &client{ - cc: cc, - parent: opts.Parent, - nodeProto: nodeProto, - logger: opts.Logger, + cc: cc, + parent: opts.Parent, + nodeProto: nodeProto, + logger: opts.Logger, + updateValidator: opts.Validator, } v2c.ctx, v2c.cancelCtx = context.WithCancel(context.Background()) v2c.TransportHelper = xdsclient.NewTransportHelper(v2c, opts.Logger, opts.Backoff) @@ -89,8 +91,9 @@ type client struct { logger *grpclog.PrefixLogger // ClientConn to the xDS gRPC server. Owned by the parent xdsClient. - cc *grpc.ClientConn - nodeProto *v2corepb.Node + cc *grpc.ClientConn + nodeProto *v2corepb.Node + updateValidator xdsclient.UpdateValidatorFunc } func (v2c *client) NewStream(ctx context.Context) (grpc.ClientStream, error) { @@ -125,7 +128,7 @@ func (v2c *client) SendRequest(s grpc.ClientStream, resourceNames []string, rTyp if err := stream.Send(req); err != nil { return fmt.Errorf("xds: stream.Send(%+v) failed: %v", req, err) } - v2c.logger.Debugf("ADS request sent: %v", req) + v2c.logger.Debugf("ADS request sent: %v", pretty.ToJSON(req)) return nil } @@ -139,11 +142,11 @@ func (v2c *client) RecvResponse(s grpc.ClientStream) (proto.Message, error) { resp, err := stream.Recv() if err != nil { - // TODO: call watch callbacks with error when stream is broken. + v2c.parent.NewConnectionError(err) return nil, fmt.Errorf("xds: stream.Recv() failed: %v", err) } v2c.logger.Infof("ADS response received, type: %v", resp.GetTypeUrl()) - v2c.logger.Debugf("ADS response received: %v", resp) + v2c.logger.Debugf("ADS response received: %v", pretty.ToJSON(resp)) return resp, nil } @@ -185,7 +188,12 @@ func (v2c *client) HandleResponse(r proto.Message) (xdsclient.ResourceType, stri // server. On receipt of a good response, it also invokes the registered watcher // callback. func (v2c *client) handleLDSResponse(resp *v2xdspb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalListener(resp.GetVersionInfo(), resp.GetResources(), v2c.logger) + update, md, err := xdsclient.UnmarshalListener(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v2c.logger, + UpdateValidator: v2c.updateValidator, + }) v2c.parent.NewListeners(update, md) return err } @@ -194,7 +202,12 @@ func (v2c *client) handleLDSResponse(resp *v2xdspb.DiscoveryResponse) error { // server. On receipt of a good response, it caches validated resources and also // invokes the registered watcher callback. func (v2c *client) handleRDSResponse(resp *v2xdspb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalRouteConfig(resp.GetVersionInfo(), resp.GetResources(), v2c.logger) + update, md, err := xdsclient.UnmarshalRouteConfig(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v2c.logger, + UpdateValidator: v2c.updateValidator, + }) v2c.parent.NewRouteConfigs(update, md) return err } @@ -203,13 +216,23 @@ func (v2c *client) handleRDSResponse(resp *v2xdspb.DiscoveryResponse) error { // server. On receipt of a good response, it also invokes the registered watcher // callback. func (v2c *client) handleCDSResponse(resp *v2xdspb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalCluster(resp.GetVersionInfo(), resp.GetResources(), v2c.logger) + update, md, err := xdsclient.UnmarshalCluster(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v2c.logger, + UpdateValidator: v2c.updateValidator, + }) v2c.parent.NewClusters(update, md) return err } func (v2c *client) handleEDSResponse(resp *v2xdspb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalEndpoints(resp.GetVersionInfo(), resp.GetResources(), v2c.logger) + update, md, err := xdsclient.UnmarshalEndpoints(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v2c.logger, + UpdateValidator: v2c.updateValidator, + }) v2c.parent.NewEndpoints(update, md) return err } diff --git a/xds/internal/client/v2/client_test.go b/xds/internal/xdsclient/v2/client_test.go similarity index 88% rename from xds/internal/client/v2/client_test.go rename to xds/internal/xdsclient/v2/client_test.go index e770324e1b1..ed4322b0dc5 100644 --- a/xds/internal/client/v2/client_test.go +++ b/xds/internal/xdsclient/v2/client_test.go @@ -21,12 +21,10 @@ package v2 import ( "context" "errors" - "fmt" "testing" "time" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc" @@ -36,9 +34,9 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" "google.golang.org/protobuf/testing/protocmp" xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" @@ -112,30 +110,24 @@ var ( }, }, } - marshaledConnMgr1, _ = proto.Marshal(goodHTTPConnManager1) - goodListener1 = &xdspb.Listener{ + marshaledConnMgr1 = testutils.MarshalAny(goodHTTPConnManager1) + goodListener1 = &xdspb.Listener{ Name: goodLDSTarget1, ApiListener: &listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, + ApiListener: marshaledConnMgr1, }, } - marshaledListener1, _ = ptypes.MarshalAny(goodListener1) - goodListener2 = &xdspb.Listener{ + marshaledListener1 = testutils.MarshalAny(goodListener1) + goodListener2 = &xdspb.Listener{ Name: goodLDSTarget2, ApiListener: &listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, + ApiListener: marshaledConnMgr1, }, } - marshaledListener2, _ = ptypes.MarshalAny(goodListener2) - noAPIListener = &xdspb.Listener{Name: goodLDSTarget1} - marshaledNoAPIListener, _ = proto.Marshal(noAPIListener) - badAPIListener2 = &xdspb.Listener{ + marshaledListener2 = testutils.MarshalAny(goodListener2) + noAPIListener = &xdspb.Listener{Name: goodLDSTarget1} + marshaledNoAPIListener = testutils.MarshalAny(noAPIListener) + badAPIListener2 = &xdspb.Listener{ Name: goodLDSTarget2, ApiListener: &listenerpb.ApiListener{ ApiListener: &anypb.Any{ @@ -168,13 +160,8 @@ var ( TypeUrl: version.V2ListenerURL, } badResourceTypeInLDSResponse = &xdspb.DiscoveryResponse{ - Resources: []*anypb.Any{ - { - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, - }, - TypeUrl: version.V2ListenerURL, + Resources: []*anypb.Any{marshaledConnMgr1}, + TypeUrl: version.V2ListenerURL, } ldsResponseWithMultipleResources = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ @@ -184,13 +171,8 @@ var ( TypeUrl: version.V2ListenerURL, } noAPIListenerLDSResponse = &xdspb.DiscoveryResponse{ - Resources: []*anypb.Any{ - { - TypeUrl: version.V2ListenerURL, - Value: marshaledNoAPIListener, - }, - }, - TypeUrl: version.V2ListenerURL, + Resources: []*anypb.Any{marshaledNoAPIListener}, + TypeUrl: version.V2ListenerURL, } goodBadUglyLDSResponse = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ @@ -213,19 +195,14 @@ var ( TypeUrl: version.V2RouteConfigURL, } badResourceTypeInRDSResponse = &xdspb.DiscoveryResponse{ - Resources: []*anypb.Any{ - { - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, - }, - TypeUrl: version.V2RouteConfigURL, + Resources: []*anypb.Any{marshaledConnMgr1}, + TypeUrl: version.V2RouteConfigURL, } noVirtualHostsRouteConfig = &xdspb.RouteConfiguration{ Name: goodRouteName1, } - marshaledNoVirtualHostsRouteConfig, _ = ptypes.MarshalAny(noVirtualHostsRouteConfig) - noVirtualHostsInRDSResponse = &xdspb.DiscoveryResponse{ + marshaledNoVirtualHostsRouteConfig = testutils.MarshalAny(noVirtualHostsRouteConfig) + noVirtualHostsInRDSResponse = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ marshaledNoVirtualHostsRouteConfig, }, @@ -262,8 +239,8 @@ var ( }, }, } - marshaledGoodRouteConfig1, _ = ptypes.MarshalAny(goodRouteConfig1) - goodRouteConfig2 = &xdspb.RouteConfiguration{ + marshaledGoodRouteConfig1 = testutils.MarshalAny(goodRouteConfig1) + goodRouteConfig2 = &xdspb.RouteConfiguration{ Name: goodRouteName2, VirtualHosts: []*routepb.VirtualHost{ { @@ -294,8 +271,8 @@ var ( }, }, } - marshaledGoodRouteConfig2, _ = ptypes.MarshalAny(goodRouteConfig2) - goodRDSResponse1 = &xdspb.DiscoveryResponse{ + marshaledGoodRouteConfig2 = testutils.MarshalAny(goodRouteConfig2) + goodRDSResponse1 = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ marshaledGoodRouteConfig1, }, @@ -307,9 +284,6 @@ var ( }, TypeUrl: version.V2RouteConfigURL, } - // An place holder error. When comparing UpdateErrorMetadata, we only check - // if error is nil, and don't compare error content. - errPlaceHolder = fmt.Errorf("err place holder") ) type watchHandleTestcase struct { @@ -327,7 +301,7 @@ type testUpdateReceiver struct { f func(rType xdsclient.ResourceType, d map[string]interface{}, md xdsclient.UpdateMetadata) } -func (t *testUpdateReceiver) NewListeners(d map[string]xdsclient.ListenerUpdate, metadata xdsclient.UpdateMetadata) { +func (t *testUpdateReceiver) NewListeners(d map[string]xdsclient.ListenerUpdateErrTuple, metadata xdsclient.UpdateMetadata) { dd := make(map[string]interface{}) for k, v := range d { dd[k] = v @@ -335,7 +309,7 @@ func (t *testUpdateReceiver) NewListeners(d map[string]xdsclient.ListenerUpdate, t.newUpdate(xdsclient.ListenerResource, dd, metadata) } -func (t *testUpdateReceiver) NewRouteConfigs(d map[string]xdsclient.RouteConfigUpdate, metadata xdsclient.UpdateMetadata) { +func (t *testUpdateReceiver) NewRouteConfigs(d map[string]xdsclient.RouteConfigUpdateErrTuple, metadata xdsclient.UpdateMetadata) { dd := make(map[string]interface{}) for k, v := range d { dd[k] = v @@ -343,7 +317,7 @@ func (t *testUpdateReceiver) NewRouteConfigs(d map[string]xdsclient.RouteConfigU t.newUpdate(xdsclient.RouteConfigResource, dd, metadata) } -func (t *testUpdateReceiver) NewClusters(d map[string]xdsclient.ClusterUpdate, metadata xdsclient.UpdateMetadata) { +func (t *testUpdateReceiver) NewClusters(d map[string]xdsclient.ClusterUpdateErrTuple, metadata xdsclient.UpdateMetadata) { dd := make(map[string]interface{}) for k, v := range d { dd[k] = v @@ -351,7 +325,7 @@ func (t *testUpdateReceiver) NewClusters(d map[string]xdsclient.ClusterUpdate, m t.newUpdate(xdsclient.ClusterResource, dd, metadata) } -func (t *testUpdateReceiver) NewEndpoints(d map[string]xdsclient.EndpointsUpdate, metadata xdsclient.UpdateMetadata) { +func (t *testUpdateReceiver) NewEndpoints(d map[string]xdsclient.EndpointsUpdateErrTuple, metadata xdsclient.UpdateMetadata) { dd := make(map[string]interface{}) for k, v := range d { dd[k] = v @@ -359,6 +333,8 @@ func (t *testUpdateReceiver) NewEndpoints(d map[string]xdsclient.EndpointsUpdate t.newUpdate(xdsclient.EndpointsResource, dd, metadata) } +func (t *testUpdateReceiver) NewConnectionError(error) {} + func (t *testUpdateReceiver) newUpdate(rType xdsclient.ResourceType, d map[string]interface{}, metadata xdsclient.UpdateMetadata) { t.f(rType, d, metadata) } @@ -387,27 +363,27 @@ func testWatchHandle(t *testing.T, test *watchHandleTestcase) { if rType == test.rType { switch test.rType { case xdsclient.ListenerResource: - dd := make(map[string]xdsclient.ListenerUpdate) + dd := make(map[string]xdsclient.ListenerUpdateErrTuple) for n, u := range d { - dd[n] = u.(xdsclient.ListenerUpdate) + dd[n] = u.(xdsclient.ListenerUpdateErrTuple) } gotUpdateCh.Send(updateErr{dd, md, nil}) case xdsclient.RouteConfigResource: - dd := make(map[string]xdsclient.RouteConfigUpdate) + dd := make(map[string]xdsclient.RouteConfigUpdateErrTuple) for n, u := range d { - dd[n] = u.(xdsclient.RouteConfigUpdate) + dd[n] = u.(xdsclient.RouteConfigUpdateErrTuple) } gotUpdateCh.Send(updateErr{dd, md, nil}) case xdsclient.ClusterResource: - dd := make(map[string]xdsclient.ClusterUpdate) + dd := make(map[string]xdsclient.ClusterUpdateErrTuple) for n, u := range d { - dd[n] = u.(xdsclient.ClusterUpdate) + dd[n] = u.(xdsclient.ClusterUpdateErrTuple) } gotUpdateCh.Send(updateErr{dd, md, nil}) case xdsclient.EndpointsResource: - dd := make(map[string]xdsclient.EndpointsUpdate) + dd := make(map[string]xdsclient.EndpointsUpdateErrTuple) for n, u := range d { - dd[n] = u.(xdsclient.EndpointsUpdate) + dd[n] = u.(xdsclient.EndpointsUpdateErrTuple) } gotUpdateCh.Send(updateErr{dd, md, nil}) } @@ -457,7 +433,7 @@ func testWatchHandle(t *testing.T, test *watchHandleTestcase) { cmpopts.EquateEmpty(), protocmp.Transform(), cmpopts.IgnoreFields(xdsclient.UpdateMetadata{}, "Timestamp"), cmpopts.IgnoreFields(xdsclient.UpdateErrorMetadata{}, "Timestamp"), - cmp.Comparer(func(x, y error) bool { return (x == nil) == (y == nil) }), + cmp.FilterValues(func(x, y error) bool { return true }, cmpopts.EquateErrors()), } uErr, err := gotUpdateCh.Receive(ctx) if err == context.DeadlineExceeded { @@ -546,7 +522,7 @@ func (s) TestV2ClientBackoffAfterRecvError(t *testing.T) { fakeServer.XDSResponseChan <- &fakeserver.Response{Err: errors.New("RPC error")} t.Log("Bad LDS response pushed to fakeServer...") - timer := time.NewTimer(defaultTestShortTimeout) + timer := time.NewTimer(defaultTestTimeout) select { case <-timer.C: t.Fatal("Timeout when expecting LDS update") @@ -688,7 +664,7 @@ func (s) TestV2ClientWatchWithoutStream(t *testing.T) { if v, err := callbackCh.Receive(ctx); err != nil { t.Fatal("Timeout when expecting LDS update") - } else if _, ok := v.(xdsclient.ListenerUpdate); !ok { + } else if _, ok := v.(xdsclient.ListenerUpdateErrTuple); !ok { t.Fatalf("Expect an LDS update from watcher, got %v", v) } } diff --git a/xds/internal/client/v2/eds_test.go b/xds/internal/xdsclient/v2/eds_test.go similarity index 85% rename from xds/internal/client/v2/eds_test.go rename to xds/internal/xdsclient/v2/eds_test.go index 0990e7ebae0..8176b6dfb93 100644 --- a/xds/internal/client/v2/eds_test.go +++ b/xds/internal/xdsclient/v2/eds_test.go @@ -23,12 +23,13 @@ import ( "time" v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" - "github.com/golang/protobuf/ptypes" anypb "github.com/golang/protobuf/ptypes/any" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/testutils" + xtestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" ) var ( @@ -42,20 +43,14 @@ var ( TypeUrl: version.V2EndpointsURL, } badResourceTypeInEDSResponse = &v2xdspb.DiscoveryResponse{ - Resources: []*anypb.Any{ - { - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, - }, - TypeUrl: version.V2EndpointsURL, + Resources: []*anypb.Any{marshaledConnMgr1}, + TypeUrl: version.V2EndpointsURL, } marshaledGoodCLA1 = func() *anypb.Any { - clab0 := testutils.NewClusterLoadAssignmentBuilder(goodEDSName, nil) + clab0 := xtestutils.NewClusterLoadAssignmentBuilder(goodEDSName, nil) clab0.AddLocality("locality-1", 1, 1, []string{"addr1:314"}, nil) clab0.AddLocality("locality-2", 1, 0, []string{"addr2:159"}, nil) - a, _ := ptypes.MarshalAny(clab0.Build()) - return a + return testutils.MarshalAny(clab0.Build()) }() goodEDSResponse1 = &v2xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ @@ -64,10 +59,9 @@ var ( TypeUrl: version.V2EndpointsURL, } marshaledGoodCLA2 = func() *anypb.Any { - clab0 := testutils.NewClusterLoadAssignmentBuilder("not-goodEDSName", nil) + clab0 := xtestutils.NewClusterLoadAssignmentBuilder("not-goodEDSName", nil) clab0.AddLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) - a, _ := ptypes.MarshalAny(clab0.Build()) - return a + return testutils.MarshalAny(clab0.Build()) }() goodEDSResponse2 = &v2xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ @@ -82,7 +76,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { name string edsResponse *v2xdspb.DiscoveryResponse wantErr bool - wantUpdate map[string]xdsclient.EndpointsUpdate + wantUpdate map[string]xdsclient.EndpointsUpdateErrTuple wantUpdateMD xdsclient.UpdateMetadata wantUpdateErr bool }{ @@ -95,7 +89,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -109,7 +103,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -119,8 +113,8 @@ func (s) TestEDSHandleResponse(t *testing.T) { name: "one-uninterestring-assignment", edsResponse: goodEDSResponse2, wantErr: false, - wantUpdate: map[string]xdsclient.EndpointsUpdate{ - "not-goodEDSName": { + wantUpdate: map[string]xdsclient.EndpointsUpdateErrTuple{ + "not-goodEDSName": {Update: xdsclient.EndpointsUpdate{ Localities: []xdsclient.Locality{ { Endpoints: []xdsclient.Endpoint{{Address: "addr1:314"}}, @@ -130,7 +124,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { }, }, Raw: marshaledGoodCLA2, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -142,8 +136,8 @@ func (s) TestEDSHandleResponse(t *testing.T) { name: "one-good-assignment", edsResponse: goodEDSResponse1, wantErr: false, - wantUpdate: map[string]xdsclient.EndpointsUpdate{ - goodEDSName: { + wantUpdate: map[string]xdsclient.EndpointsUpdateErrTuple{ + goodEDSName: {Update: xdsclient.EndpointsUpdate{ Localities: []xdsclient.Locality{ { Endpoints: []xdsclient.Endpoint{{Address: "addr1:314"}}, @@ -159,7 +153,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { }, }, Raw: marshaledGoodCLA1, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, diff --git a/xds/internal/client/v2/lds_test.go b/xds/internal/xdsclient/v2/lds_test.go similarity index 81% rename from xds/internal/client/v2/lds_test.go rename to xds/internal/xdsclient/v2/lds_test.go index 1f4c980fae5..a0600550095 100644 --- a/xds/internal/client/v2/lds_test.go +++ b/xds/internal/xdsclient/v2/lds_test.go @@ -23,8 +23,9 @@ import ( "time" v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" + "github.com/google/go-cmp/cmp/cmpopts" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/xdsclient" ) // TestLDSHandleResponse starts a fake xDS server, makes a ClientConn to it, @@ -35,7 +36,7 @@ func (s) TestLDSHandleResponse(t *testing.T) { name string ldsResponse *v2xdspb.DiscoveryResponse wantErr bool - wantUpdate map[string]xdsclient.ListenerUpdate + wantUpdate map[string]xdsclient.ListenerUpdateErrTuple wantUpdateMD xdsclient.UpdateMetadata wantUpdateErr bool }{ @@ -48,7 +49,7 @@ func (s) TestLDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -62,7 +63,7 @@ func (s) TestLDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -74,13 +75,13 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "no-apiListener-in-response", ldsResponse: noAPIListenerLDSResponse, wantErr: true, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget1: {}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget1: {Err: cmpopts.AnyError}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -90,8 +91,8 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "one-good-listener", ldsResponse: goodLDSResponse1, wantErr: false, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget1: {RouteConfigName: goodRouteName1, Raw: marshaledListener1}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget1: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener1}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -104,9 +105,9 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "multiple-good-listener", ldsResponse: ldsResponseWithMultipleResources, wantErr: false, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget1: {RouteConfigName: goodRouteName1, Raw: marshaledListener1}, - goodLDSTarget2: {RouteConfigName: goodRouteName1, Raw: marshaledListener2}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget1: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener1}}, + goodLDSTarget2: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener2}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -120,14 +121,14 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "good-bad-ugly-listeners", ldsResponse: goodBadUglyLDSResponse, wantErr: true, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget1: {RouteConfigName: goodRouteName1, Raw: marshaledListener1}, - goodLDSTarget2: {}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget1: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener1}}, + goodLDSTarget2: {Err: cmpopts.AnyError}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -137,8 +138,8 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "one-uninteresting-listener", ldsResponse: goodLDSResponse2, wantErr: false, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget2: {RouteConfigName: goodRouteName1, Raw: marshaledListener2}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget2: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener2}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, diff --git a/xds/internal/client/v2/loadreport.go b/xds/internal/xdsclient/v2/loadreport.go similarity index 88% rename from xds/internal/client/v2/loadreport.go rename to xds/internal/xdsclient/v2/loadreport.go index 69405fcd9ad..f0034e21c35 100644 --- a/xds/internal/client/v2/loadreport.go +++ b/xds/internal/xdsclient/v2/loadreport.go @@ -26,7 +26,8 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/xds/internal/xdsclient/load" v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" v2endpointpb "github.com/envoyproxy/go-control-plane/envoy/api/v2/endpoint" @@ -57,7 +58,7 @@ func (v2c *client) SendFirstLoadStatsRequest(s grpc.ClientStream) error { node.ClientFeatures = append(node.ClientFeatures, clientFeatureLRSSendAllClusters) req := &lrspb.LoadStatsRequest{Node: node} - v2c.logger.Infof("lrs: sending init LoadStatsRequest: %v", req) + v2c.logger.Infof("lrs: sending init LoadStatsRequest: %v", pretty.ToJSON(req)) return stream.Send(req) } @@ -71,7 +72,7 @@ func (v2c *client) HandleLoadStatsResponse(s grpc.ClientStream) ([]string, time. if err != nil { return nil, 0, fmt.Errorf("lrs: failed to receive first response: %v", err) } - v2c.logger.Infof("lrs: received first LoadStatsResponse: %+v", resp) + v2c.logger.Infof("lrs: received first LoadStatsResponse: %+v", pretty.ToJSON(resp)) interval, err := ptypes.Duration(resp.GetLoadReportingInterval()) if err != nil { @@ -98,24 +99,22 @@ func (v2c *client) SendLoadStatsRequest(s grpc.ClientStream, loads []*load.Data) return fmt.Errorf("lrs: Attempt to send request on unsupported stream type: %T", s) } - var clusterStats []*v2endpointpb.ClusterStats + clusterStats := make([]*v2endpointpb.ClusterStats, 0, len(loads)) for _, sd := range loads { - var ( - droppedReqs []*v2endpointpb.ClusterStats_DroppedRequests - localityStats []*v2endpointpb.UpstreamLocalityStats - ) + droppedReqs := make([]*v2endpointpb.ClusterStats_DroppedRequests, 0, len(sd.Drops)) for category, count := range sd.Drops { droppedReqs = append(droppedReqs, &v2endpointpb.ClusterStats_DroppedRequests{ Category: category, DroppedCount: count, }) } + localityStats := make([]*v2endpointpb.UpstreamLocalityStats, 0, len(sd.LocalityStats)) for l, localityData := range sd.LocalityStats { lid, err := internal.LocalityIDFromString(l) if err != nil { return err } - var loadMetricStats []*v2endpointpb.EndpointLoadMetricStats + loadMetricStats := make([]*v2endpointpb.EndpointLoadMetricStats, 0, len(localityData.LoadStats)) for name, loadData := range localityData.LoadStats { loadMetricStats = append(loadMetricStats, &v2endpointpb.EndpointLoadMetricStats{ MetricName: name, @@ -149,6 +148,6 @@ func (v2c *client) SendLoadStatsRequest(s grpc.ClientStream, loads []*load.Data) } req := &lrspb.LoadStatsRequest{ClusterStats: clusterStats} - v2c.logger.Infof("lrs: sending LRS loads: %+v", req) + v2c.logger.Infof("lrs: sending LRS loads: %+v", pretty.ToJSON(req)) return stream.Send(req) } diff --git a/xds/internal/client/v2/rds_test.go b/xds/internal/xdsclient/v2/rds_test.go similarity index 77% rename from xds/internal/client/v2/rds_test.go rename to xds/internal/xdsclient/v2/rds_test.go index dd145158b8a..3389f053946 100644 --- a/xds/internal/client/v2/rds_test.go +++ b/xds/internal/xdsclient/v2/rds_test.go @@ -24,9 +24,10 @@ import ( "time" xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" + "github.com/google/go-cmp/cmp/cmpopts" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeserver" + "google.golang.org/grpc/xds/internal/xdsclient" ) // doLDS makes a LDS watch, and waits for the response and ack to finish. @@ -49,7 +50,7 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { name string rdsResponse *xdspb.DiscoveryResponse wantErr bool - wantUpdate map[string]xdsclient.RouteConfigUpdate + wantUpdate map[string]xdsclient.RouteConfigUpdateErrTuple wantUpdateMD xdsclient.UpdateMetadata wantUpdateErr bool }{ @@ -62,7 +63,7 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -76,23 +77,23 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, }, - // No VirtualHosts in the response. Just one test case here for a bad + // No virtualHosts in the response. Just one test case here for a bad // RouteConfiguration, since the others are covered in // TestGetClusterFromRouteConfiguration. { name: "no-virtual-hosts-in-response", rdsResponse: noVirtualHostsInRDSResponse, wantErr: false, - wantUpdate: map[string]xdsclient.RouteConfigUpdate{ - goodRouteName1: { + wantUpdate: map[string]xdsclient.RouteConfigUpdateErrTuple{ + goodRouteName1: {Update: xdsclient.RouteConfigUpdate{ VirtualHosts: nil, Raw: marshaledNoVirtualHostsRouteConfig, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -104,20 +105,25 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { name: "one-uninteresting-route-config", rdsResponse: goodRDSResponse2, wantErr: false, - wantUpdate: map[string]xdsclient.RouteConfigUpdate{ - goodRouteName2: { + wantUpdate: map[string]xdsclient.RouteConfigUpdateErrTuple{ + goodRouteName2: {Update: xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: xdsclient.RouteActionRoute}}, }, { Domains: []string{goodLDSTarget1}, - Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{goodClusterName2: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{ + Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{goodClusterName2: {Weight: 1}}, + RouteAction: xdsclient.RouteActionRoute}}, }, }, Raw: marshaledGoodRouteConfig2, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -129,20 +135,25 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { name: "one-good-route-config", rdsResponse: goodRDSResponse1, wantErr: false, - wantUpdate: map[string]xdsclient.RouteConfigUpdate{ - goodRouteName1: { + wantUpdate: map[string]xdsclient.RouteConfigUpdateErrTuple{ + goodRouteName1: {Update: xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{ + Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: xdsclient.RouteActionRoute}}, }, { Domains: []string{goodLDSTarget1}, - Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{goodClusterName1: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{goodClusterName1: {Weight: 1}}, + RouteAction: xdsclient.RouteActionRoute}}, }, }, Raw: marshaledGoodRouteConfig1, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, diff --git a/xds/internal/client/v3/client.go b/xds/internal/xdsclient/v3/client.go similarity index 82% rename from xds/internal/client/v3/client.go rename to xds/internal/xdsclient/v3/client.go index 55cae56d8cc..827c06b741b 100644 --- a/xds/internal/client/v3/client.go +++ b/xds/internal/xdsclient/v3/client.go @@ -28,8 +28,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpclog" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3adsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" @@ -65,10 +66,11 @@ func newClient(cc *grpc.ClientConn, opts xdsclient.BuildOptions) (xdsclient.APIC return nil, fmt.Errorf("xds: unsupported Node proto type: %T, want %T", opts.NodeProto, v3corepb.Node{}) } v3c := &client{ - cc: cc, - parent: opts.Parent, - nodeProto: nodeProto, - logger: opts.Logger, + cc: cc, + parent: opts.Parent, + nodeProto: nodeProto, + logger: opts.Logger, + updateValidator: opts.Validator, } v3c.ctx, v3c.cancelCtx = context.WithCancel(context.Background()) v3c.TransportHelper = xdsclient.NewTransportHelper(v3c, opts.Logger, opts.Backoff) @@ -89,8 +91,9 @@ type client struct { logger *grpclog.PrefixLogger // ClientConn to the xDS gRPC server. Owned by the parent xdsClient. - cc *grpc.ClientConn - nodeProto *v3corepb.Node + cc *grpc.ClientConn + nodeProto *v3corepb.Node + updateValidator xdsclient.UpdateValidatorFunc } func (v3c *client) NewStream(ctx context.Context) (grpc.ClientStream, error) { @@ -125,7 +128,7 @@ func (v3c *client) SendRequest(s grpc.ClientStream, resourceNames []string, rTyp if err := stream.Send(req); err != nil { return fmt.Errorf("xds: stream.Send(%+v) failed: %v", req, err) } - v3c.logger.Debugf("ADS request sent: %v", req) + v3c.logger.Debugf("ADS request sent: %v", pretty.ToJSON(req)) return nil } @@ -139,11 +142,11 @@ func (v3c *client) RecvResponse(s grpc.ClientStream) (proto.Message, error) { resp, err := stream.Recv() if err != nil { - // TODO: call watch callbacks with error when stream is broken. + v3c.parent.NewConnectionError(err) return nil, fmt.Errorf("xds: stream.Recv() failed: %v", err) } v3c.logger.Infof("ADS response received, type: %v", resp.GetTypeUrl()) - v3c.logger.Debugf("ADS response received: %+v", resp) + v3c.logger.Debugf("ADS response received: %+v", pretty.ToJSON(resp)) return resp, nil } @@ -185,7 +188,12 @@ func (v3c *client) HandleResponse(r proto.Message) (xdsclient.ResourceType, stri // server. On receipt of a good response, it also invokes the registered watcher // callback. func (v3c *client) handleLDSResponse(resp *v3discoverypb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalListener(resp.GetVersionInfo(), resp.GetResources(), v3c.logger) + update, md, err := xdsclient.UnmarshalListener(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v3c.logger, + UpdateValidator: v3c.updateValidator, + }) v3c.parent.NewListeners(update, md) return err } @@ -194,7 +202,12 @@ func (v3c *client) handleLDSResponse(resp *v3discoverypb.DiscoveryResponse) erro // server. On receipt of a good response, it caches validated resources and also // invokes the registered watcher callback. func (v3c *client) handleRDSResponse(resp *v3discoverypb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalRouteConfig(resp.GetVersionInfo(), resp.GetResources(), v3c.logger) + update, md, err := xdsclient.UnmarshalRouteConfig(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v3c.logger, + UpdateValidator: v3c.updateValidator, + }) v3c.parent.NewRouteConfigs(update, md) return err } @@ -203,13 +216,23 @@ func (v3c *client) handleRDSResponse(resp *v3discoverypb.DiscoveryResponse) erro // server. On receipt of a good response, it also invokes the registered watcher // callback. func (v3c *client) handleCDSResponse(resp *v3discoverypb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalCluster(resp.GetVersionInfo(), resp.GetResources(), v3c.logger) + update, md, err := xdsclient.UnmarshalCluster(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v3c.logger, + UpdateValidator: v3c.updateValidator, + }) v3c.parent.NewClusters(update, md) return err } func (v3c *client) handleEDSResponse(resp *v3discoverypb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalEndpoints(resp.GetVersionInfo(), resp.GetResources(), v3c.logger) + update, md, err := xdsclient.UnmarshalEndpoints(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v3c.logger, + UpdateValidator: v3c.updateValidator, + }) v3c.parent.NewEndpoints(update, md) return err } diff --git a/xds/internal/client/v3/loadreport.go b/xds/internal/xdsclient/v3/loadreport.go similarity index 88% rename from xds/internal/client/v3/loadreport.go rename to xds/internal/xdsclient/v3/loadreport.go index 74e18632aa0..8cdb5476fbb 100644 --- a/xds/internal/client/v3/loadreport.go +++ b/xds/internal/xdsclient/v3/loadreport.go @@ -26,7 +26,8 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/xds/internal/xdsclient/load" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" @@ -57,7 +58,7 @@ func (v3c *client) SendFirstLoadStatsRequest(s grpc.ClientStream) error { node.ClientFeatures = append(node.ClientFeatures, clientFeatureLRSSendAllClusters) req := &lrspb.LoadStatsRequest{Node: node} - v3c.logger.Infof("lrs: sending init LoadStatsRequest: %v", req) + v3c.logger.Infof("lrs: sending init LoadStatsRequest: %v", pretty.ToJSON(req)) return stream.Send(req) } @@ -71,7 +72,7 @@ func (v3c *client) HandleLoadStatsResponse(s grpc.ClientStream) ([]string, time. if err != nil { return nil, 0, fmt.Errorf("lrs: failed to receive first response: %v", err) } - v3c.logger.Infof("lrs: received first LoadStatsResponse: %+v", resp) + v3c.logger.Infof("lrs: received first LoadStatsResponse: %+v", pretty.ToJSON(resp)) interval, err := ptypes.Duration(resp.GetLoadReportingInterval()) if err != nil { @@ -98,24 +99,22 @@ func (v3c *client) SendLoadStatsRequest(s grpc.ClientStream, loads []*load.Data) return fmt.Errorf("lrs: Attempt to send request on unsupported stream type: %T", s) } - var clusterStats []*v3endpointpb.ClusterStats + clusterStats := make([]*v3endpointpb.ClusterStats, 0, len(loads)) for _, sd := range loads { - var ( - droppedReqs []*v3endpointpb.ClusterStats_DroppedRequests - localityStats []*v3endpointpb.UpstreamLocalityStats - ) + droppedReqs := make([]*v3endpointpb.ClusterStats_DroppedRequests, 0, len(sd.Drops)) for category, count := range sd.Drops { droppedReqs = append(droppedReqs, &v3endpointpb.ClusterStats_DroppedRequests{ Category: category, DroppedCount: count, }) } + localityStats := make([]*v3endpointpb.UpstreamLocalityStats, 0, len(sd.LocalityStats)) for l, localityData := range sd.LocalityStats { lid, err := internal.LocalityIDFromString(l) if err != nil { return err } - var loadMetricStats []*v3endpointpb.EndpointLoadMetricStats + loadMetricStats := make([]*v3endpointpb.EndpointLoadMetricStats, 0, len(localityData.LoadStats)) for name, loadData := range localityData.LoadStats { loadMetricStats = append(loadMetricStats, &v3endpointpb.EndpointLoadMetricStats{ MetricName: name, @@ -148,6 +147,6 @@ func (v3c *client) SendLoadStatsRequest(s grpc.ClientStream, loads []*load.Data) } req := &lrspb.LoadStatsRequest{ClusterStats: clusterStats} - v3c.logger.Infof("lrs: sending LRS loads: %+v", req) + v3c.logger.Infof("lrs: sending LRS loads: %+v", pretty.ToJSON(req)) return stream.Send(req) } diff --git a/xds/internal/client/watchers.go b/xds/internal/xdsclient/watchers.go similarity index 95% rename from xds/internal/client/watchers.go rename to xds/internal/xdsclient/watchers.go index 9fafe5a60f8..e26ed360308 100644 --- a/xds/internal/client/watchers.go +++ b/xds/internal/xdsclient/watchers.go @@ -16,12 +16,14 @@ * */ -package client +package xdsclient import ( "fmt" "sync" "time" + + "google.golang.org/grpc/internal/pretty" ) type watchInfoState int @@ -64,6 +66,17 @@ func (wi *watchInfo) newUpdate(update interface{}) { wi.c.scheduleCallback(wi, update, nil) } +func (wi *watchInfo) newError(err error) { + wi.mu.Lock() + defer wi.mu.Unlock() + if wi.state == watchInfoStateCanceled { + return + } + wi.state = watchInfoStateRespReceived + wi.expiryTimer.Stop() + wi.sendErrorLocked(err) +} + func (wi *watchInfo) resourceNotFound() { wi.mu.Lock() defer wi.mu.Unlock() @@ -161,22 +174,22 @@ func (c *clientImpl) watch(wi *watchInfo) (cancel func()) { switch wi.rType { case ListenerResource: if v, ok := c.ldsCache[resourceName]; ok { - c.logger.Debugf("LDS resource with name %v found in cache: %+v", wi.target, v) + c.logger.Debugf("LDS resource with name %v found in cache: %+v", wi.target, pretty.ToJSON(v)) wi.newUpdate(v) } case RouteConfigResource: if v, ok := c.rdsCache[resourceName]; ok { - c.logger.Debugf("RDS resource with name %v found in cache: %+v", wi.target, v) + c.logger.Debugf("RDS resource with name %v found in cache: %+v", wi.target, pretty.ToJSON(v)) wi.newUpdate(v) } case ClusterResource: if v, ok := c.cdsCache[resourceName]; ok { - c.logger.Debugf("CDS resource with name %v found in cache: %+v", wi.target, v) + c.logger.Debugf("CDS resource with name %v found in cache: %+v", wi.target, pretty.ToJSON(v)) wi.newUpdate(v) } case EndpointsResource: if v, ok := c.edsCache[resourceName]; ok { - c.logger.Debugf("EDS resource with name %v found in cache: %+v", wi.target, v) + c.logger.Debugf("EDS resource with name %v found in cache: %+v", wi.target, pretty.ToJSON(v)) wi.newUpdate(v) } } diff --git a/xds/internal/client/watchers_cluster_test.go b/xds/internal/xdsclient/watchers_cluster_test.go similarity index 58% rename from xds/internal/client/watchers_cluster_test.go rename to xds/internal/xdsclient/watchers_cluster_test.go index fdef0cf6164..c06319e959c 100644 --- a/xds/internal/client/watchers_cluster_test.go +++ b/xds/internal/xdsclient/watchers_cluster_test.go @@ -16,22 +16,19 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc/internal/testutils" ) -type clusterUpdateErr struct { - u ClusterUpdate - err error -} - // TestClusterWatch covers the cases: // - an update is received after a watch() // - an update for another resource name @@ -56,30 +53,34 @@ func (s) TestClusterWatch(t *testing.T) { clusterUpdateCh := testutils.NewChannel() cancelWatch := client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } - // Another update, with an extra resource for a different resource name. - client.NewClusters(map[string]ClusterUpdate{ - testCDSName: wantUpdate, + // Push an update, with an extra resource for a different resource name. + // Specify a non-nil raw proto in the original resource to ensure that the + // new update is not considered equal to the old one. + newUpdate := wantUpdate + newUpdate.Raw = &anypb.Any{} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName: {Update: newUpdate}, "randomName": {}, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, newUpdate, nil); err != nil { t.Fatal(err) } // Cancel watch, and send update again. cancelWatch() - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() if u, err := clusterUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { @@ -114,7 +115,7 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) { clusterUpdateCh := testutils.NewChannel() clusterUpdateChs = append(clusterUpdateChs, clusterUpdateCh) cancelLastWatch = client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -126,27 +127,36 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) { } } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } - // Cancel the last watch, and send update again. + // Cancel the last watch, and send update again. None of the watchers should + // be notified because one has been cancelled, and the other is receiving + // the same update. cancelLastWatch() - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count-1; i++ { - if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + func() { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := clusterUpdateChs[i].Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ClusterUpdate: %v, %v, want channel recv timeout", u, err) + } + }() } - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := clusterUpdateChs[count-1].Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err) + // Push a new update and make sure the uncancelled watcher is invoked. + // Specify a non-nil raw proto to ensure that the new update is not + // considered equal to the old one. + newUpdate := ClusterUpdate{ClusterName: testEDSName, Raw: &anypb.Any{}} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: newUpdate}}, UpdateMetadata{}) + if err := verifyClusterUpdate(ctx, clusterUpdateChs[0], newUpdate, nil); err != nil { + t.Fatal(err) } } @@ -177,7 +187,7 @@ func (s) TestClusterThreeWatchDifferentResourceName(t *testing.T) { clusterUpdateCh := testutils.NewChannel() clusterUpdateChs = append(clusterUpdateChs, clusterUpdateCh) client.WatchCluster(testCDSName+"1", func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -192,25 +202,25 @@ func (s) TestClusterThreeWatchDifferentResourceName(t *testing.T) { // Third watch for a different name. clusterUpdateCh2 := testutils.NewChannel() client.WatchCluster(testCDSName+"2", func(update ClusterUpdate, err error) { - clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh2.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate1 := ClusterUpdate{ServiceName: testEDSName + "1"} - wantUpdate2 := ClusterUpdate{ServiceName: testEDSName + "2"} - client.NewClusters(map[string]ClusterUpdate{ - testCDSName + "1": wantUpdate1, - testCDSName + "2": wantUpdate2, + wantUpdate1 := ClusterUpdate{ClusterName: testEDSName + "1"} + wantUpdate2 := ClusterUpdate{ClusterName: testEDSName + "2"} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName + "1": {Update: wantUpdate1}, + testCDSName + "2": {Update: wantUpdate2}, }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate1); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -237,24 +247,24 @@ func (s) TestClusterWatchAfterCache(t *testing.T) { clusterUpdateCh := testutils.NewChannel() client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{ - testCDSName: wantUpdate, + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName: {Update: wantUpdate}, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } // Another watch for the resource in cache. clusterUpdateCh2 := testutils.NewChannel() client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh2.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() @@ -263,7 +273,7 @@ func (s) TestClusterWatchAfterCache(t *testing.T) { } // New watch should receives the update. - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -298,7 +308,7 @@ func (s) TestClusterWatchExpiryTimer(t *testing.T) { clusterUpdateCh := testutils.NewChannel() client.WatchCluster(testCDSName, func(u ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: u, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: u, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -308,9 +318,9 @@ func (s) TestClusterWatchExpiryTimer(t *testing.T) { if err != nil { t.Fatalf("timeout when waiting for cluster update: %v", err) } - gotUpdate := u.(clusterUpdateErr) - if gotUpdate.err == nil || !cmp.Equal(gotUpdate.u, ClusterUpdate{}) { - t.Fatalf("unexpected clusterUpdate: (%v, %v), want: (ClusterUpdate{}, nil)", gotUpdate.u, gotUpdate.err) + gotUpdate := u.(ClusterUpdateErrTuple) + if gotUpdate.Err == nil || !cmp.Equal(gotUpdate.Update, ClusterUpdate{}) { + t.Fatalf("unexpected clusterUpdate: (%v, %v), want: (ClusterUpdate{}, nil)", gotUpdate.Update, gotUpdate.Err) } } @@ -337,17 +347,17 @@ func (s) TestClusterWatchExpiryTimerStop(t *testing.T) { clusterUpdateCh := testutils.NewChannel() client.WatchCluster(testCDSName, func(u ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: u, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: u, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{ - testCDSName: wantUpdate, + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName: {Update: wantUpdate}, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -385,7 +395,7 @@ func (s) TestClusterResourceRemoved(t *testing.T) { clusterUpdateCh1 := testutils.NewChannel() client.WatchCluster(testCDSName+"1", func(update ClusterUpdate, err error) { - clusterUpdateCh1.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh1.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -394,50 +404,152 @@ func (s) TestClusterResourceRemoved(t *testing.T) { // Another watch for a different name. clusterUpdateCh2 := testutils.NewChannel() client.WatchCluster(testCDSName+"2", func(update ClusterUpdate, err error) { - clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh2.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate1 := ClusterUpdate{ServiceName: testEDSName + "1"} - wantUpdate2 := ClusterUpdate{ServiceName: testEDSName + "2"} - client.NewClusters(map[string]ClusterUpdate{ - testCDSName + "1": wantUpdate1, - testCDSName + "2": wantUpdate2, + wantUpdate1 := ClusterUpdate{ClusterName: testEDSName + "1"} + wantUpdate2 := ClusterUpdate{ClusterName: testEDSName + "2"} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName + "1": {Update: wantUpdate1}, + testCDSName + "2": {Update: wantUpdate2}, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh1, wantUpdate1); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh1, wantUpdate1, nil); err != nil { t.Fatal(err) } - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } // Send another update to remove resource 1. - client.NewClusters(map[string]ClusterUpdate{testCDSName + "2": wantUpdate2}, UpdateMetadata{}) + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName + "2": {Update: wantUpdate2}}, UpdateMetadata{}) // Watcher 1 should get an error. - if u, err := clusterUpdateCh1.Receive(ctx); err != nil || ErrType(u.(clusterUpdateErr).err) != ErrorTypeResourceNotFound { + if u, err := clusterUpdateCh1.Receive(ctx); err != nil || ErrType(u.(ClusterUpdateErrTuple).Err) != ErrorTypeResourceNotFound { t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err) } - // Watcher 2 should get the same update again. - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) + // Watcher 2 should not see an update since the resource has not changed. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := clusterUpdateCh2.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ClusterUpdate: %v, want receiving from channel timeout", u) } - // Send one more update without resource 1. - client.NewClusters(map[string]ClusterUpdate{testCDSName + "2": wantUpdate2}, UpdateMetadata{}) + // Send another update with resource 2 modified. Specify a non-nil raw proto + // to ensure that the new update is not considered equal to the old one. + wantUpdate2 = ClusterUpdate{ClusterName: testEDSName + "2", Raw: &anypb.Any{}} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName + "2": {Update: wantUpdate2}}, UpdateMetadata{}) // Watcher 1 should not see an update. - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() if u, err := clusterUpdateCh1.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err) + t.Errorf("unexpected Cluster: %v, want receiving from channel timeout", u) + } + + // Watcher 2 should get the update. + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } +} + +// TestClusterWatchNACKError covers the case that an update is NACK'ed, and the +// watcher should also receive the error. +func (s) TestClusterWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + clusterUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: { + Err: wantError, + }}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyClusterUpdate(ctx, clusterUpdateCh, ClusterUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} + +// TestClusterWatchPartialValid covers the case that a response contains both +// valid and invalid resources. This response will be NACK'ed by the xdsclient. +// But the watchers with valid resources should receive the update, those with +// invalida resources should receive an error. +func (s) TestClusterWatchPartialValid(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const badResourceName = "bad-resource" + updateChs := make(map[string]*testutils.Channel) + + for _, name := range []string{testCDSName, badResourceName} { + clusterUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchCluster(name, func(update ClusterUpdate, err error) { + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) + }) + defer func() { + cancelWatch() + if _, err := apiClient.removeWatches[ClusterResource].Receive(ctx); err != nil { + t.Fatalf("want watch to be canceled, got err: %v", err) + } + }() + if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + updateChs[name] = clusterUpdateCh + } + + wantError := fmt.Errorf("testing error") + wantError2 := fmt.Errorf("individual error") + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName: {Update: ClusterUpdate{ClusterName: testEDSName}}, + badResourceName: {Err: wantError2}, + }, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + + // The valid resource should be sent to the watcher. + if err := verifyClusterUpdate(ctx, updateChs[testCDSName], ClusterUpdate{ClusterName: testEDSName}, nil); err != nil { + t.Fatal(err) } - // Watcher 2 should get the same update again. - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + // The failed watcher should receive an error. + if err := verifyClusterUpdate(ctx, updateChs[badResourceName], ClusterUpdate{}, wantError2); err != nil { t.Fatal(err) } } diff --git a/xds/internal/client/watchers_endpoints_test.go b/xds/internal/xdsclient/watchers_endpoints_test.go similarity index 57% rename from xds/internal/client/watchers_endpoints_test.go rename to xds/internal/xdsclient/watchers_endpoints_test.go index b79397414d4..b87723e5086 100644 --- a/xds/internal/client/watchers_endpoints_test.go +++ b/xds/internal/xdsclient/watchers_endpoints_test.go @@ -16,13 +16,15 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal" @@ -45,11 +47,6 @@ var ( } ) -type endpointsUpdateErr struct { - u EndpointsUpdate - err error -} - // TestEndpointsWatch covers the cases: // - an update is received after a watch() // - an update for another resource name (which doesn't trigger callback) @@ -74,30 +71,35 @@ func (s) TestEndpointsWatch(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() cancelWatch := client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } wantUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate); err != nil { + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } - // Another update for a different resource name. - client.NewEndpoints(map[string]EndpointsUpdate{"randomName": {}}, UpdateMetadata{}) - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := endpointsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) + // Push an update, with an extra resource for a different resource name. + // Specify a non-nil raw proto in the original resource to ensure that the + // new update is not considered equal to the old one. + newUpdate := wantUpdate + newUpdate.Raw = &anypb.Any{} + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{ + testCDSName: {Update: newUpdate}, + "randomName": {}, + }, UpdateMetadata{}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, newUpdate, nil); err != nil { + t.Fatal(err) } // Cancel watch, and send update again. cancelWatch() - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() if u, err := endpointsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) @@ -133,7 +135,7 @@ func (s) TestEndpointsTwoWatchSameResourceName(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() endpointsUpdateChs = append(endpointsUpdateChs, endpointsUpdateCh) cancelLastWatch = client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -146,26 +148,35 @@ func (s) TestEndpointsTwoWatchSameResourceName(t *testing.T) { } wantUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } - // Cancel the last watch, and send update again. + // Cancel the last watch, and send update again. None of the watchers should + // be notified because one has been cancelled, and the other is receiving + // the same update. cancelLastWatch() - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count-1; i++ { - if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + func() { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := endpointsUpdateChs[i].Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) + } + }() } - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := endpointsUpdateChs[count-1].Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) + // Push a new update and make sure the uncancelled watcher is invoked. + // Specify a non-nil raw proto to ensure that the new update is not + // considered equal to the old one. + newUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}, Raw: &anypb.Any{}} + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: newUpdate}}, UpdateMetadata{}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[0], newUpdate, nil); err != nil { + t.Fatal(err) } } @@ -196,7 +207,7 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() endpointsUpdateChs = append(endpointsUpdateChs, endpointsUpdateCh) client.WatchEndpoints(testCDSName+"1", func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -211,7 +222,7 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) { // Third watch for a different name. endpointsUpdateCh2 := testutils.NewChannel() client.WatchEndpoints(testCDSName+"2", func(update EndpointsUpdate, err error) { - endpointsUpdateCh2.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh2.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -219,17 +230,17 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) { wantUpdate1 := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} wantUpdate2 := EndpointsUpdate{Localities: []Locality{testLocalities[1]}} - client.NewEndpoints(map[string]EndpointsUpdate{ - testCDSName + "1": wantUpdate1, - testCDSName + "2": wantUpdate2, + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{ + testCDSName + "1": {Update: wantUpdate1}, + testCDSName + "2": {Update: wantUpdate2}, }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate1); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate2); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -256,22 +267,22 @@ func (s) TestEndpointsWatchAfterCache(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } wantUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate); err != nil { + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } // Another watch for the resource in cache. endpointsUpdateCh2 := testutils.NewChannel() client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh2.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh2.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() @@ -280,7 +291,7 @@ func (s) TestEndpointsWatchAfterCache(t *testing.T) { } // New watch should receives the update. - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -315,7 +326,7 @@ func (s) TestEndpointsWatchExpiryTimer(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -325,8 +336,104 @@ func (s) TestEndpointsWatchExpiryTimer(t *testing.T) { if err != nil { t.Fatalf("timeout when waiting for endpoints update: %v", err) } - gotUpdate := u.(endpointsUpdateErr) - if gotUpdate.err == nil || !cmp.Equal(gotUpdate.u, EndpointsUpdate{}) { - t.Fatalf("unexpected endpointsUpdate: (%v, %v), want: (EndpointsUpdate{}, nil)", gotUpdate.u, gotUpdate.err) + gotUpdate := u.(EndpointsUpdateErrTuple) + if gotUpdate.Err == nil || !cmp.Equal(gotUpdate.Update, EndpointsUpdate{}) { + t.Fatalf("unexpected endpointsUpdate: (%v, %v), want: (EndpointsUpdate{}, nil)", gotUpdate.Update, gotUpdate.Err) + } +} + +// TestEndpointsWatchNACKError covers the case that an update is NACK'ed, and +// the watcher should also receive the error. +func (s) TestEndpointsWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + endpointsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Err: wantError}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, EndpointsUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} + +// TestEndpointsWatchPartialValid covers the case that a response contains both +// valid and invalid resources. This response will be NACK'ed by the xdsclient. +// But the watchers with valid resources should receive the update, those with +// invalida resources should receive an error. +func (s) TestEndpointsWatchPartialValid(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const badResourceName = "bad-resource" + updateChs := make(map[string]*testutils.Channel) + + for _, name := range []string{testCDSName, badResourceName} { + endpointsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchEndpoints(name, func(update EndpointsUpdate, err error) { + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) + }) + defer func() { + cancelWatch() + if _, err := apiClient.removeWatches[EndpointsResource].Receive(ctx); err != nil { + t.Fatalf("want watch to be canceled, got err: %v", err) + } + }() + if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + updateChs[name] = endpointsUpdateCh + } + + wantError := fmt.Errorf("testing error") + wantError2 := fmt.Errorf("individual error") + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{ + testCDSName: {Update: EndpointsUpdate{Localities: []Locality{testLocalities[0]}}}, + badResourceName: {Err: wantError2}, + }, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + + // The valid resource should be sent to the watcher. + if err := verifyEndpointsUpdate(ctx, updateChs[testCDSName], EndpointsUpdate{Localities: []Locality{testLocalities[0]}}, nil); err != nil { + t.Fatal(err) + } + + // The failed watcher should receive an error. + if err := verifyEndpointsUpdate(ctx, updateChs[badResourceName], EndpointsUpdate{}, wantError2); err != nil { + t.Fatal(err) } } diff --git a/xds/internal/xdsclient/watchers_listener_test.go b/xds/internal/xdsclient/watchers_listener_test.go new file mode 100644 index 00000000000..176e6bbcb7b --- /dev/null +++ b/xds/internal/xdsclient/watchers_listener_test.go @@ -0,0 +1,591 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "context" + "fmt" + "testing" + + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/protobuf/types/known/anypb" +) + +// TestLDSWatch covers the cases: +// - an update is received after a watch() +// - an update for another resource name +// - an update is received after cancel() +func (s) TestLDSWatch(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate, nil); err != nil { + t.Fatal(err) + } + + // Push an update, with an extra resource for a different resource name. + // Specify a non-nil raw proto in the original resource to ensure that the + // new update is not considered equal to the old one. + newUpdate := ListenerUpdate{RouteConfigName: testRDSName, Raw: &anypb.Any{}} + client.NewListeners(map[string]ListenerUpdateErrTuple{ + testLDSName: {Update: newUpdate}, + "randomName": {}, + }, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh, newUpdate, nil); err != nil { + t.Fatal(err) + } + + // Cancel watch, and send update again. + cancelWatch() + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) + } +} + +// TestLDSTwoWatchSameResourceName covers the case where an update is received +// after two watch() for the same resource name. +func (s) TestLDSTwoWatchSameResourceName(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const count = 2 + var ( + ldsUpdateChs []*testutils.Channel + cancelLastWatch func() + ) + + for i := 0; i < count; i++ { + ldsUpdateCh := testutils.NewChannel() + ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) + cancelLastWatch = client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + + if i == 0 { + // A new watch is registered on the underlying API client only for + // the first iteration because we are using the same resource name. + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + } + } + + wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate, nil); err != nil { + t.Fatal(err) + } + } + + // Cancel the last watch, and send update again. None of the watchers should + // be notified because one has been cancelled, and the other is receiving + // the same update. + cancelLastWatch() + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + func() { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateChs[i].Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) + } + }() + } + + // Push a new update and make sure the uncancelled watcher is invoked. + // Specify a non-nil raw proto to ensure that the new update is not + // considered equal to the old one. + newUpdate := ListenerUpdate{RouteConfigName: testRDSName, Raw: &anypb.Any{}} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: newUpdate}}, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateChs[0], newUpdate, nil); err != nil { + t.Fatal(err) + } +} + +// TestLDSThreeWatchDifferentResourceName covers the case where an update is +// received after three watch() for different resource names. +func (s) TestLDSThreeWatchDifferentResourceName(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + var ldsUpdateChs []*testutils.Channel + const count = 2 + + // Two watches for the same name. + for i := 0; i < count; i++ { + ldsUpdateCh := testutils.NewChannel() + ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) + client.WatchListener(testLDSName+"1", func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + + if i == 0 { + // A new watch is registered on the underlying API client only for + // the first iteration because we are using the same resource name. + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + } + } + + // Third watch for a different name. + ldsUpdateCh2 := testutils.NewChannel() + client.WatchListener(testLDSName+"2", func(update ListenerUpdate, err error) { + ldsUpdateCh2.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantUpdate1 := ListenerUpdate{RouteConfigName: testRDSName + "1"} + wantUpdate2 := ListenerUpdate{RouteConfigName: testRDSName + "2"} + client.NewListeners(map[string]ListenerUpdateErrTuple{ + testLDSName + "1": {Update: wantUpdate1}, + testLDSName + "2": {Update: wantUpdate2}, + }, UpdateMetadata{}) + + for i := 0; i < count; i++ { + if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate1, nil); err != nil { + t.Fatal(err) + } + } + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } +} + +// TestLDSWatchAfterCache covers the case where watch is called after the update +// is in cache. +func (s) TestLDSWatchAfterCache(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh := testutils.NewChannel() + client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate, nil); err != nil { + t.Fatal(err) + } + + // Another watch for the resource in cache. + ldsUpdateCh2 := testutils.NewChannel() + client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh2.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if n, err := apiClient.addWatches[ListenerResource].Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err) + } + + // New watch should receive the update. + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate, nil); err != nil { + t.Fatal(err) + } + + // Old watch should see nothing. + sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) + } +} + +// TestLDSResourceRemoved covers the cases: +// - an update is received after a watch() +// - another update is received, with one resource removed +// - this should trigger callback with resource removed error +// - one more update without the removed resource +// - the callback (above) shouldn't receive any update +func (s) TestLDSResourceRemoved(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh1 := testutils.NewChannel() + client.WatchListener(testLDSName+"1", func(update ListenerUpdate, err error) { + ldsUpdateCh1.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + // Another watch for a different name. + ldsUpdateCh2 := testutils.NewChannel() + client.WatchListener(testLDSName+"2", func(update ListenerUpdate, err error) { + ldsUpdateCh2.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantUpdate1 := ListenerUpdate{RouteConfigName: testEDSName + "1"} + wantUpdate2 := ListenerUpdate{RouteConfigName: testEDSName + "2"} + client.NewListeners(map[string]ListenerUpdateErrTuple{ + testLDSName + "1": {Update: wantUpdate1}, + testLDSName + "2": {Update: wantUpdate2}, + }, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh1, wantUpdate1, nil); err != nil { + t.Fatal(err) + } + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } + + // Send another update to remove resource 1. + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName + "2": {Update: wantUpdate2}}, UpdateMetadata{}) + + // Watcher 1 should get an error. + if u, err := ldsUpdateCh1.Receive(ctx); err != nil || ErrType(u.(ListenerUpdateErrTuple).Err) != ErrorTypeResourceNotFound { + t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err) + } + + // Watcher 2 should not see an update since the resource has not changed. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh2.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u) + } + + // Send another update with resource 2 modified. Specify a non-nil raw proto + // to ensure that the new update is not considered equal to the old one. + wantUpdate2 = ListenerUpdate{RouteConfigName: testEDSName + "2", Raw: &anypb.Any{}} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName + "2": {Update: wantUpdate2}}, UpdateMetadata{}) + + // Watcher 1 should not see an update. + sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh1.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u) + } + + // Watcher 2 should get the update. + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } +} + +// TestListenerWatchNACKError covers the case that an update is NACK'ed, and the +// watcher should also receive the error. +func (s) TestListenerWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Err: wantError}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh, ListenerUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} + +// TestListenerWatchPartialValid covers the case that a response contains both +// valid and invalid resources. This response will be NACK'ed by the xdsclient. +// But the watchers with valid resources should receive the update, those with +// invalida resources should receive an error. +func (s) TestListenerWatchPartialValid(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const badResourceName = "bad-resource" + updateChs := make(map[string]*testutils.Channel) + + for _, name := range []string{testLDSName, badResourceName} { + ldsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchListener(name, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + defer func() { + cancelWatch() + if _, err := apiClient.removeWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want watch to be canceled, got err: %v", err) + } + }() + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + updateChs[name] = ldsUpdateCh + } + + wantError := fmt.Errorf("testing error") + wantError2 := fmt.Errorf("individual error") + client.NewListeners(map[string]ListenerUpdateErrTuple{ + testLDSName: {Update: ListenerUpdate{RouteConfigName: testEDSName}}, + badResourceName: {Err: wantError2}, + }, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + + // The valid resource should be sent to the watcher. + if err := verifyListenerUpdate(ctx, updateChs[testLDSName], ListenerUpdate{RouteConfigName: testEDSName}, nil); err != nil { + t.Fatal(err) + } + + // The failed watcher should receive an error. + if err := verifyListenerUpdate(ctx, updateChs[badResourceName], ListenerUpdate{}, wantError2); err != nil { + t.Fatal(err) + } +} + +// TestListenerWatch_RedundantUpdateSupression tests scenarios where an update +// with an unmodified resource is suppressed, and modified resource is not. +func (s) TestListenerWatch_RedundantUpdateSupression(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh := testutils.NewChannel() + client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + basicListener := testutils.MarshalAny(&v3listenerpb.Listener{ + Name: testLDSName, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{RouteConfigName: "route-config-name"}, + }, + }), + }, + }) + listenerWithFilter1 := testutils.MarshalAny(&v3listenerpb.Listener{ + Name: testLDSName, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{RouteConfigName: "route-config-name"}, + }, + HttpFilters: []*v3httppb.HttpFilter{ + { + Name: "customFilter1", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + }, + }, + }), + }, + }) + listenerWithFilter2 := testutils.MarshalAny(&v3listenerpb.Listener{ + Name: testLDSName, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{RouteConfigName: "route-config-name"}, + }, + HttpFilters: []*v3httppb.HttpFilter{ + { + Name: "customFilter2", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + }, + }, + }), + }, + }) + + tests := []struct { + update ListenerUpdate + wantCallback bool + }{ + { + // First update. Callback should be invoked. + update: ListenerUpdate{Raw: basicListener}, + wantCallback: true, + }, + { + // Same update as previous. Callback should be skipped. + update: ListenerUpdate{Raw: basicListener}, + wantCallback: false, + }, + { + // New update. Callback should be invoked. + update: ListenerUpdate{Raw: listenerWithFilter1}, + wantCallback: true, + }, + { + // Same update as previous. Callback should be skipped. + update: ListenerUpdate{Raw: listenerWithFilter1}, + wantCallback: false, + }, + { + // New update. Callback should be invoked. + update: ListenerUpdate{Raw: listenerWithFilter2}, + wantCallback: true, + }, + { + // Same update as previous. Callback should be skipped. + update: ListenerUpdate{Raw: listenerWithFilter2}, + wantCallback: false, + }, + } + for _, test := range tests { + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: test.update}}, UpdateMetadata{}) + if test.wantCallback { + if err := verifyListenerUpdate(ctx, ldsUpdateCh, test.update, nil); err != nil { + t.Fatal(err) + } + } else { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u) + } + } + } +} diff --git a/xds/internal/client/watchers_route_test.go b/xds/internal/xdsclient/watchers_route_test.go similarity index 54% rename from xds/internal/client/watchers_route_test.go rename to xds/internal/xdsclient/watchers_route_test.go index 5f44e549333..70c8dd829e9 100644 --- a/xds/internal/client/watchers_route_test.go +++ b/xds/internal/xdsclient/watchers_route_test.go @@ -16,22 +16,19 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc/internal/testutils" ) -type rdsUpdateErr struct { - u RouteConfigUpdate - err error -} - // TestRDSWatch covers the cases: // - an update is received after a watch() // - an update for another resource name (which doesn't trigger callback) @@ -56,7 +53,7 @@ func (s) TestRDSWatch(t *testing.T) { rdsUpdateCh := testutils.NewChannel() cancelWatch := client.WatchRouteConfig(testRDSName, func(update RouteConfigUpdate, err error) { - rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -70,23 +67,28 @@ func (s) TestRDSWatch(t *testing.T) { }, }, } - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate); err != nil { + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } - // Another update for a different resource name. - client.NewRouteConfigs(map[string]RouteConfigUpdate{"randomName": {}}, UpdateMetadata{}) - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := rdsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) + // Push an update, with an extra resource for a different resource name. + // Specify a non-nil raw proto in the original resource to ensure that the + // new update is not considered equal to the old one. + newUpdate := wantUpdate + newUpdate.Raw = &anypb.Any{} + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{ + testRDSName: {Update: newUpdate}, + "randomName": {}, + }, UpdateMetadata{}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, newUpdate, nil); err != nil { + t.Fatal(err) } // Cancel watch, and send update again. cancelWatch() - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() if u, err := rdsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) @@ -122,7 +124,7 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) { rdsUpdateCh := testutils.NewChannel() rdsUpdateChs = append(rdsUpdateChs, rdsUpdateCh) cancelLastWatch = client.WatchRouteConfig(testRDSName, func(update RouteConfigUpdate, err error) { - rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -142,26 +144,36 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) { }, }, } - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } - // Cancel the last watch, and send update again. + // Cancel the last watch, and send update again. None of the watchers should + // be notified because one has been cancelled, and the other is receiving + // the same update. cancelLastWatch() - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count-1; i++ { - if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + func() { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := rdsUpdateChs[i].Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) + } + }() } - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := rdsUpdateChs[count-1].Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) + // Push a new update and make sure the uncancelled watcher is invoked. + // Specify a non-nil raw proto to ensure that the new update is not + // considered equal to the old one. + newUpdate := wantUpdate + newUpdate.Raw = &anypb.Any{} + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: newUpdate}}, UpdateMetadata{}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[0], newUpdate, nil); err != nil { + t.Fatal(err) } } @@ -192,7 +204,7 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) { rdsUpdateCh := testutils.NewChannel() rdsUpdateChs = append(rdsUpdateChs, rdsUpdateCh) client.WatchRouteConfig(testRDSName+"1", func(update RouteConfigUpdate, err error) { - rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -207,7 +219,7 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) { // Third watch for a different name. rdsUpdateCh2 := testutils.NewChannel() client.WatchRouteConfig(testRDSName+"2", func(update RouteConfigUpdate, err error) { - rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh2.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -229,17 +241,17 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) { }, }, } - client.NewRouteConfigs(map[string]RouteConfigUpdate{ - testRDSName + "1": wantUpdate1, - testRDSName + "2": wantUpdate2, + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{ + testRDSName + "1": {Update: wantUpdate1}, + testRDSName + "2": {Update: wantUpdate2}, }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate1); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh2, wantUpdate2); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -266,7 +278,7 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { rdsUpdateCh := testutils.NewChannel() client.WatchRouteConfig(testRDSName, func(update RouteConfigUpdate, err error) { - rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -280,15 +292,15 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { }, }, } - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate); err != nil { + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } // Another watch for the resource in cache. rdsUpdateCh2 := testutils.NewChannel() client.WatchRouteConfig(testRDSName, func(update RouteConfigUpdate, err error) { - rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh2.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() @@ -297,7 +309,7 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { } // New watch should receives the update. - if u, err := rdsUpdateCh2.Receive(ctx); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) { + if u, err := rdsUpdateCh2.Receive(ctx); err != nil || !cmp.Equal(u, RouteConfigUpdateErrTuple{wantUpdate, nil}, cmp.AllowUnexported(RouteConfigUpdateErrTuple{})) { t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err) } @@ -308,3 +320,105 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) } } + +// TestRouteWatchNACKError covers the case that an update is NACK'ed, and the +// watcher should also receive the error. +func (s) TestRouteWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + rdsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchRouteConfig(testCDSName, func(update RouteConfigUpdate, err error) { + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testCDSName: {Err: wantError}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, RouteConfigUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} + +// TestRouteWatchPartialValid covers the case that a response contains both +// valid and invalid resources. This response will be NACK'ed by the xdsclient. +// But the watchers with valid resources should receive the update, those with +// invalida resources should receive an error. +func (s) TestRouteWatchPartialValid(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const badResourceName = "bad-resource" + updateChs := make(map[string]*testutils.Channel) + + for _, name := range []string{testRDSName, badResourceName} { + rdsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchRouteConfig(name, func(update RouteConfigUpdate, err error) { + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) + }) + defer func() { + cancelWatch() + if _, err := apiClient.removeWatches[RouteConfigResource].Receive(ctx); err != nil { + t.Fatalf("want watch to be canceled, got err: %v", err) + } + }() + if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + updateChs[name] = rdsUpdateCh + } + + wantError := fmt.Errorf("testing error") + wantError2 := fmt.Errorf("individual error") + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{ + testRDSName: {Update: RouteConfigUpdate{VirtualHosts: []*VirtualHost{{ + Domains: []string{testLDSName}, + Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{testCDSName: {Weight: 1}}}}, + }}}}, + badResourceName: {Err: wantError2}, + }, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + + // The valid resource should be sent to the watcher. + if err := verifyRouteConfigUpdate(ctx, updateChs[testRDSName], RouteConfigUpdate{VirtualHosts: []*VirtualHost{{ + Domains: []string{testLDSName}, + Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{testCDSName: {Weight: 1}}}}, + }}}, nil); err != nil { + t.Fatal(err) + } + + // The failed watcher should receive an error. + if err := verifyRouteConfigUpdate(ctx, updateChs[badResourceName], RouteConfigUpdate{}, wantError2); err != nil { + t.Fatal(err) + } +} diff --git a/xds/internal/xdsclient/xds.go b/xds/internal/xdsclient/xds.go new file mode 100644 index 00000000000..732c4e6addc --- /dev/null +++ b/xds/internal/xdsclient/xds.go @@ -0,0 +1,1334 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "errors" + "fmt" + "net" + "regexp" + "strconv" + "strings" + "time" + + v1typepb "github.com/cncf/udpa/go/udpa/type/v1" + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3aggregateclusterpb "github.com/envoyproxy/go-control-plane/envoy/extensions/clusters/aggregate/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/internal/xds/matcher" + "google.golang.org/protobuf/types/known/anypb" + + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/version" +) + +// TransportSocket proto message has a `name` field which is expected to be set +// to this value by the management server. +const transportSocketName = "envoy.transport_sockets.tls" + +// UnmarshalOptions wraps the input parameters for `UnmarshalXxx` functions. +type UnmarshalOptions struct { + // Version is the version of the received response. + Version string + // Resources are the xDS resources resources in the received response. + Resources []*anypb.Any + // Logger is the prefix logger to be used during unmarshaling. + Logger *grpclog.PrefixLogger + // UpdateValidator is a post unmarshal validation check provided by the + // upper layer. + UpdateValidator UpdateValidatorFunc +} + +// UnmarshalListener processes resources received in an LDS response, validates +// them, and transforms them into a native struct which contains only fields we +// are interested in. +func UnmarshalListener(opts *UnmarshalOptions) (map[string]ListenerUpdateErrTuple, UpdateMetadata, error) { + update := make(map[string]ListenerUpdateErrTuple) + md, err := processAllResources(opts, update) + return update, md, err +} + +func unmarshalListenerResource(r *anypb.Any, f UpdateValidatorFunc, logger *grpclog.PrefixLogger) (string, ListenerUpdate, error) { + if !IsListenerResource(r.GetTypeUrl()) { + return "", ListenerUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) + } + // TODO: Pass version.TransportAPI instead of relying upon the type URL + v2 := r.GetTypeUrl() == version.V2ListenerURL + lis := &v3listenerpb.Listener{} + if err := proto.Unmarshal(r.GetValue(), lis); err != nil { + return "", ListenerUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + logger.Infof("Resource with name: %v, type: %T, contains: %v", lis.GetName(), lis, pretty.ToJSON(lis)) + + lu, err := processListener(lis, logger, v2) + if err != nil { + return lis.GetName(), ListenerUpdate{}, err + } + if f != nil { + if err := f(*lu); err != nil { + return lis.GetName(), ListenerUpdate{}, err + } + } + lu.Raw = r + return lis.GetName(), *lu, nil +} + +func processListener(lis *v3listenerpb.Listener, logger *grpclog.PrefixLogger, v2 bool) (*ListenerUpdate, error) { + if lis.GetApiListener() != nil { + return processClientSideListener(lis, logger, v2) + } + return processServerSideListener(lis) +} + +// processClientSideListener checks if the provided Listener proto meets +// the expected criteria. If so, it returns a non-empty routeConfigName. +func processClientSideListener(lis *v3listenerpb.Listener, logger *grpclog.PrefixLogger, v2 bool) (*ListenerUpdate, error) { + update := &ListenerUpdate{} + + apiLisAny := lis.GetApiListener().GetApiListener() + if !IsHTTPConnManagerResource(apiLisAny.GetTypeUrl()) { + return nil, fmt.Errorf("unexpected resource type: %q", apiLisAny.GetTypeUrl()) + } + apiLis := &v3httppb.HttpConnectionManager{} + if err := proto.Unmarshal(apiLisAny.GetValue(), apiLis); err != nil { + return nil, fmt.Errorf("failed to unmarshal api_listner: %v", err) + } + // "HttpConnectionManager.xff_num_trusted_hops must be unset or zero and + // HttpConnectionManager.original_ip_detection_extensions must be empty. If + // either field has an incorrect value, the Listener must be NACKed." - A41 + if apiLis.XffNumTrustedHops != 0 { + return nil, fmt.Errorf("xff_num_trusted_hops must be unset or zero %+v", apiLis) + } + if len(apiLis.OriginalIpDetectionExtensions) != 0 { + return nil, fmt.Errorf("original_ip_detection_extensions must be empty %+v", apiLis) + } + + switch apiLis.RouteSpecifier.(type) { + case *v3httppb.HttpConnectionManager_Rds: + if apiLis.GetRds().GetConfigSource().GetAds() == nil { + return nil, fmt.Errorf("ConfigSource is not ADS: %+v", lis) + } + name := apiLis.GetRds().GetRouteConfigName() + if name == "" { + return nil, fmt.Errorf("empty route_config_name: %+v", lis) + } + update.RouteConfigName = name + case *v3httppb.HttpConnectionManager_RouteConfig: + routeU, err := generateRDSUpdateFromRouteConfiguration(apiLis.GetRouteConfig(), logger, v2) + if err != nil { + return nil, fmt.Errorf("failed to parse inline RDS resp: %v", err) + } + update.InlineRouteConfig = &routeU + case nil: + return nil, fmt.Errorf("no RouteSpecifier: %+v", apiLis) + default: + return nil, fmt.Errorf("unsupported type %T for RouteSpecifier", apiLis.RouteSpecifier) + } + + if v2 { + return update, nil + } + + // The following checks and fields only apply to xDS protocol versions v3+. + + update.MaxStreamDuration = apiLis.GetCommonHttpProtocolOptions().GetMaxStreamDuration().AsDuration() + + var err error + if update.HTTPFilters, err = processHTTPFilters(apiLis.GetHttpFilters(), false); err != nil { + return nil, err + } + + return update, nil +} + +func unwrapHTTPFilterConfig(config *anypb.Any) (proto.Message, string, error) { + // The real type name is inside the TypedStruct. + s := new(v1typepb.TypedStruct) + if !ptypes.Is(config, s) { + return config, config.GetTypeUrl(), nil + } + if err := ptypes.UnmarshalAny(config, s); err != nil { + return nil, "", fmt.Errorf("error unmarshalling TypedStruct filter config: %v", err) + } + return s, s.GetTypeUrl(), nil +} + +func validateHTTPFilterConfig(cfg *anypb.Any, lds, optional bool) (httpfilter.Filter, httpfilter.FilterConfig, error) { + config, typeURL, err := unwrapHTTPFilterConfig(cfg) + if err != nil { + return nil, nil, err + } + filterBuilder := httpfilter.Get(typeURL) + if filterBuilder == nil { + if optional { + return nil, nil, nil + } + return nil, nil, fmt.Errorf("no filter implementation found for %q", typeURL) + } + parseFunc := filterBuilder.ParseFilterConfig + if !lds { + parseFunc = filterBuilder.ParseFilterConfigOverride + } + filterConfig, err := parseFunc(config) + if err != nil { + return nil, nil, fmt.Errorf("error parsing config for filter %q: %v", typeURL, err) + } + return filterBuilder, filterConfig, nil +} + +func processHTTPFilterOverrides(cfgs map[string]*anypb.Any) (map[string]httpfilter.FilterConfig, error) { + if len(cfgs) == 0 { + return nil, nil + } + m := make(map[string]httpfilter.FilterConfig) + for name, cfg := range cfgs { + optional := false + s := new(v3routepb.FilterConfig) + if ptypes.Is(cfg, s) { + if err := ptypes.UnmarshalAny(cfg, s); err != nil { + return nil, fmt.Errorf("filter override %q: error unmarshalling FilterConfig: %v", name, err) + } + cfg = s.GetConfig() + optional = s.GetIsOptional() + } + + httpFilter, config, err := validateHTTPFilterConfig(cfg, false, optional) + if err != nil { + return nil, fmt.Errorf("filter override %q: %v", name, err) + } + if httpFilter == nil { + // Optional configs are ignored. + continue + } + m[name] = config + } + return m, nil +} + +func processHTTPFilters(filters []*v3httppb.HttpFilter, server bool) ([]HTTPFilter, error) { + ret := make([]HTTPFilter, 0, len(filters)) + seenNames := make(map[string]bool, len(filters)) + for _, filter := range filters { + name := filter.GetName() + if name == "" { + return nil, errors.New("filter missing name field") + } + if seenNames[name] { + return nil, fmt.Errorf("duplicate filter name %q", name) + } + seenNames[name] = true + + httpFilter, config, err := validateHTTPFilterConfig(filter.GetTypedConfig(), true, filter.GetIsOptional()) + if err != nil { + return nil, err + } + if httpFilter == nil { + // Optional configs are ignored. + continue + } + if server { + if _, ok := httpFilter.(httpfilter.ServerInterceptorBuilder); !ok { + if filter.GetIsOptional() { + continue + } + return nil, fmt.Errorf("HTTP filter %q not supported server-side", name) + } + } else if _, ok := httpFilter.(httpfilter.ClientInterceptorBuilder); !ok { + if filter.GetIsOptional() { + continue + } + return nil, fmt.Errorf("HTTP filter %q not supported client-side", name) + } + + // Save name/config + ret = append(ret, HTTPFilter{Name: name, Filter: httpFilter, Config: config}) + } + // "Validation will fail if a terminal filter is not the last filter in the + // chain or if a non-terminal filter is the last filter in the chain." - A39 + if len(ret) == 0 { + return nil, fmt.Errorf("http filters list is empty") + } + var i int + for ; i < len(ret)-1; i++ { + if ret[i].Filter.IsTerminal() { + return nil, fmt.Errorf("http filter %q is a terminal filter but it is not last in the filter chain", ret[i].Name) + } + } + if !ret[i].Filter.IsTerminal() { + return nil, fmt.Errorf("http filter %q is not a terminal filter", ret[len(ret)-1].Name) + } + return ret, nil +} + +func processServerSideListener(lis *v3listenerpb.Listener) (*ListenerUpdate, error) { + if n := len(lis.ListenerFilters); n != 0 { + return nil, fmt.Errorf("unsupported field 'listener_filters' contains %d entries", n) + } + if useOrigDst := lis.GetUseOriginalDst(); useOrigDst != nil && useOrigDst.GetValue() { + return nil, errors.New("unsupported field 'use_original_dst' is present and set to true") + } + addr := lis.GetAddress() + if addr == nil { + return nil, fmt.Errorf("no address field in LDS response: %+v", lis) + } + sockAddr := addr.GetSocketAddress() + if sockAddr == nil { + return nil, fmt.Errorf("no socket_address field in LDS response: %+v", lis) + } + lu := &ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: sockAddr.GetAddress(), + Port: strconv.Itoa(int(sockAddr.GetPortValue())), + }, + } + + fcMgr, err := NewFilterChainManager(lis) + if err != nil { + return nil, err + } + lu.InboundListenerCfg.FilterChains = fcMgr + return lu, nil +} + +// UnmarshalRouteConfig processes resources received in an RDS response, +// validates them, and transforms them into a native struct which contains only +// fields we are interested in. The provided hostname determines the route +// configuration resources of interest. +func UnmarshalRouteConfig(opts *UnmarshalOptions) (map[string]RouteConfigUpdateErrTuple, UpdateMetadata, error) { + update := make(map[string]RouteConfigUpdateErrTuple) + md, err := processAllResources(opts, update) + return update, md, err +} + +func unmarshalRouteConfigResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, RouteConfigUpdate, error) { + if !IsRouteConfigResource(r.GetTypeUrl()) { + return "", RouteConfigUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) + } + rc := &v3routepb.RouteConfiguration{} + if err := proto.Unmarshal(r.GetValue(), rc); err != nil { + return "", RouteConfigUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + logger.Infof("Resource with name: %v, type: %T, contains: %v.", rc.GetName(), rc, pretty.ToJSON(rc)) + + // TODO: Pass version.TransportAPI instead of relying upon the type URL + v2 := r.GetTypeUrl() == version.V2RouteConfigURL + u, err := generateRDSUpdateFromRouteConfiguration(rc, logger, v2) + if err != nil { + return rc.GetName(), RouteConfigUpdate{}, err + } + u.Raw = r + return rc.GetName(), u, nil +} + +// generateRDSUpdateFromRouteConfiguration checks if the provided +// RouteConfiguration meets the expected criteria. If so, it returns a +// RouteConfigUpdate with nil error. +// +// A RouteConfiguration resource is considered valid when only if it contains a +// VirtualHost whose domain field matches the server name from the URI passed +// to the gRPC channel, and it contains a clusterName or a weighted cluster. +// +// The RouteConfiguration includes a list of virtualHosts, which may have zero +// or more elements. We are interested in the element whose domains field +// matches the server name specified in the "xds:" URI. The only field in the +// VirtualHost proto that the we are interested in is the list of routes. We +// only look at the last route in the list (the default route), whose match +// field must be empty and whose route field must be set. Inside that route +// message, the cluster field will contain the clusterName or weighted clusters +// we are looking for. +func generateRDSUpdateFromRouteConfiguration(rc *v3routepb.RouteConfiguration, logger *grpclog.PrefixLogger, v2 bool) (RouteConfigUpdate, error) { + vhs := make([]*VirtualHost, 0, len(rc.GetVirtualHosts())) + for _, vh := range rc.GetVirtualHosts() { + routes, err := routesProtoToSlice(vh.Routes, logger, v2) + if err != nil { + return RouteConfigUpdate{}, fmt.Errorf("received route is invalid: %v", err) + } + rc, err := generateRetryConfig(vh.GetRetryPolicy()) + if err != nil { + return RouteConfigUpdate{}, fmt.Errorf("received route is invalid: %v", err) + } + vhOut := &VirtualHost{ + Domains: vh.GetDomains(), + Routes: routes, + RetryConfig: rc, + } + if !v2 { + cfgs, err := processHTTPFilterOverrides(vh.GetTypedPerFilterConfig()) + if err != nil { + return RouteConfigUpdate{}, fmt.Errorf("virtual host %+v: %v", vh, err) + } + vhOut.HTTPFilterConfigOverride = cfgs + } + vhs = append(vhs, vhOut) + } + return RouteConfigUpdate{VirtualHosts: vhs}, nil +} + +func generateRetryConfig(rp *v3routepb.RetryPolicy) (*RetryConfig, error) { + if !env.RetrySupport || rp == nil { + return nil, nil + } + + cfg := &RetryConfig{RetryOn: make(map[codes.Code]bool)} + for _, s := range strings.Split(rp.GetRetryOn(), ",") { + switch strings.TrimSpace(strings.ToLower(s)) { + case "cancelled": + cfg.RetryOn[codes.Canceled] = true + case "deadline-exceeded": + cfg.RetryOn[codes.DeadlineExceeded] = true + case "internal": + cfg.RetryOn[codes.Internal] = true + case "resource-exhausted": + cfg.RetryOn[codes.ResourceExhausted] = true + case "unavailable": + cfg.RetryOn[codes.Unavailable] = true + } + } + + if rp.NumRetries == nil { + cfg.NumRetries = 1 + } else { + cfg.NumRetries = rp.GetNumRetries().Value + if cfg.NumRetries < 1 { + return nil, fmt.Errorf("retry_policy.num_retries = %v; must be >= 1", cfg.NumRetries) + } + } + + backoff := rp.GetRetryBackOff() + if backoff == nil { + cfg.RetryBackoff.BaseInterval = 25 * time.Millisecond + } else { + cfg.RetryBackoff.BaseInterval = backoff.GetBaseInterval().AsDuration() + if cfg.RetryBackoff.BaseInterval <= 0 { + return nil, fmt.Errorf("retry_policy.base_interval = %v; must be > 0", cfg.RetryBackoff.BaseInterval) + } + } + if max := backoff.GetMaxInterval(); max == nil { + cfg.RetryBackoff.MaxInterval = 10 * cfg.RetryBackoff.BaseInterval + } else { + cfg.RetryBackoff.MaxInterval = max.AsDuration() + if cfg.RetryBackoff.MaxInterval <= 0 { + return nil, fmt.Errorf("retry_policy.max_interval = %v; must be > 0", cfg.RetryBackoff.MaxInterval) + } + } + + if len(cfg.RetryOn) == 0 { + return &RetryConfig{}, nil + } + return cfg, nil +} + +func routesProtoToSlice(routes []*v3routepb.Route, logger *grpclog.PrefixLogger, v2 bool) ([]*Route, error) { + var routesRet []*Route + for _, r := range routes { + match := r.GetMatch() + if match == nil { + return nil, fmt.Errorf("route %+v doesn't have a match", r) + } + + if len(match.GetQueryParameters()) != 0 { + // Ignore route with query parameters. + logger.Warningf("route %+v has query parameter matchers, the route will be ignored", r) + continue + } + + pathSp := match.GetPathSpecifier() + if pathSp == nil { + return nil, fmt.Errorf("route %+v doesn't have a path specifier", r) + } + + var route Route + switch pt := pathSp.(type) { + case *v3routepb.RouteMatch_Prefix: + route.Prefix = &pt.Prefix + case *v3routepb.RouteMatch_Path: + route.Path = &pt.Path + case *v3routepb.RouteMatch_SafeRegex: + regex := pt.SafeRegex.GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("route %+v contains an invalid regex %q", r, regex) + } + route.Regex = re + default: + return nil, fmt.Errorf("route %+v has an unrecognized path specifier: %+v", r, pt) + } + + if caseSensitive := match.GetCaseSensitive(); caseSensitive != nil { + route.CaseInsensitive = !caseSensitive.Value + } + + for _, h := range match.GetHeaders() { + var header HeaderMatcher + switch ht := h.GetHeaderMatchSpecifier().(type) { + case *v3routepb.HeaderMatcher_ExactMatch: + header.ExactMatch = &ht.ExactMatch + case *v3routepb.HeaderMatcher_SafeRegexMatch: + regex := ht.SafeRegexMatch.GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("route %+v contains an invalid regex %q", r, regex) + } + header.RegexMatch = re + case *v3routepb.HeaderMatcher_RangeMatch: + header.RangeMatch = &Int64Range{ + Start: ht.RangeMatch.Start, + End: ht.RangeMatch.End, + } + case *v3routepb.HeaderMatcher_PresentMatch: + header.PresentMatch = &ht.PresentMatch + case *v3routepb.HeaderMatcher_PrefixMatch: + header.PrefixMatch = &ht.PrefixMatch + case *v3routepb.HeaderMatcher_SuffixMatch: + header.SuffixMatch = &ht.SuffixMatch + default: + return nil, fmt.Errorf("route %+v has an unrecognized header matcher: %+v", r, ht) + } + header.Name = h.GetName() + invert := h.GetInvertMatch() + header.InvertMatch = &invert + route.Headers = append(route.Headers, &header) + } + + if fr := match.GetRuntimeFraction(); fr != nil { + d := fr.GetDefaultValue() + n := d.GetNumerator() + switch d.GetDenominator() { + case v3typepb.FractionalPercent_HUNDRED: + n *= 10000 + case v3typepb.FractionalPercent_TEN_THOUSAND: + n *= 100 + case v3typepb.FractionalPercent_MILLION: + } + route.Fraction = &n + } + + switch r.GetAction().(type) { + case *v3routepb.Route_Route: + route.WeightedClusters = make(map[string]WeightedCluster) + action := r.GetRoute() + + // Hash Policies are only applicable for a Ring Hash LB. + if env.RingHashSupport { + hp, err := hashPoliciesProtoToSlice(action.HashPolicy, logger) + if err != nil { + return nil, err + } + route.HashPolicies = hp + } + + switch a := action.GetClusterSpecifier().(type) { + case *v3routepb.RouteAction_Cluster: + route.WeightedClusters[a.Cluster] = WeightedCluster{Weight: 1} + case *v3routepb.RouteAction_WeightedClusters: + wcs := a.WeightedClusters + var totalWeight uint32 + for _, c := range wcs.Clusters { + w := c.GetWeight().GetValue() + if w == 0 { + continue + } + wc := WeightedCluster{Weight: w} + if !v2 { + cfgs, err := processHTTPFilterOverrides(c.GetTypedPerFilterConfig()) + if err != nil { + return nil, fmt.Errorf("route %+v, action %+v: %v", r, a, err) + } + wc.HTTPFilterConfigOverride = cfgs + } + route.WeightedClusters[c.GetName()] = wc + totalWeight += w + } + // envoy xds doc + // default TotalWeight https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/route/v3/route_components.proto.html#envoy-v3-api-field-config-route-v3-weightedcluster-total-weight + wantTotalWeight := uint32(100) + if tw := wcs.GetTotalWeight(); tw != nil { + wantTotalWeight = tw.GetValue() + } + if totalWeight != wantTotalWeight { + return nil, fmt.Errorf("route %+v, action %+v, weights of clusters do not add up to total total weight, got: %v, expected total weight from response: %v", r, a, totalWeight, wantTotalWeight) + } + if totalWeight == 0 { + return nil, fmt.Errorf("route %+v, action %+v, has no valid cluster in WeightedCluster action", r, a) + } + case *v3routepb.RouteAction_ClusterHeader: + continue + default: + return nil, fmt.Errorf("route %+v, has an unknown ClusterSpecifier: %+v", r, a) + } + + msd := action.GetMaxStreamDuration() + // Prefer grpc_timeout_header_max, if set. + dur := msd.GetGrpcTimeoutHeaderMax() + if dur == nil { + dur = msd.GetMaxStreamDuration() + } + if dur != nil { + d := dur.AsDuration() + route.MaxStreamDuration = &d + } + + var err error + route.RetryConfig, err = generateRetryConfig(action.GetRetryPolicy()) + if err != nil { + return nil, fmt.Errorf("route %+v, action %+v: %v", r, action, err) + } + + route.RouteAction = RouteActionRoute + + case *v3routepb.Route_NonForwardingAction: + // Expected to be used on server side. + route.RouteAction = RouteActionNonForwardingAction + default: + route.RouteAction = RouteActionUnsupported + } + + if !v2 { + cfgs, err := processHTTPFilterOverrides(r.GetTypedPerFilterConfig()) + if err != nil { + return nil, fmt.Errorf("route %+v: %v", r, err) + } + route.HTTPFilterConfigOverride = cfgs + } + routesRet = append(routesRet, &route) + } + return routesRet, nil +} + +func hashPoliciesProtoToSlice(policies []*v3routepb.RouteAction_HashPolicy, logger *grpclog.PrefixLogger) ([]*HashPolicy, error) { + var hashPoliciesRet []*HashPolicy + for _, p := range policies { + policy := HashPolicy{Terminal: p.Terminal} + switch p.GetPolicySpecifier().(type) { + case *v3routepb.RouteAction_HashPolicy_Header_: + policy.HashPolicyType = HashPolicyTypeHeader + policy.HeaderName = p.GetHeader().GetHeaderName() + if rr := p.GetHeader().GetRegexRewrite(); rr != nil { + regex := rr.GetPattern().GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("hash policy %+v contains an invalid regex %q", p, regex) + } + policy.Regex = re + policy.RegexSubstitution = rr.GetSubstitution() + } + case *v3routepb.RouteAction_HashPolicy_FilterState_: + if p.GetFilterState().GetKey() != "io.grpc.channel_id" { + logger.Infof("hash policy %+v contains an invalid key for filter state policy %q", p, p.GetFilterState().GetKey()) + continue + } + policy.HashPolicyType = HashPolicyTypeChannelID + default: + logger.Infof("hash policy %T is an unsupported hash policy", p.GetPolicySpecifier()) + continue + } + + hashPoliciesRet = append(hashPoliciesRet, &policy) + } + return hashPoliciesRet, nil +} + +// UnmarshalCluster processes resources received in an CDS response, validates +// them, and transforms them into a native struct which contains only fields we +// are interested in. +func UnmarshalCluster(opts *UnmarshalOptions) (map[string]ClusterUpdateErrTuple, UpdateMetadata, error) { + update := make(map[string]ClusterUpdateErrTuple) + md, err := processAllResources(opts, update) + return update, md, err +} + +func unmarshalClusterResource(r *anypb.Any, f UpdateValidatorFunc, logger *grpclog.PrefixLogger) (string, ClusterUpdate, error) { + if !IsClusterResource(r.GetTypeUrl()) { + return "", ClusterUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) + } + + cluster := &v3clusterpb.Cluster{} + if err := proto.Unmarshal(r.GetValue(), cluster); err != nil { + return "", ClusterUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + logger.Infof("Resource with name: %v, type: %T, contains: %v", cluster.GetName(), cluster, pretty.ToJSON(cluster)) + cu, err := validateClusterAndConstructClusterUpdate(cluster) + if err != nil { + return cluster.GetName(), ClusterUpdate{}, err + } + cu.Raw = r + if f != nil { + if err := f(cu); err != nil { + return "", ClusterUpdate{}, err + } + } + + return cluster.GetName(), cu, nil +} + +const ( + defaultRingHashMinSize = 1024 + defaultRingHashMaxSize = 8 * 1024 * 1024 // 8M + ringHashSizeUpperBound = 8 * 1024 * 1024 // 8M +) + +func validateClusterAndConstructClusterUpdate(cluster *v3clusterpb.Cluster) (ClusterUpdate, error) { + var lbPolicy *ClusterLBPolicyRingHash + switch cluster.GetLbPolicy() { + case v3clusterpb.Cluster_ROUND_ROBIN: + lbPolicy = nil // The default is round_robin, and there's no config to set. + case v3clusterpb.Cluster_RING_HASH: + if !env.RingHashSupport { + return ClusterUpdate{}, fmt.Errorf("unexpected lbPolicy %v in response: %+v", cluster.GetLbPolicy(), cluster) + } + rhc := cluster.GetRingHashLbConfig() + if rhc.GetHashFunction() != v3clusterpb.Cluster_RingHashLbConfig_XX_HASH { + return ClusterUpdate{}, fmt.Errorf("unsupported ring_hash hash function %v in response: %+v", rhc.GetHashFunction(), cluster) + } + // Minimum defaults to 1024 entries, and limited to 8M entries Maximum + // defaults to 8M entries, and limited to 8M entries + var minSize, maxSize uint64 = defaultRingHashMinSize, defaultRingHashMaxSize + if min := rhc.GetMinimumRingSize(); min != nil { + if min.GetValue() > ringHashSizeUpperBound { + return ClusterUpdate{}, fmt.Errorf("unexpected ring_hash mininum ring size %v in response: %+v", min.GetValue(), cluster) + } + minSize = min.GetValue() + } + if max := rhc.GetMaximumRingSize(); max != nil { + if max.GetValue() > ringHashSizeUpperBound { + return ClusterUpdate{}, fmt.Errorf("unexpected ring_hash maxinum ring size %v in response: %+v", max.GetValue(), cluster) + } + maxSize = max.GetValue() + } + if minSize > maxSize { + return ClusterUpdate{}, fmt.Errorf("ring_hash config min size %v is greater than max %v", minSize, maxSize) + } + lbPolicy = &ClusterLBPolicyRingHash{MinimumRingSize: minSize, MaximumRingSize: maxSize} + default: + return ClusterUpdate{}, fmt.Errorf("unexpected lbPolicy %v in response: %+v", cluster.GetLbPolicy(), cluster) + } + + // Process security configuration received from the control plane iff the + // corresponding environment variable is set. + var sc *SecurityConfig + if env.ClientSideSecuritySupport { + var err error + if sc, err = securityConfigFromCluster(cluster); err != nil { + return ClusterUpdate{}, err + } + } + + ret := ClusterUpdate{ + ClusterName: cluster.GetName(), + EnableLRS: cluster.GetLrsServer().GetSelf() != nil, + SecurityCfg: sc, + MaxRequests: circuitBreakersFromCluster(cluster), + LBPolicy: lbPolicy, + } + + // Validate and set cluster type from the response. + switch { + case cluster.GetType() == v3clusterpb.Cluster_EDS: + if cluster.GetEdsClusterConfig().GetEdsConfig().GetAds() == nil { + return ClusterUpdate{}, fmt.Errorf("unexpected edsConfig in response: %+v", cluster) + } + ret.ClusterType = ClusterTypeEDS + ret.EDSServiceName = cluster.GetEdsClusterConfig().GetServiceName() + return ret, nil + case cluster.GetType() == v3clusterpb.Cluster_LOGICAL_DNS: + if !env.AggregateAndDNSSupportEnv { + return ClusterUpdate{}, fmt.Errorf("unsupported cluster type (%v, %v) in response: %+v", cluster.GetType(), cluster.GetClusterType(), cluster) + } + ret.ClusterType = ClusterTypeLogicalDNS + dnsHN, err := dnsHostNameFromCluster(cluster) + if err != nil { + return ClusterUpdate{}, err + } + ret.DNSHostName = dnsHN + return ret, nil + case cluster.GetClusterType() != nil && cluster.GetClusterType().Name == "envoy.clusters.aggregate": + if !env.AggregateAndDNSSupportEnv { + return ClusterUpdate{}, fmt.Errorf("unsupported cluster type (%v, %v) in response: %+v", cluster.GetType(), cluster.GetClusterType(), cluster) + } + clusters := &v3aggregateclusterpb.ClusterConfig{} + if err := proto.Unmarshal(cluster.GetClusterType().GetTypedConfig().GetValue(), clusters); err != nil { + return ClusterUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + ret.ClusterType = ClusterTypeAggregate + ret.PrioritizedClusterNames = clusters.Clusters + return ret, nil + default: + return ClusterUpdate{}, fmt.Errorf("unsupported cluster type (%v, %v) in response: %+v", cluster.GetType(), cluster.GetClusterType(), cluster) + } +} + +// dnsHostNameFromCluster extracts the DNS host name from the cluster's load +// assignment. +// +// There should be exactly one locality, with one endpoint, whose address +// contains the address and port. +func dnsHostNameFromCluster(cluster *v3clusterpb.Cluster) (string, error) { + loadAssignment := cluster.GetLoadAssignment() + if loadAssignment == nil { + return "", fmt.Errorf("load_assignment not present for LOGICAL_DNS cluster") + } + if len(loadAssignment.GetEndpoints()) != 1 { + return "", fmt.Errorf("load_assignment for LOGICAL_DNS cluster must have exactly one locality, got: %+v", loadAssignment) + } + endpoints := loadAssignment.GetEndpoints()[0].GetLbEndpoints() + if len(endpoints) != 1 { + return "", fmt.Errorf("locality for LOGICAL_DNS cluster must have exactly one endpoint, got: %+v", endpoints) + } + endpoint := endpoints[0].GetEndpoint() + if endpoint == nil { + return "", fmt.Errorf("endpoint for LOGICAL_DNS cluster not set") + } + socketAddr := endpoint.GetAddress().GetSocketAddress() + if socketAddr == nil { + return "", fmt.Errorf("socket address for endpoint for LOGICAL_DNS cluster not set") + } + if socketAddr.GetResolverName() != "" { + return "", fmt.Errorf("socket address for endpoint for LOGICAL_DNS cluster not set has unexpected custom resolver name: %v", socketAddr.GetResolverName()) + } + host := socketAddr.GetAddress() + if host == "" { + return "", fmt.Errorf("host for endpoint for LOGICAL_DNS cluster not set") + } + port := socketAddr.GetPortValue() + if port == 0 { + return "", fmt.Errorf("port for endpoint for LOGICAL_DNS cluster not set") + } + return net.JoinHostPort(host, strconv.Itoa(int(port))), nil +} + +// securityConfigFromCluster extracts the relevant security configuration from +// the received Cluster resource. +func securityConfigFromCluster(cluster *v3clusterpb.Cluster) (*SecurityConfig, error) { + if tsm := cluster.GetTransportSocketMatches(); len(tsm) != 0 { + return nil, fmt.Errorf("unsupport transport_socket_matches field is non-empty: %+v", tsm) + } + // The Cluster resource contains a `transport_socket` field, which contains + // a oneof `typed_config` field of type `protobuf.Any`. The any proto + // contains a marshaled representation of an `UpstreamTlsContext` message. + ts := cluster.GetTransportSocket() + if ts == nil { + return nil, nil + } + if name := ts.GetName(); name != transportSocketName { + return nil, fmt.Errorf("transport_socket field has unexpected name: %s", name) + } + any := ts.GetTypedConfig() + if any == nil || any.TypeUrl != version.V3UpstreamTLSContextURL { + return nil, fmt.Errorf("transport_socket field has unexpected typeURL: %s", any.TypeUrl) + } + upstreamCtx := &v3tlspb.UpstreamTlsContext{} + if err := proto.Unmarshal(any.GetValue(), upstreamCtx); err != nil { + return nil, fmt.Errorf("failed to unmarshal UpstreamTlsContext in CDS response: %v", err) + } + // The following fields from `UpstreamTlsContext` are ignored: + // - sni + // - allow_renegotiation + // - max_session_keys + if upstreamCtx.GetCommonTlsContext() == nil { + return nil, errors.New("UpstreamTlsContext in CDS response does not contain a CommonTlsContext") + } + + return securityConfigFromCommonTLSContext(upstreamCtx.GetCommonTlsContext(), false) +} + +// common is expected to be not nil. +// The `alpn_protocols` field is ignored. +func securityConfigFromCommonTLSContext(common *v3tlspb.CommonTlsContext, server bool) (*SecurityConfig, error) { + if common.GetTlsParams() != nil { + return nil, fmt.Errorf("unsupported tls_params field in CommonTlsContext message: %+v", common) + } + if common.GetCustomHandshaker() != nil { + return nil, fmt.Errorf("unsupported custom_handshaker field in CommonTlsContext message: %+v", common) + } + + // For now, if we can't get a valid security config from the new fields, we + // fallback to the old deprecated fields. + // TODO: Drop support for deprecated fields. NACK if err != nil here. + sc, _ := securityConfigFromCommonTLSContextUsingNewFields(common, server) + if sc == nil || sc.Equal(&SecurityConfig{}) { + var err error + sc, err = securityConfigFromCommonTLSContextWithDeprecatedFields(common, server) + if err != nil { + return nil, err + } + } + if sc != nil { + // sc == nil is a valid case where the control plane has not sent us any + // security configuration. xDS creds will use fallback creds. + if server { + if sc.IdentityInstanceName == "" { + return nil, errors.New("security configuration on the server-side does not contain identity certificate provider instance name") + } + } else { + if sc.RootInstanceName == "" { + return nil, errors.New("security configuration on the client-side does not contain root certificate provider instance name") + } + } + } + return sc, nil +} + +func securityConfigFromCommonTLSContextWithDeprecatedFields(common *v3tlspb.CommonTlsContext, server bool) (*SecurityConfig, error) { + // The `CommonTlsContext` contains a + // `tls_certificate_certificate_provider_instance` field of type + // `CertificateProviderInstance`, which contains the provider instance name + // and the certificate name to fetch identity certs. + sc := &SecurityConfig{} + if identity := common.GetTlsCertificateCertificateProviderInstance(); identity != nil { + sc.IdentityInstanceName = identity.GetInstanceName() + sc.IdentityCertName = identity.GetCertificateName() + } + + // The `CommonTlsContext` contains a `validation_context_type` field which + // is a oneof. We can get the values that we are interested in from two of + // those possible values: + // - combined validation context: + // - contains a default validation context which holds the list of + // matchers for accepted SANs. + // - contains certificate provider instance configuration + // - certificate provider instance configuration + // - in this case, we do not get a list of accepted SANs. + switch t := common.GetValidationContextType().(type) { + case *v3tlspb.CommonTlsContext_CombinedValidationContext: + combined := common.GetCombinedValidationContext() + var matchers []matcher.StringMatcher + if def := combined.GetDefaultValidationContext(); def != nil { + for _, m := range def.GetMatchSubjectAltNames() { + matcher, err := matcher.StringMatcherFromProto(m) + if err != nil { + return nil, err + } + matchers = append(matchers, matcher) + } + } + if server && len(matchers) != 0 { + return nil, fmt.Errorf("match_subject_alt_names field in validation context is not supported on the server: %v", common) + } + sc.SubjectAltNameMatchers = matchers + if pi := combined.GetValidationContextCertificateProviderInstance(); pi != nil { + sc.RootInstanceName = pi.GetInstanceName() + sc.RootCertName = pi.GetCertificateName() + } + case *v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance: + pi := common.GetValidationContextCertificateProviderInstance() + sc.RootInstanceName = pi.GetInstanceName() + sc.RootCertName = pi.GetCertificateName() + case nil: + // It is valid for the validation context to be nil on the server side. + default: + return nil, fmt.Errorf("validation context contains unexpected type: %T", t) + } + return sc, nil +} + +// gRFC A29 https://github.com/grpc/proposal/blob/master/A29-xds-tls-security.md +// specifies the new way to fetch security configuration and says the following: +// +// Although there are various ways to obtain certificates as per this proto +// (which are supported by Envoy), gRPC supports only one of them and that is +// the `CertificateProviderPluginInstance` proto. +// +// This helper function attempts to fetch security configuration from the +// `CertificateProviderPluginInstance` message, given a CommonTlsContext. +func securityConfigFromCommonTLSContextUsingNewFields(common *v3tlspb.CommonTlsContext, server bool) (*SecurityConfig, error) { + // The `tls_certificate_provider_instance` field of type + // `CertificateProviderPluginInstance` is used to fetch the identity + // certificate provider. + sc := &SecurityConfig{} + identity := common.GetTlsCertificateProviderInstance() + if identity == nil && len(common.GetTlsCertificates()) != 0 { + return nil, fmt.Errorf("expected field tls_certificate_provider_instance is not set, while unsupported field tls_certificates is set in CommonTlsContext message: %+v", common) + } + if identity == nil && common.GetTlsCertificateSdsSecretConfigs() != nil { + return nil, fmt.Errorf("expected field tls_certificate_provider_instance is not set, while unsupported field tls_certificate_sds_secret_configs is set in CommonTlsContext message: %+v", common) + } + sc.IdentityInstanceName = identity.GetInstanceName() + sc.IdentityCertName = identity.GetCertificateName() + + // The `CommonTlsContext` contains a oneof field `validation_context_type`, + // which contains the `CertificateValidationContext` message in one of the + // following ways: + // - `validation_context` field + // - this is directly of type `CertificateValidationContext` + // - `combined_validation_context` field + // - this is of type `CombinedCertificateValidationContext` and contains + // a `default validation context` field of type + // `CertificateValidationContext` + // + // The `CertificateValidationContext` message has the following fields that + // we are interested in: + // - `ca_certificate_provider_instance` + // - this is of type `CertificateProviderPluginInstance` + // - `match_subject_alt_names` + // - this is a list of string matchers + // + // The `CertificateProviderPluginInstance` message contains two fields + // - instance_name + // - this is the certificate provider instance name to be looked up in + // the bootstrap configuration + // - certificate_name + // - this is an opaque name passed to the certificate provider + var validationCtx *v3tlspb.CertificateValidationContext + switch typ := common.GetValidationContextType().(type) { + case *v3tlspb.CommonTlsContext_ValidationContext: + validationCtx = common.GetValidationContext() + case *v3tlspb.CommonTlsContext_CombinedValidationContext: + validationCtx = common.GetCombinedValidationContext().GetDefaultValidationContext() + case nil: + // It is valid for the validation context to be nil on the server side. + return sc, nil + default: + return nil, fmt.Errorf("validation context contains unexpected type: %T", typ) + } + // If we get here, it means that the `CertificateValidationContext` message + // was found through one of the supported ways. It is an error if the + // validation context is specified, but it does not contain the + // ca_certificate_provider_instance field which contains information about + // the certificate provider to be used for the root certificates. + if validationCtx.GetCaCertificateProviderInstance() == nil { + return nil, fmt.Errorf("expected field ca_certificate_provider_instance is missing in CommonTlsContext message: %+v", common) + } + // The following fields are ignored: + // - trusted_ca + // - watched_directory + // - allow_expired_certificate + // - trust_chain_verification + switch { + case len(validationCtx.GetVerifyCertificateSpki()) != 0: + return nil, fmt.Errorf("unsupported verify_certificate_spki field in CommonTlsContext message: %+v", common) + case len(validationCtx.GetVerifyCertificateHash()) != 0: + return nil, fmt.Errorf("unsupported verify_certificate_hash field in CommonTlsContext message: %+v", common) + case validationCtx.GetRequireSignedCertificateTimestamp().GetValue(): + return nil, fmt.Errorf("unsupported require_sugned_ceritificate_timestamp field in CommonTlsContext message: %+v", common) + case validationCtx.GetCrl() != nil: + return nil, fmt.Errorf("unsupported crl field in CommonTlsContext message: %+v", common) + case validationCtx.GetCustomValidatorConfig() != nil: + return nil, fmt.Errorf("unsupported custom_validator_config field in CommonTlsContext message: %+v", common) + } + + if rootProvider := validationCtx.GetCaCertificateProviderInstance(); rootProvider != nil { + sc.RootInstanceName = rootProvider.GetInstanceName() + sc.RootCertName = rootProvider.GetCertificateName() + } + var matchers []matcher.StringMatcher + for _, m := range validationCtx.GetMatchSubjectAltNames() { + matcher, err := matcher.StringMatcherFromProto(m) + if err != nil { + return nil, err + } + matchers = append(matchers, matcher) + } + if server && len(matchers) != 0 { + return nil, fmt.Errorf("match_subject_alt_names field in validation context is not supported on the server: %v", common) + } + sc.SubjectAltNameMatchers = matchers + return sc, nil +} + +// circuitBreakersFromCluster extracts the circuit breakers configuration from +// the received cluster resource. Returns nil if no CircuitBreakers or no +// Thresholds in CircuitBreakers. +func circuitBreakersFromCluster(cluster *v3clusterpb.Cluster) *uint32 { + for _, threshold := range cluster.GetCircuitBreakers().GetThresholds() { + if threshold.GetPriority() != v3corepb.RoutingPriority_DEFAULT { + continue + } + maxRequestsPb := threshold.GetMaxRequests() + if maxRequestsPb == nil { + return nil + } + maxRequests := maxRequestsPb.GetValue() + return &maxRequests + } + return nil +} + +// UnmarshalEndpoints processes resources received in an EDS response, +// validates them, and transforms them into a native struct which contains only +// fields we are interested in. +func UnmarshalEndpoints(opts *UnmarshalOptions) (map[string]EndpointsUpdateErrTuple, UpdateMetadata, error) { + update := make(map[string]EndpointsUpdateErrTuple) + md, err := processAllResources(opts, update) + return update, md, err +} + +func unmarshalEndpointsResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, EndpointsUpdate, error) { + if !IsEndpointsResource(r.GetTypeUrl()) { + return "", EndpointsUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) + } + + cla := &v3endpointpb.ClusterLoadAssignment{} + if err := proto.Unmarshal(r.GetValue(), cla); err != nil { + return "", EndpointsUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + logger.Infof("Resource with name: %v, type: %T, contains: %v", cla.GetClusterName(), cla, pretty.ToJSON(cla)) + + u, err := parseEDSRespProto(cla) + if err != nil { + return cla.GetClusterName(), EndpointsUpdate{}, err + } + u.Raw = r + return cla.GetClusterName(), u, nil +} + +func parseAddress(socketAddress *v3corepb.SocketAddress) string { + return net.JoinHostPort(socketAddress.GetAddress(), strconv.Itoa(int(socketAddress.GetPortValue()))) +} + +func parseDropPolicy(dropPolicy *v3endpointpb.ClusterLoadAssignment_Policy_DropOverload) OverloadDropConfig { + percentage := dropPolicy.GetDropPercentage() + var ( + numerator = percentage.GetNumerator() + denominator uint32 + ) + switch percentage.GetDenominator() { + case v3typepb.FractionalPercent_HUNDRED: + denominator = 100 + case v3typepb.FractionalPercent_TEN_THOUSAND: + denominator = 10000 + case v3typepb.FractionalPercent_MILLION: + denominator = 1000000 + } + return OverloadDropConfig{ + Category: dropPolicy.GetCategory(), + Numerator: numerator, + Denominator: denominator, + } +} + +func parseEndpoints(lbEndpoints []*v3endpointpb.LbEndpoint) []Endpoint { + endpoints := make([]Endpoint, 0, len(lbEndpoints)) + for _, lbEndpoint := range lbEndpoints { + endpoints = append(endpoints, Endpoint{ + HealthStatus: EndpointHealthStatus(lbEndpoint.GetHealthStatus()), + Address: parseAddress(lbEndpoint.GetEndpoint().GetAddress().GetSocketAddress()), + Weight: lbEndpoint.GetLoadBalancingWeight().GetValue(), + }) + } + return endpoints +} + +func parseEDSRespProto(m *v3endpointpb.ClusterLoadAssignment) (EndpointsUpdate, error) { + ret := EndpointsUpdate{} + for _, dropPolicy := range m.GetPolicy().GetDropOverloads() { + ret.Drops = append(ret.Drops, parseDropPolicy(dropPolicy)) + } + priorities := make(map[uint32]struct{}) + for _, locality := range m.Endpoints { + l := locality.GetLocality() + if l == nil { + return EndpointsUpdate{}, fmt.Errorf("EDS response contains a locality without ID, locality: %+v", locality) + } + lid := internal.LocalityID{ + Region: l.Region, + Zone: l.Zone, + SubZone: l.SubZone, + } + priority := locality.GetPriority() + priorities[priority] = struct{}{} + ret.Localities = append(ret.Localities, Locality{ + ID: lid, + Endpoints: parseEndpoints(locality.GetLbEndpoints()), + Weight: locality.GetLoadBalancingWeight().GetValue(), + Priority: priority, + }) + } + for i := 0; i < len(priorities); i++ { + if _, ok := priorities[uint32(i)]; !ok { + return EndpointsUpdate{}, fmt.Errorf("priority %v missing (with different priorities %v received)", i, priorities) + } + } + return ret, nil +} + +// ListenerUpdateErrTuple is a tuple with the update and error. It contains the +// results from unmarshal functions. It's used to pass unmarshal results of +// multiple resources together, e.g. in maps like `map[string]{Update,error}`. +type ListenerUpdateErrTuple struct { + Update ListenerUpdate + Err error +} + +// RouteConfigUpdateErrTuple is a tuple with the update and error. It contains +// the results from unmarshal functions. It's used to pass unmarshal results of +// multiple resources together, e.g. in maps like `map[string]{Update,error}`. +type RouteConfigUpdateErrTuple struct { + Update RouteConfigUpdate + Err error +} + +// ClusterUpdateErrTuple is a tuple with the update and error. It contains the +// results from unmarshal functions. It's used to pass unmarshal results of +// multiple resources together, e.g. in maps like `map[string]{Update,error}`. +type ClusterUpdateErrTuple struct { + Update ClusterUpdate + Err error +} + +// EndpointsUpdateErrTuple is a tuple with the update and error. It contains the +// results from unmarshal functions. It's used to pass unmarshal results of +// multiple resources together, e.g. in maps like `map[string]{Update,error}`. +type EndpointsUpdateErrTuple struct { + Update EndpointsUpdate + Err error +} + +// processAllResources unmarshals and validates the resources, populates the +// provided ret (a map), and returns metadata and error. +// +// After this function, the ret map will be populated with both valid and +// invalid updates. Invalid resources will have an entry with the key as the +// resource name, value as an empty update. +// +// The type of the resource is determined by the type of ret. E.g. +// map[string]ListenerUpdate means this is for LDS. +func processAllResources(opts *UnmarshalOptions, ret interface{}) (UpdateMetadata, error) { + timestamp := time.Now() + md := UpdateMetadata{ + Version: opts.Version, + Timestamp: timestamp, + } + var topLevelErrors []error + perResourceErrors := make(map[string]error) + + for _, r := range opts.Resources { + switch ret2 := ret.(type) { + case map[string]ListenerUpdateErrTuple: + name, update, err := unmarshalListenerResource(r, opts.UpdateValidator, opts.Logger) + if err == nil { + ret2[name] = ListenerUpdateErrTuple{Update: update} + continue + } + if name == "" { + topLevelErrors = append(topLevelErrors, err) + continue + } + perResourceErrors[name] = err + // Add place holder in the map so we know this resource name was in + // the response. + ret2[name] = ListenerUpdateErrTuple{Err: err} + case map[string]RouteConfigUpdateErrTuple: + name, update, err := unmarshalRouteConfigResource(r, opts.Logger) + if err == nil { + ret2[name] = RouteConfigUpdateErrTuple{Update: update} + continue + } + if name == "" { + topLevelErrors = append(topLevelErrors, err) + continue + } + perResourceErrors[name] = err + // Add place holder in the map so we know this resource name was in + // the response. + ret2[name] = RouteConfigUpdateErrTuple{Err: err} + case map[string]ClusterUpdateErrTuple: + name, update, err := unmarshalClusterResource(r, opts.UpdateValidator, opts.Logger) + if err == nil { + ret2[name] = ClusterUpdateErrTuple{Update: update} + continue + } + if name == "" { + topLevelErrors = append(topLevelErrors, err) + continue + } + perResourceErrors[name] = err + // Add place holder in the map so we know this resource name was in + // the response. + ret2[name] = ClusterUpdateErrTuple{Err: err} + case map[string]EndpointsUpdateErrTuple: + name, update, err := unmarshalEndpointsResource(r, opts.Logger) + if err == nil { + ret2[name] = EndpointsUpdateErrTuple{Update: update} + continue + } + if name == "" { + topLevelErrors = append(topLevelErrors, err) + continue + } + perResourceErrors[name] = err + // Add place holder in the map so we know this resource name was in + // the response. + ret2[name] = EndpointsUpdateErrTuple{Err: err} + } + } + + if len(topLevelErrors) == 0 && len(perResourceErrors) == 0 { + md.Status = ServiceStatusACKed + return md, nil + } + + var typeStr string + switch ret.(type) { + case map[string]ListenerUpdate: + typeStr = "LDS" + case map[string]RouteConfigUpdate: + typeStr = "RDS" + case map[string]ClusterUpdate: + typeStr = "CDS" + case map[string]EndpointsUpdate: + typeStr = "EDS" + } + + md.Status = ServiceStatusNACKed + errRet := combineErrors(typeStr, topLevelErrors, perResourceErrors) + md.ErrState = &UpdateErrorMetadata{ + Version: opts.Version, + Err: errRet, + Timestamp: timestamp, + } + return md, errRet +} + +func combineErrors(rType string, topLevelErrors []error, perResourceErrors map[string]error) error { + var errStrB strings.Builder + errStrB.WriteString(fmt.Sprintf("error parsing %q response: ", rType)) + if len(topLevelErrors) > 0 { + errStrB.WriteString("top level errors: ") + for i, err := range topLevelErrors { + if i != 0 { + errStrB.WriteString(";\n") + } + errStrB.WriteString(err.Error()) + } + } + if len(perResourceErrors) > 0 { + var i int + for name, err := range perResourceErrors { + if i != 0 { + errStrB.WriteString(";\n") + } + i++ + errStrB.WriteString(fmt.Sprintf("resource %q: %v", name, err.Error())) + } + } + return errors.New(errStrB.String()) +} diff --git a/xds/internal/client/tests/client_test.go b/xds/internal/xdsclient/xdsclient_test.go similarity index 92% rename from xds/internal/client/tests/client_test.go rename to xds/internal/xdsclient/xdsclient_test.go index f5a57fbcd21..f348df48161 100644 --- a/xds/internal/client/tests/client_test.go +++ b/xds/internal/xdsclient/xdsclient_test.go @@ -16,7 +16,7 @@ * */ -package tests_test +package xdsclient_test import ( "testing" @@ -25,11 +25,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpctest" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" - _ "google.golang.org/grpc/xds/internal/client/v2" // Register the v2 API client. "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register the v2 API client. ) type s struct { diff --git a/xds/server.go b/xds/server.go index f1c1e4181b8..b36fa64b500 100644 --- a/xds/server.go +++ b/xds/server.go @@ -23,89 +23,78 @@ import ( "errors" "fmt" "net" + "strings" "sync" - "time" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal" - xdsinternal "google.golang.org/grpc/internal/credentials/xds" + "google.golang.org/grpc/internal/buffer" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/transport" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/grpc/xds/internal/server" + "google.golang.org/grpc/xds/internal/xdsclient" ) const serverPrefix = "[xds-server %p] " var ( // These new functions will be overridden in unit tests. - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { return xdsclient.New() } - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { return grpc.NewServer(opts...) } - // Unexported function to retrieve transport credentials from a gRPC server. - grpcGetServerCreds = internal.GetServerCredentials.(func(*grpc.Server) credentials.TransportCredentials) - buildProvider = buildProviderFunc - logger = grpclog.Component("xds") + grpcGetServerCreds = internal.GetServerCredentials.(func(*grpc.Server) credentials.TransportCredentials) + drainServerTransports = internal.DrainServerTransports.(func(*grpc.Server, string)) + logger = grpclog.Component("xds") ) func prefixLogger(p *GRPCServer) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(serverPrefix, p)) } -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the server. This is useful for overriding in unit tests. -type xdsClientInterface interface { - WatchListener(string, func(xdsclient.ListenerUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - -// grpcServerInterface contains methods from grpc.Server which are used by the +// grpcServer contains methods from grpc.Server which are used by the // GRPCServer type here. This is useful for overriding in unit tests. -type grpcServerInterface interface { +type grpcServer interface { RegisterService(*grpc.ServiceDesc, interface{}) Serve(net.Listener) error Stop() GracefulStop() + GetServiceInfo() map[string]grpc.ServiceInfo } // GRPCServer wraps a gRPC server and provides server-side xDS functionality, by // communication with a management server using xDS APIs. It implements the // grpc.ServiceRegistrar interface and can be passed to service registration // functions in IDL generated code. -// -// Experimental -// -// Notice: This type is EXPERIMENTAL and may be changed or removed in a -// later release. type GRPCServer struct { - gs grpcServerInterface + gs grpcServer quit *grpcsync.Event logger *internalgrpclog.PrefixLogger xdsCredsInUse bool + opts *serverOptions // clientMu is used only in initXDSClient(), which is called at the // beginning of Serve(), where we have to decide if we have to create a // client or use an existing one. clientMu sync.Mutex - xdsC xdsClientInterface + xdsC xdsclient.XDSClient } // NewGRPCServer creates an xDS-enabled gRPC server using the passed in opts. // The underlying gRPC server has no service registered and has not started to // accept requests yet. -// -// Experimental -// -// Notice: This API is EXPERIMENTAL and may be changed or removed in a later -// release. func NewGRPCServer(opts ...grpc.ServerOption) *GRPCServer { newOpts := []grpc.ServerOption{ grpc.ChainUnaryInterceptor(xdsUnaryInterceptor), @@ -115,6 +104,7 @@ func NewGRPCServer(opts ...grpc.ServerOption) *GRPCServer { s := &GRPCServer{ gs: newGRPCServer(newOpts...), quit: grpcsync.NewEvent(), + opts: handleServerOptions(opts), } s.logger = prefixLogger(s) s.logger.Infof("Created xds.GRPCServer") @@ -135,6 +125,18 @@ func NewGRPCServer(opts ...grpc.ServerOption) *GRPCServer { return s } +// handleServerOptions iterates through the list of server options passed in by +// the user, and handles the xDS server specific options. +func handleServerOptions(opts []grpc.ServerOption) *serverOptions { + so := &serverOptions{} + for _, opt := range opts { + if o, ok := opt.(*serverOption); ok { + o.apply(so) + } + } + return so +} + // RegisterService registers a service and its implementation to the underlying // gRPC server. It is called from the IDL generated code. This must be called // before invoking Serve. @@ -142,6 +144,12 @@ func (s *GRPCServer) RegisterService(sd *grpc.ServiceDesc, ss interface{}) { s.gs.RegisterService(sd, ss) } +// GetServiceInfo returns a map from service names to ServiceInfo. +// Service names include the package names, in the form of .. +func (s *GRPCServer) GetServiceInfo() map[string]grpc.ServiceInfo { + return s.gs.GetServiceInfo() +} + // initXDSClient creates a new xdsClient if there is no existing one available. func (s *GRPCServer) initXDSClient() error { s.clientMu.Lock() @@ -151,6 +159,12 @@ func (s *GRPCServer) initXDSClient() error { return nil } + newXDSClient := newXDSClient + if s.opts.bootstrapContents != nil { + newXDSClient = func() (xdsclient.XDSClient, error) { + return xdsclient.NewClientWithBootstrapContents(s.opts.bootstrapContents) + } + } client, err := newXDSClient() if err != nil { return fmt.Errorf("xds: failed to create xds-client: %v", err) @@ -178,76 +192,60 @@ func (s *GRPCServer) Serve(lis net.Listener) error { if err := s.initXDSClient(); err != nil { return err } + cfg := s.xdsC.BootstrapConfig() + if cfg == nil { + return errors.New("bootstrap configuration is empty") + } // If xds credentials were specified by the user, but bootstrap configs do // not contain any certificate provider configuration, it is better to fail // right now rather than failing when attempting to create certificate // providers after receiving an LDS response with security configuration. if s.xdsCredsInUse { - bc := s.xdsC.BootstrapConfig() - if bc == nil || len(bc.CertProviderConfigs) == 0 { + if len(cfg.CertProviderConfigs) == 0 { return errors.New("xds: certificate_providers config missing in bootstrap file") } } - lw, err := s.newListenerWrapper(lis) - if lw == nil { - // Error returned can be nil (when Stop/GracefulStop() is called). So, - // we need to check the returned listenerWrapper instead. - return err - } - return s.gs.Serve(lw) -} - -// newListenerWrapper creates and returns a listenerWrapper, which is a thin -// wrapper around the passed in listener lis, that can be passed to -// grpcServer.Serve(). -// -// It then registers a watch for a Listener resource and blocks until a good -// response is received or the server is stopped by a call to -// Stop/GracefulStop(). -func (s *GRPCServer) newListenerWrapper(lis net.Listener) (*listenerWrapper, error) { - lw := &listenerWrapper{ - Listener: lis, - closed: grpcsync.NewEvent(), - xdsHI: xdsinternal.NewHandshakeInfo(nil, nil), - } - - // This is used to notify that a good update has been received and that - // Serve() can be invoked on the underlying gRPC server. Using a - // grpcsync.Event instead of a vanilla channel simplifies the update handler - // as it need not keep track of whether the received update is the first one - // or not. - goodUpdate := grpcsync.NewEvent() - - // The resource_name in the LDS request sent by the xDS-enabled gRPC server - // is of the following format: - // "/path/to/resource?udpa.resource.listening_address=IP:Port". The - // `/path/to/resource` part of the name is sourced from the bootstrap config - // field `grpc_server_resource_name_id`. If this field is not specified in - // the bootstrap file, we will use a default of `grpc/server`. - path := "grpc/server" - if cfg := s.xdsC.BootstrapConfig(); cfg != nil && cfg.ServerResourceNameID != "" { - path = cfg.ServerResourceNameID - } - name := fmt.Sprintf("%s?udpa.resource.listening_address=%s", path, lis.Addr().String()) - - // Register an LDS watch using our xdsClient, and specify the listening - // address as the resource name. - cancelWatch := s.xdsC.WatchListener(name, func(update xdsclient.ListenerUpdate, err error) { - s.handleListenerUpdate(listenerUpdate{ - lw: lw, - name: name, - lds: update, - err: err, - goodUpdate: goodUpdate, - }) + // The server listener resource name template from the bootstrap + // configuration contains a template for the name of the Listener resource + // to subscribe to for a gRPC server. If the token `%s` is present in the + // string, it will be replaced with the server's listening "IP:port" (e.g., + // "0.0.0.0:8080", "[::]:8080"). The absence of a template will be treated + // as an error since we do not have any default value for this. + if cfg.ServerListenerResourceNameTemplate == "" { + return errors.New("missing server_listener_resource_name_template in the bootstrap configuration") + } + name := cfg.ServerListenerResourceNameTemplate + if strings.Contains(cfg.ServerListenerResourceNameTemplate, "%s") { + name = strings.Replace(cfg.ServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) + } + + modeUpdateCh := buffer.NewUnbounded() + go func() { + s.handleServingModeChanges(modeUpdateCh) + }() + + // Create a listenerWrapper which handles all functionality required by + // this particular instance of Serve(). + lw, goodUpdateCh := server.NewListenerWrapper(server.ListenerWrapperParams{ + Listener: lis, + ListenerResourceName: name, + XDSCredsInUse: s.xdsCredsInUse, + XDSClient: s.xdsC, + ModeCallback: func(addr net.Addr, mode connectivity.ServingMode, err error) { + modeUpdateCh.Put(&modeChangeArgs{ + addr: addr, + mode: mode, + err: err, + }) + }, + DrainCallback: func(addr net.Addr) { + if gs, ok := s.gs.(*grpc.Server); ok { + drainServerTransports(gs, addr.String()) + } + }, }) - s.logger.Infof("Watch started on resource name %v", name) - lw.cancelWatch = func() { - cancelWatch() - s.logger.Infof("Watch cancelled on resource name %v", name) - } // Block until a good LDS response is received or the server is stopped. select { @@ -256,123 +254,51 @@ func (s *GRPCServer) newListenerWrapper(lis net.Listener) (*listenerWrapper, err // need to explicitly close the listener. Cancellation of the xDS watch // is handled by the listenerWrapper. lw.Close() - return nil, nil - case <-goodUpdate.Done(): + return nil + case <-goodUpdateCh: } - return lw, nil + return s.gs.Serve(lw) } -// listenerUpdate wraps the information received from a registered LDS watcher. -type listenerUpdate struct { - lw *listenerWrapper // listener associated with this watch - name string // resource name being watched - lds xdsclient.ListenerUpdate // received update - err error // received error - goodUpdate *grpcsync.Event // event to fire upon a good update +// modeChangeArgs wraps argument required for invoking mode change callback. +type modeChangeArgs struct { + addr net.Addr + mode connectivity.ServingMode + err error } -func (s *GRPCServer) handleListenerUpdate(update listenerUpdate) { - if update.lw.closed.HasFired() { - s.logger.Warningf("Resource %q received update: %v with error: %v, after for listener was closed", update.name, update.lds, update.err) - return - } - - if update.err != nil { - // We simply log an error here and hope we get a successful update - // in the future. The error could be because of a timeout or an - // actual error, like the requested resource not found. In any case, - // it is fine for the server to hang indefinitely until Stop() is - // called. - s.logger.Warningf("Received error for resource %q: %+v", update.name, update.err) - return - } - s.logger.Infof("Received update for resource %q: %+v", update.name, update.lds.String()) - - if err := s.handleSecurityConfig(update.lds.SecurityCfg, update.lw); err != nil { - s.logger.Warningf("Invalid security config update: %v", err) - return - } - - // If we got all the way here, it means the received update was a good one. - update.goodUpdate.Fire() -} - -func (s *GRPCServer) handleSecurityConfig(config *xdsclient.SecurityConfig, lw *listenerWrapper) error { - // If xdsCredentials are not in use, i.e, the user did not want to get - // security configuration from the control plane, we should not be acting on - // the received security config here. Doing so poses a security threat. - if !s.xdsCredsInUse { - return nil - } - - // Security config being nil is a valid case where the control plane has - // not sent any security configuration. The xdsCredentials implementation - // handles this by delegating to its fallback credentials. - if config == nil { - // We need to explicitly set the fields to nil here since this might be - // a case of switching from a good security configuration to an empty - // one where fallback credentials are to be used. - lw.xdsHI.SetRootCertProvider(nil) - lw.xdsHI.SetIdentityCertProvider(nil) - lw.xdsHI.SetRequireClientCert(false) - return nil - } - - cpc := s.xdsC.BootstrapConfig().CertProviderConfigs - // Identity provider is mandatory on the server side. - identityProvider, err := buildProvider(cpc, config.IdentityInstanceName, config.IdentityCertName, true, false) - if err != nil { - return err - } - - // A root provider is required only when doing mTLS. - var rootProvider certprovider.Provider - if config.RootInstanceName != "" { - rootProvider, err = buildProvider(cpc, config.RootInstanceName, config.RootCertName, false, true) - if err != nil { - return err +// handleServingModeChanges runs as a separate goroutine, spawned from Serve(). +// It reads a channel on to which mode change arguments are pushed, and in turn +// invokes the user registered callback. It also calls an internal method on the +// underlying grpc.Server to gracefully close existing connections, if the +// listener moved to a "not-serving" mode. +func (s *GRPCServer) handleServingModeChanges(updateCh *buffer.Unbounded) { + for { + select { + case <-s.quit.Done(): + return + case u := <-updateCh.Get(): + updateCh.Load() + args := u.(*modeChangeArgs) + if args.mode == connectivity.ServingModeNotServing { + // We type assert our underlying gRPC server to the real + // grpc.Server here before trying to initiate the drain + // operation. This approach avoids performing the same type + // assertion in the grpc package which provides the + // implementation for internal.GetServerCredentials, and allows + // us to use a fake gRPC server in tests. + if gs, ok := s.gs.(*grpc.Server); ok { + drainServerTransports(gs, args.addr.String()) + } + } + if s.opts.modeCallback != nil { + s.opts.modeCallback(args.addr, ServingModeChangeArgs{ + Mode: args.mode, + Err: args.err, + }) + } } } - - // Close the old providers and cache the new ones. - lw.providerMu.Lock() - if lw.cachedIdentity != nil { - lw.cachedIdentity.Close() - } - if lw.cachedRoot != nil { - lw.cachedRoot.Close() - } - lw.cachedRoot = rootProvider - lw.cachedIdentity = identityProvider - - // We set all fields here, even if some of them are nil, since they - // could have been non-nil earlier. - lw.xdsHI.SetRootCertProvider(rootProvider) - lw.xdsHI.SetIdentityCertProvider(identityProvider) - lw.xdsHI.SetRequireClientCert(config.RequireClientCert) - lw.providerMu.Unlock() - - return nil -} - -func buildProviderFunc(configs map[string]*certprovider.BuildableConfig, instanceName, certName string, wantIdentity, wantRoot bool) (certprovider.Provider, error) { - cfg, ok := configs[instanceName] - if !ok { - return nil, fmt.Errorf("certificate provider instance %q not found in bootstrap file", instanceName) - } - provider, err := cfg.Build(certprovider.BuildOptions{ - CertName: certName, - WantIdentity: wantIdentity, - WantRoot: wantRoot, - }) - if err != nil { - // This error is not expected since the bootstrap process parses the - // config and makes sure that it is acceptable to the plugin. Still, it - // is possible that the plugin parses the config successfully, but its - // Build() method errors out. - return nil, fmt.Errorf("failed to get security plugin instance (%+v): %v", cfg, err) - } - return provider, nil } // Stop stops the underlying gRPC server. It immediately closes all open @@ -398,120 +324,79 @@ func (s *GRPCServer) GracefulStop() { } } +// routeAndProcess routes the incoming RPC to a configured route in the route +// table and also processes the RPC by running the incoming RPC through any HTTP +// Filters configured. +func routeAndProcess(ctx context.Context) error { + conn := transport.GetConnection(ctx) + cw, ok := conn.(interface { + VirtualHosts() []xdsclient.VirtualHostWithInterceptors + }) + if !ok { + return errors.New("missing virtual hosts in incoming context") + } + mn, ok := grpc.Method(ctx) + if !ok { + return errors.New("missing method name in incoming context") + } + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return errors.New("missing metadata in incoming context") + } + // A41 added logic to the core grpc implementation to guarantee that once + // the RPC gets to this point, there will be a single, unambiguous authority + // present in the header map. + authority := md.Get(":authority") + vh := xdsclient.FindBestMatchingVirtualHostServer(authority[0], cw.VirtualHosts()) + if vh == nil { + return status.Error(codes.Unavailable, "the incoming RPC did not match a configured Virtual Host") + } + + var rwi *xdsclient.RouteWithInterceptors + rpcInfo := iresolver.RPCInfo{ + Context: ctx, + Method: mn, + } + for _, r := range vh.Routes { + if r.M.Match(rpcInfo) { + // "NonForwardingAction is expected for all Routes used on server-side; a route with an inappropriate action causes + // RPCs matching that route to fail with UNAVAILABLE." - A36 + if r.RouteAction != xdsclient.RouteActionNonForwardingAction { + return status.Error(codes.Unavailable, "the incoming RPC matched to a route that was not of action type non forwarding") + } + rwi = &r + break + } + } + if rwi == nil { + return status.Error(codes.Unavailable, "the incoming RPC did not match a configured Route") + } + for _, interceptor := range rwi.Interceptors { + if err := interceptor.AllowRPC(ctx); err != nil { + return status.Errorf(codes.PermissionDenied, "Incoming RPC is not allowed: %v", err) + } + } + return nil +} + // xdsUnaryInterceptor is the unary interceptor added to the gRPC server to // perform any xDS specific functionality on unary RPCs. -// -// This is a no-op at this point. func xdsUnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + if env.RBACSupport { + if err := routeAndProcess(ctx); err != nil { + return nil, err + } + } return handler(ctx, req) } // xdsStreamInterceptor is the stream interceptor added to the gRPC server to // perform any xDS specific functionality on streaming RPCs. -// -// This is a no-op at this point. func xdsStreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - return handler(srv, ss) -} - -// listenerWrapper wraps the net.Listener associated with the listening address -// passed to Serve(). It also contains all other state associated with this -// particular invocation of Serve(). -type listenerWrapper struct { - net.Listener - cancelWatch func() - - // A small race exists in the xdsClient code where an xDS response is - // received while the user is calling cancel(). In this small window the - // registered callback can be called after the watcher is canceled. We avoid - // processing updates received in callbacks once the listener is closed, to - // make sure that we do not process updates received during this race - // window. - closed *grpcsync.Event - - // The certificate providers are cached here to that they can be closed when - // a new provider is to be created. - providerMu sync.Mutex - cachedRoot certprovider.Provider - cachedIdentity certprovider.Provider - - // Wraps all information required by the xds handshaker. - xdsHI *xdsinternal.HandshakeInfo -} - -// Accept blocks on an Accept() on the underlying listener, and wraps the -// returned net.Conn with the configured certificate providers. -func (l *listenerWrapper) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - return &conn{Conn: c, xdsHI: l.xdsHI}, nil -} - -// Close closes the underlying listener. It also cancels the xDS watch -// registered in Serve() and closes any certificate provider instances created -// based on security configuration received in the LDS response. -func (l *listenerWrapper) Close() error { - l.closed.Fire() - l.Listener.Close() - if l.cancelWatch != nil { - l.cancelWatch() - } - - l.providerMu.Lock() - if l.cachedIdentity != nil { - l.cachedIdentity.Close() - l.cachedIdentity = nil - } - if l.cachedRoot != nil { - l.cachedRoot.Close() - l.cachedRoot = nil + if env.RBACSupport { + if err := routeAndProcess(ss.Context()); err != nil { + return err + } } - l.providerMu.Unlock() - - return nil -} - -// conn is a thin wrapper around a net.Conn returned by Accept(). -type conn struct { - net.Conn - - // This is the same HandshakeInfo as stored in the listenerWrapper that - // created this conn. The former updates the HandshakeInfo whenever it - // receives new security configuration. - xdsHI *xdsinternal.HandshakeInfo - - // The connection deadline as configured by the grpc.Server on the rawConn - // that is returned by a call to Accept(). This is set to the connection - // timeout value configured by the user (or to a default value) before - // initiating the transport credential handshake, and set to zero after - // completing the HTTP2 handshake. - deadlineMu sync.Mutex - deadline time.Time -} - -// SetDeadline makes a copy of the passed in deadline and forwards the call to -// the underlying rawConn. -func (c *conn) SetDeadline(t time.Time) error { - c.deadlineMu.Lock() - c.deadline = t - c.deadlineMu.Unlock() - return c.Conn.SetDeadline(t) -} - -// GetDeadline returns the configured deadline. This will be invoked by the -// ServerHandshake() method of the XdsCredentials, which needs a deadline to -// pass to the certificate provider. -func (c *conn) GetDeadline() time.Time { - c.deadlineMu.Lock() - t := c.deadline - c.deadlineMu.Unlock() - return t -} - -// XDSHandshakeInfo returns a pointer to the HandshakeInfo stored in conn. This -// will be invoked by the ServerHandshake() method of the XdsCredentials. -func (c *conn) XDSHandshakeInfo() *xdsinternal.HandshakeInfo { - return c.xdsHI + return handler(srv, ss) } diff --git a/xds/server_options.go b/xds/server_options.go new file mode 100644 index 00000000000..1d46c3adb7b --- /dev/null +++ b/xds/server_options.go @@ -0,0 +1,76 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "net" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" +) + +type serverOptions struct { + modeCallback ServingModeCallbackFunc + bootstrapContents []byte +} + +type serverOption struct { + grpc.EmptyServerOption + apply func(*serverOptions) +} + +// ServingModeCallback returns a grpc.ServerOption which allows users to +// register a callback to get notified about serving mode changes. +func ServingModeCallback(cb ServingModeCallbackFunc) grpc.ServerOption { + return &serverOption{apply: func(o *serverOptions) { o.modeCallback = cb }} +} + +// ServingModeCallbackFunc is the callback that users can register to get +// notified about the server's serving mode changes. The callback is invoked +// with the address of the listener and its new mode. +// +// Users must not perform any blocking operations in this callback. +type ServingModeCallbackFunc func(addr net.Addr, args ServingModeChangeArgs) + +// ServingModeChangeArgs wraps the arguments passed to the serving mode callback +// function. +type ServingModeChangeArgs struct { + // Mode is the new serving mode of the server listener. + Mode connectivity.ServingMode + // Err is set to a non-nil error if the server has transitioned into + // not-serving mode. + Err error +} + +// BootstrapContentsForTesting returns a grpc.ServerOption which allows users +// to inject a bootstrap configuration used by only this server, instead of the +// global configuration from the environment variables. +// +// Testing Only +// +// This function should ONLY be used for testing and may not work with some +// other features, including the CSDS service. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func BootstrapContentsForTesting(contents []byte) grpc.ServerOption { + return &serverOption{apply: func(o *serverOptions) { o.bootstrapContents = contents }} +} diff --git a/xds/server_test.go b/xds/server_test.go index cde96307926..0866e0414ae 100644 --- a/xds/server_test.go +++ b/xds/server_test.go @@ -24,27 +24,99 @@ import ( "fmt" "net" "reflect" + "strings" "testing" "time" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + _ "google.golang.org/grpc/xds/internal/httpfilter/router" xdstestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( - defaultTestTimeout = 5 * time.Second - defaultTestShortTimeout = 10 * time.Millisecond - testServerResourceNameID = "/path/to/resource" + defaultTestTimeout = 5 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + testServerListenerResourceNameTemplate = "/path/to/resource/%s/%s" ) +var listenerWithFilterChains = &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourcePorts: []uint32{80}, + }, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, +} + type s struct { grpctest.Tester } @@ -65,9 +137,10 @@ func (f *fakeGRPCServer) RegisterService(*grpc.ServiceDesc, interface{}) { f.registerServiceCh.Send(nil) } -func (f *fakeGRPCServer) Serve(net.Listener) error { +func (f *fakeGRPCServer) Serve(lis net.Listener) error { f.serveCh.Send(nil) <-f.done + lis.Close() return nil } @@ -80,6 +153,10 @@ func (f *fakeGRPCServer) GracefulStop() { f.gracefulStopCh.Send(nil) } +func (f *fakeGRPCServer) GetServiceInfo() map[string]grpc.ServiceInfo { + panic("implement me") +} + func newFakeGRPCServer() *fakeGRPCServer { return &fakeGRPCServer{ done: make(chan struct{}), @@ -90,6 +167,14 @@ func newFakeGRPCServer() *fakeGRPCServer { } } +func splitHostPort(hostport string) (string, string) { + addr, port, err := net.SplitHostPort(hostport) + if err != nil { + panic(fmt.Sprintf("listener address %q does not parse: %v", hostport, err)) + } + return addr, port +} + func (s) TestNewServer(t *testing.T) { xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) if err != nil { @@ -119,7 +204,7 @@ func (s) TestNewServer(t *testing.T) { wantServerOpts := len(test.serverOpts) + 2 origNewGRPCServer := newGRPCServer - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { if got := len(opts); got != wantServerOpts { t.Fatalf("%d ServerOptions passed to grpc.Server, want %d", got, wantServerOpts) } @@ -147,7 +232,7 @@ func (s) TestRegisterService(t *testing.T) { fs := newFakeGRPCServer() origNewGRPCServer := newGRPCServer - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { return fs } + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { return fs } defer func() { newGRPCServer = origNewGRPCServer }() s := NewGRPCServer() @@ -173,8 +258,14 @@ var ( ) func init() { - fpb1 = &fakeProviderBuilder{name: fakeProvider1Name} - fpb2 = &fakeProviderBuilder{name: fakeProvider2Name} + fpb1 = &fakeProviderBuilder{ + name: fakeProvider1Name, + buildCh: testutils.NewChannel(), + } + fpb2 = &fakeProviderBuilder{ + name: fakeProvider2Name, + buildCh: testutils.NewChannel(), + } cfg1, _ := fpb1.ParseConfig(fakeConfig + "1111") cfg2, _ := fpb2.ParseConfig(fakeConfig + "2222") certProviderConfigs = map[string]*certprovider.BuildableConfig{ @@ -188,7 +279,8 @@ func init() { // fakeProviderBuilder builds new instances of fakeProvider and interprets the // config provided to it as a string. type fakeProviderBuilder struct { - name string + name string + buildCh *testutils.Channel } func (b *fakeProviderBuilder) ParseConfig(config interface{}) (*certprovider.BuildableConfig, error) { @@ -197,6 +289,7 @@ func (b *fakeProviderBuilder) ParseConfig(config interface{}) (*certprovider.Bui return nil, fmt.Errorf("providerBuilder %s received config of type %T, want string", b.name, config) } return certprovider.NewBuildableConfig(b.name, []byte(s), func(certprovider.BuildOptions) certprovider.Provider { + b.buildCh.Send(nil) return &fakeProvider{ Distributor: certprovider.NewDistributor(), config: s, @@ -220,19 +313,19 @@ func (p *fakeProvider) Close() { p.Distributor.Stop() } -// setupOverrides sets up overrides for bootstrap config, new xdsClient creation, -// new gRPC.Server creation, and certificate provider creation. -func setupOverrides() (*fakeGRPCServer, *testutils.Channel, *testutils.Channel, func()) { +// setupOverrides sets up overrides for bootstrap config, new xdsClient creation +// and new gRPC.Server creation. +func setupOverrides() (*fakeGRPCServer, *testutils.Channel, func()) { clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() c.SetBootstrapConfig(&bootstrap.Config{ - BalancerName: "dummyBalancer", - Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), - NodeProto: xdstestutils.EmptyNodeProtoV3, - ServerResourceNameID: testServerResourceNameID, - CertProviderConfigs: certProviderConfigs, + BalancerName: "dummyBalancer", + Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), + NodeProto: xdstestutils.EmptyNodeProtoV3, + ServerListenerResourceNameTemplate: testServerListenerResourceNameTemplate, + CertProviderConfigs: certProviderConfigs, }) clientCh.Send(c) return c, nil @@ -240,20 +333,11 @@ func setupOverrides() (*fakeGRPCServer, *testutils.Channel, *testutils.Channel, fs := newFakeGRPCServer() origNewGRPCServer := newGRPCServer - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { return fs } + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { return fs } - providerCh := testutils.NewChannel() - origBuildProvider := buildProvider - buildProvider = func(c map[string]*certprovider.BuildableConfig, id, cert string, wi, wr bool) (certprovider.Provider, error) { - p, err := origBuildProvider(c, id, cert, wi, wr) - providerCh.Send(nil) - return p, err - } - - return fs, clientCh, providerCh, func() { + return fs, clientCh, func() { newXDSClient = origNewXDSClient newGRPCServer = origNewGRPCServer - buildProvider = origBuildProvider } } @@ -261,16 +345,16 @@ func setupOverrides() (*fakeGRPCServer, *testutils.Channel, *testutils.Channel, // one. Tests that use xdsCredentials need a real grpc.Server instead of a fake // one, because the xDS-enabled server needs to read configured creds from the // underlying grpc.Server to confirm whether xdsCreds were configured. -func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, *testutils.Channel, func()) { +func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, func()) { clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() bc := &bootstrap.Config{ - BalancerName: "dummyBalancer", - Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), - NodeProto: xdstestutils.EmptyNodeProtoV3, - ServerResourceNameID: testServerResourceNameID, + BalancerName: "dummyBalancer", + Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), + NodeProto: xdstestutils.EmptyNodeProtoV3, + ServerListenerResourceNameTemplate: testServerListenerResourceNameTemplate, } if includeCertProviderCfg { bc.CertProviderConfigs = certProviderConfigs @@ -280,18 +364,7 @@ func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, return c, nil } - providerCh := testutils.NewChannel() - origBuildProvider := buildProvider - buildProvider = func(c map[string]*certprovider.BuildableConfig, id, cert string, wi, wr bool) (certprovider.Provider, error) { - p, err := origBuildProvider(c, id, cert, wi, wr) - providerCh.Send(nil) - return p, err - } - - return clientCh, providerCh, func() { - newXDSClient = origNewXDSClient - buildProvider = origBuildProvider - } + return clientCh, func() { newXDSClient = origNewXDSClient } } // TestServeSuccess tests the successful case of calling Serve(). @@ -303,10 +376,17 @@ func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, // 4. Push a good response from the xdsClient, and make sure that Serve() on the // underlying grpc.Server is called. func (s) TestServeSuccess(t *testing.T) { - fs, clientCh, _, cleanup := setupOverrides() + fs, clientCh, cleanup := setupOverrides() defer cleanup() - server := NewGRPCServer() + // Create a new xDS-enabled gRPC server and pass it a server option to get + // notified about serving mode changes. + modeChangeCh := testutils.NewChannel() + modeChangeOption := ServingModeCallback(func(addr net.Addr, args ServingModeChangeArgs) { + t.Logf("server mode change callback invoked for listener %q with mode %q and error %v", addr.String(), args.Mode, args.Err) + modeChangeCh.Send(args.Mode) + }) + server := NewGRPCServer(modeChangeOption) defer server.Stop() lis, err := xdstestutils.LocalTCPListener() @@ -337,33 +417,89 @@ func (s) TestServeSuccess(t *testing.T) { if err != nil { t.Fatalf("error when waiting for a ListenerWatch: %v", err) } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) + wantName := strings.Replace(testServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) if name != wantName { t.Fatalf("LDS watch registered for name %q, want %q", name, wantName) } // Push an error to the registered listener watch callback and make sure // that Serve does not return. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{}, errors.New("LDS error")) + client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{}, xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "LDS resource not found")) sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() if _, err := serveDone.Receive(sCtx); err != context.DeadlineExceeded { t.Fatal("Serve() returned after a bad LDS response") } + // Make sure the serving mode changes appropriately. + v, err := modeChangeCh.Receive(ctx) + if err != nil { + t.Fatalf("error when waiting for serving mode to change: %v", err) + } + if mode := v.(connectivity.ServingMode); mode != connectivity.ServingModeNotServing { + t.Fatalf("server mode is %q, want %q", mode, connectivity.ServingModeNotServing) + } + // Push a good LDS response, and wait for Serve() to be invoked on the // underlying grpc.Server. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: "routeconfig"}, nil) + fcm, err := xdsclient.NewFilterChainManager(listenerWithFilterChains) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + addr, port := splitHostPort(lis.Addr().String()) + client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + RouteConfigName: "routeconfig", + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: addr, + Port: port, + FilterChains: fcm, + }, + }, nil) if _, err := fs.serveCh.Receive(ctx); err != nil { t.Fatalf("error when waiting for Serve() to be invoked on the grpc.Server") } + + // Make sure the serving mode changes appropriately. + v, err = modeChangeCh.Receive(ctx) + if err != nil { + t.Fatalf("error when waiting for serving mode to change: %v", err) + } + if mode := v.(connectivity.ServingMode); mode != connectivity.ServingModeServing { + t.Fatalf("server mode is %q, want %q", mode, connectivity.ServingModeServing) + } + + // Push an update to the registered listener watch callback with a Listener + // resource whose host:port does not match the actual listening address and + // port. This will push the listener to "not-serving" mode. + client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + RouteConfigName: "routeconfig", + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: "10.20.30.40", + Port: "666", + FilterChains: fcm, + }, + }, nil) + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := serveDone.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatal("Serve() returned after a bad LDS response") + } + + // Make sure the serving mode changes appropriately. + v, err = modeChangeCh.Receive(ctx) + if err != nil { + t.Fatalf("error when waiting for serving mode to change: %v", err) + } + if mode := v.(connectivity.ServingMode); mode != connectivity.ServingModeNotServing { + t.Fatalf("server mode is %q, want %q", mode, connectivity.ServingModeNotServing) + } } // TestServeWithStop tests the case where Stop() is called before an LDS update // is received. This should cause Serve() to exit before calling Serve() on the // underlying grpc.Server. func (s) TestServeWithStop(t *testing.T) { - fs, clientCh, _, cleanup := setupOverrides() + fs, clientCh, cleanup := setupOverrides() defer cleanup() // Note that we are not deferring the Stop() here since we explicitly call @@ -399,7 +535,7 @@ func (s) TestServeWithStop(t *testing.T) { server.Stop() t.Fatalf("error when waiting for a ListenerWatch: %v", err) } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) + wantName := strings.Replace(testServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) if name != wantName { server.Stop() t.Fatalf("LDS watch registered for name %q, wantPrefix %q", name, wantName) @@ -448,39 +584,79 @@ func (s) TestServeBootstrapFailure(t *testing.T) { } } -// TestServeBootstrapWithMissingCertProviders tests the case where the bootstrap -// config does not contain certificate provider configuration, but xdsCreds are -// passed to the server. Verifies that the call to Serve() fails. -func (s) TestServeBootstrapWithMissingCertProviders(t *testing.T) { - _, _, cleanup := setupOverridesForXDSCreds(false) - defer cleanup() - - xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) - if err != nil { - t.Fatalf("failed to create xds server credentials: %v", err) +// TestServeBootstrapConfigInvalid tests the cases where the bootstrap config +// does not contain expected fields. Verifies that the call to Serve() fails. +func (s) TestServeBootstrapConfigInvalid(t *testing.T) { + tests := []struct { + desc string + bootstrapConfig *bootstrap.Config + }{ + { + desc: "bootstrap config is missing", + bootstrapConfig: nil, + }, + { + desc: "certificate provider config is missing", + bootstrapConfig: &bootstrap.Config{ + BalancerName: "dummyBalancer", + Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), + NodeProto: xdstestutils.EmptyNodeProtoV3, + ServerListenerResourceNameTemplate: testServerListenerResourceNameTemplate, + }, + }, + { + desc: "server_listener_resource_name_template is missing", + bootstrapConfig: &bootstrap.Config{ + BalancerName: "dummyBalancer", + Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), + NodeProto: xdstestutils.EmptyNodeProtoV3, + CertProviderConfigs: certProviderConfigs, + }, + }, } - server := NewGRPCServer(grpc.Creds(xdsCreds)) - defer server.Stop() - lis, err := xdstestutils.LocalTCPListener() - if err != nil { - t.Fatalf("xdstestutils.LocalTCPListener() failed: %v", err) - } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + // Override the xdsClient creation with one that returns a fake + // xdsClient with the specified bootstrap configuration. + clientCh := testutils.NewChannel() + origNewXDSClient := newXDSClient + newXDSClient = func() (xdsclient.XDSClient, error) { + c := fakeclient.NewClient() + c.SetBootstrapConfig(test.bootstrapConfig) + clientCh.Send(c) + return c, nil + } + defer func() { newXDSClient = origNewXDSClient }() - serveDone := testutils.NewChannel() - go func() { - err := server.Serve(lis) - serveDone.Send(err) - }() + xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + t.Fatalf("failed to create xds server credentials: %v", err) + } + server := NewGRPCServer(grpc.Creds(xdsCreds)) + defer server.Stop() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - v, err := serveDone.Receive(ctx) - if err != nil { - t.Fatalf("error when waiting for Serve() to exit: %v", err) - } - if err, ok := v.(error); !ok || err == nil { - t.Fatal("Serve() did not exit with error") + lis, err := xdstestutils.LocalTCPListener() + if err != nil { + t.Fatalf("xdstestutils.LocalTCPListener() failed: %v", err) + } + + serveDone := testutils.NewChannel() + go func() { + err := server.Serve(lis) + serveDone.Send(err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + v, err := serveDone.Receive(ctx) + if err != nil { + t.Fatalf("error when waiting for Serve() to exit: %v", err) + } + if err, ok := v.(error); !ok || err == nil { + t.Fatal("Serve() did not exit with error") + } + }) } } @@ -488,7 +664,7 @@ func (s) TestServeBootstrapWithMissingCertProviders(t *testing.T) { // verifies that Server() exits with a non-nil error. func (s) TestServeNewClientFailure(t *testing.T) { origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { return nil, errors.New("xdsClient creation failed") } defer func() { newXDSClient = origNewXDSClient }() @@ -522,7 +698,7 @@ func (s) TestServeNewClientFailure(t *testing.T) { // server is not configured with xDS credentials. Verifies that the security // config received as part of a Listener update is not acted upon. func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { - fs, clientCh, providerCh, cleanup := setupOverrides() + fs, clientCh, cleanup := setupOverrides() defer cleanup() server := NewGRPCServer() @@ -556,7 +732,7 @@ func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { if err != nil { t.Fatalf("error when waiting for a ListenerWatch: %v", err) } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) + wantName := strings.Replace(testServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) if name != wantName { t.Fatalf("LDS watch registered for name %q, want %q", name, wantName) } @@ -564,12 +740,57 @@ func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { // Push a good LDS response with security config, and wait for Serve() to be // invoked on the underlying grpc.Server. Also make sure that certificate // providers are not created. + fcm, err := xdsclient.NewFilterChainManager(&v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + addr, port := splitHostPort(lis.Addr().String()) client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ RouteConfigName: "routeconfig", - SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", - IdentityInstanceName: "default2", - RequireClientCert: true, + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: addr, + Port: port, + FilterChains: fcm, }, }, nil) if _, err := fs.serveCh.Receive(ctx); err != nil { @@ -577,10 +798,8 @@ func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { } // Make sure the security configuration is not acted upon. - sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if _, err := providerCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Fatalf("certificate provider created when no xDS creds were specified") + if err := verifyCertProviderNotCreated(); err != nil { + t.Fatal(err) } } @@ -588,7 +807,7 @@ func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { // server is configured with xDS credentials, but receives a Listener update // with an error. Verifies that no certificate providers are created. func (s) TestHandleListenerUpdate_ErrorUpdate(t *testing.T) { - clientCh, providerCh, cleanup := setupOverridesForXDSCreds(true) + clientCh, cleanup := setupOverridesForXDSCreds(true) defer cleanup() xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) @@ -627,21 +846,14 @@ func (s) TestHandleListenerUpdate_ErrorUpdate(t *testing.T) { if err != nil { t.Fatalf("error when waiting for a ListenerWatch: %v", err) } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) + wantName := strings.Replace(testServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) if name != wantName { t.Fatalf("LDS watch registered for name %q, want %q", name, wantName) } // Push an error to the registered listener watch callback and make sure // that Serve does not return. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ - RouteConfigName: "routeconfig", - SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", - IdentityInstanceName: "default2", - RequireClientCert: true, - }, - }, errors.New("LDS error")) + client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{}, errors.New("LDS error")) sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() if _, err := serveDone.Receive(sCtx); err != context.DeadlineExceeded { @@ -649,84 +861,21 @@ func (s) TestHandleListenerUpdate_ErrorUpdate(t *testing.T) { } // Also make sure that no certificate providers are created. - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if _, err := providerCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Fatalf("certificate provider created when no xDS creds were specified") + if err := verifyCertProviderNotCreated(); err != nil { + t.Fatal(err) } } -func (s) TestHandleListenerUpdate_ClosedListener(t *testing.T) { - clientCh, providerCh, cleanup := setupOverridesForXDSCreds(true) - defer cleanup() - - xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) - if err != nil { - t.Fatalf("failed to create xds server credentials: %v", err) - } - - server := NewGRPCServer(grpc.Creds(xdsCreds)) - defer server.Stop() - - lis, err := xdstestutils.LocalTCPListener() - if err != nil { - t.Fatalf("xdstestutils.LocalTCPListener() failed: %v", err) - } - - // Call Serve() in a goroutine, and push on a channel when Serve returns. - serveDone := testutils.NewChannel() - go func() { serveDone.Send(server.Serve(lis)) }() - - // Wait for an xdsClient to be created. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := clientCh.Receive(ctx) - if err != nil { - t.Fatalf("error when waiting for new xdsClient to be created: %v", err) - } - client := c.(*fakeclient.Client) - - // Wait for a listener watch to be registered on the xdsClient. - name, err := client.WaitForWatchListener(ctx) - if err != nil { - t.Fatalf("error when waiting for a ListenerWatch: %v", err) - } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) - if name != wantName { - t.Fatalf("LDS watch registered for name %q, want %q", name, wantName) - } - - // Push a good update to the registered listener watch callback. This will - // unblock the xds-enabled server which is waiting for a good listener - // update before calling Serve() on the underlying grpc.Server. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ - RouteConfigName: "routeconfig", - SecurityCfg: &xdsclient.SecurityConfig{IdentityInstanceName: "default2"}, - }, nil) - if _, err := providerCh.Receive(ctx); err != nil { - t.Fatal("error when waiting for certificate provider to be created") - } - - // Close the listener passed to Serve(), and wait for the latter to return a - // non-nil error. - lis.Close() - v, err := serveDone.Receive(ctx) - if err != nil { - t.Fatalf("error when waiting for Serve() to exit: %v", err) - } - if err, ok := v.(error); !ok || err == nil { - t.Fatal("Serve() did not exit with error") - } - - // Push another listener update and make sure that no certificate providers - // are created. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ - RouteConfigName: "routeconfig", - SecurityCfg: &xdsclient.SecurityConfig{IdentityInstanceName: "default1"}, - }, nil) +func verifyCertProviderNotCreated() error { sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if _, err := providerCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Fatalf("certificate provider created when no xDS creds were specified") + if _, err := fpb1.buildCh.Receive(sCtx); err != context.DeadlineExceeded { + return errors.New("certificate provider created when no xDS creds were specified") } + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := fpb2.buildCh.Receive(sCtx); err != context.DeadlineExceeded { + return errors.New("certificate provider created when no xDS creds were specified") + } + return nil } diff --git a/xds/xds.go b/xds/xds.go index 5deafd130a2..27547b56d22 100644 --- a/xds/xds.go +++ b/xds/xds.go @@ -28,9 +28,67 @@ package xds import ( + "fmt" + + v3statusgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + "google.golang.org/grpc" + internaladmin "google.golang.org/grpc/internal/admin" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/csds" + _ "google.golang.org/grpc/credentials/tls/certprovider/pemfile" // Register the file watcher certificate provider plugin. _ "google.golang.org/grpc/xds/internal/balancer" // Register the balancers. - _ "google.golang.org/grpc/xds/internal/client/v2" // Register the v2 xDS API client. - _ "google.golang.org/grpc/xds/internal/client/v3" // Register the v3 xDS API client. - _ "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + _ "google.golang.org/grpc/xds/internal/httpfilter/fault" // Register the fault injection filter. + _ "google.golang.org/grpc/xds/internal/httpfilter/rbac" // Register the RBAC filter. + _ "google.golang.org/grpc/xds/internal/httpfilter/router" // Register the router filter. + xdsresolver "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register the v2 xDS API client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register the v3 xDS API client. ) + +func init() { + internaladmin.AddService(func(registrar grpc.ServiceRegistrar) (func(), error) { + var grpcServer *grpc.Server + switch ss := registrar.(type) { + case *grpc.Server: + grpcServer = ss + case *GRPCServer: + sss, ok := ss.gs.(*grpc.Server) + if !ok { + logger.Warningf("grpc server within xds.GRPCServer is not *grpc.Server, CSDS will not be registered") + return nil, nil + } + grpcServer = sss + default: + // Returning an error would cause the top level admin.Register() to + // fail. Log a warning instead. + logger.Warningf("server to register service on is neither a *grpc.Server or a *xds.GRPCServer, CSDS will not be registered") + return nil, nil + } + + csdss, err := csds.NewClientStatusDiscoveryServer() + if err != nil { + return nil, fmt.Errorf("failed to create csds server: %v", err) + } + v3statusgrpc.RegisterClientStatusDiscoveryServiceServer(grpcServer, csdss) + return csdss.Close, nil + }) +} + +// NewXDSResolverWithConfigForTesting creates a new xds resolver builder using +// the provided xds bootstrap config instead of the global configuration from +// the supported environment variables. The resolver.Builder is meant to be +// used in conjunction with the grpc.WithResolvers DialOption. +// +// Testing Only +// +// This function should ONLY be used for testing and may not work with some +// other features, including the CSDS service. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func NewXDSResolverWithConfigForTesting(bootstrapConfig []byte) (resolver.Builder, error) { + return xdsresolver.NewBuilder(bootstrapConfig) +}