From ab2a1bb475743e5149341b00d8256e8de9e7a220 Mon Sep 17 00:00:00 2001 From: tspiridonova Date: Wed, 13 Oct 2021 20:05:50 +0300 Subject: [PATCH] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 2fe71180762478c66b0027f780b95c40fc563a55 Author: Menghan Li Date: Mon Oct 11 15:42:10 2021 -0700 xds/e2e: move flag check to each test, and call t.Skip() (#4861) commit ea41fbfa10817592c85b4ada15d3d1ba3d6fdae7 Author: Easwar Swaminathan Date: Mon Oct 11 14:55:45 2021 -0700 examples: unix abstract socket (#4848) commit 6c56e211a0691f83bb59c5100e79f25d04cd4bb0 Author: Easwar Swaminathan Date: Mon Oct 11 14:55:12 2021 -0700 grpclb: add `target_field` to service config (#4847) commit 49f638878973e2ccc20859209a73e1c3a02de015 Author: Menghan Li Date: Mon Oct 11 11:06:15 2021 -0700 grpclog: support formatting output as JSON (#4854) commit b99d1040b71caf9b22be570edb68d85dcb6c515c Author: Ashitha Santhosh <55257063+ashithasantosh@users.noreply.github.com> Date: Fri Oct 8 17:09:55 2021 -0700 authz: create file watcher interceptor for gRPC SDK API (#4760) * authz: create file watcher interceptor for gRPC SDK API commit 03ca7b7d00cada2ff8a3ea7348fe6c3a2b2ee4fb Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Thu Oct 7 22:46:49 2021 -0400 Added logs to rbac (#4853) Added logs to rbac commit 524d10cbce3e1c597c48589341c01332f71c3d93 Author: Terry Wilson Date: Thu Oct 7 11:58:49 2021 -0700 kokoro: source test driver install script from core repo (#4825) commit b9d7c74e01f89a52880332f584776ccfe0c27756 Author: Menghan Li Date: Thu Oct 7 11:47:53 2021 -0700 xds: local interop tests (#4823) commit 404d8fd5139bfb7be9b4bf675e76eed53207dd8c Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed Oct 6 19:26:43 2021 -0400 Added imports for HTTP Filters (#4850) Added imports for HTTP Filters commit d16cfedb5f31caad933f6bb4f3aa3a85177fb989 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed Oct 6 19:26:22 2021 -0400 Rename env var (#4849) Rename env var commit 4bd99953513f3d9de6a75075cfb51bc2224429e0 Author: Easwar Swaminathan Date: Tue Oct 5 16:55:25 2021 -0700 xds: suppress redundant resource updates using proto.Equal (#4831) commit ee479e630f859849f23d70ebf2fa3021f5ad2658 Author: Menghan Li Date: Tue Oct 5 14:49:15 2021 -0700 creds/google: replace NewComputeEngineCredsWithOptions with NewDefaultCredentialsWithOptions (#4830) commit 02da625150e8ee126d4b84dfed27d2453f2617f4 Author: Doug Fawley Date: Mon Oct 4 14:01:09 2021 -0700 github: increase timeout for codeql and disable for PRs (#4841) commit f2974e7778b189c094a2d5fe087d680f7a72050e Author: Menghan Li Date: Mon Oct 4 11:54:27 2021 -0700 kokoro: remove expired letsencrypt.org cert and update (#4840) commit f068a13ef05d63510828ed59e7cf3651a02c7118 Author: Doug Fawley Date: Mon Oct 4 11:22:00 2021 -0700 server: add missing conn.Close if the connection dies before reading the HTTP/2 preface (#4837) commit 09970207abb5f88aeb1b31fd1190a0934e302e0f Author: Easwar Swaminathan Date: Fri Oct 1 15:27:28 2021 -0700 xds: remove race in TestUnmarshalCluster_WithUpdateValidatorFunc (#4836) commit b9f62538f003011893c70991c36bdb5b680eed8e Author: Easwar Swaminathan Date: Fri Oct 1 11:09:26 2021 -0700 rls: pull proto changes made in grpc-proto/pull/98 (#4832) commit 69e1b54deb77c76b6832f36e4b31d4316a34f17d Author: Doug Fawley Date: Fri Oct 1 11:09:12 2021 -0700 test: fix stayConnected to call Connect after state reports IDLE (#4821) commit 127c052c701b81daa5970c695438f6ef08c76040 Author: Mohan Li <67390330+mohanli-ml@users.noreply.github.com> Date: Thu Sep 30 13:06:50 2021 -0700 credentials/google: introduce a new API `NewComputeEngineCredsWithOptions` (#4767) commit 2ae5ac1637d68d584d20815d965538481a3c11c7 Author: Easwar Swaminathan Date: Thu Sep 30 10:04:19 2021 -0700 xds: nack if certprovider instance name is missing in bootstrap config (#4799) commit adb21c46100568b95ab41e97e3b267923a0a92a0 Author: Easwar Swaminathan Date: Wed Sep 29 16:58:46 2021 -0700 rls: improve config parsing (#4819) commit e6d0d2818a7380920b806ae629320500e739bd5c Author: Menghan Li Date: Tue Sep 28 13:55:29 2021 -0700 internal: log SubConn type if it's not the expected type (#4813) commit 34df1b42aecf459d913a1b6aaf835e1d4eea22d3 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Tue Sep 28 15:27:00 2021 -0400 xds: Small RBAC Changes defined in A41 (#4818) * xds: Small RBAC Changes defined in A41 commit 75f1d4b986342c24fab707fc6be37c51f9f8ee50 Author: Doug Fawley Date: Tue Sep 28 12:20:57 2021 -0700 transport: call stats handler for trailers before closeStream (#4816) commit 08927214a41e3a2d937658689167363942c06426 Author: Menghan Li Date: Tue Sep 28 10:11:52 2021 -0700 xds/rds: NACK unknown route action cluster specifier (#4788) commit 710419d32bfd469509bae5b73274f5825ad13554 Author: ZhenLian Date: Mon Sep 27 16:42:32 2021 -0700 advancedtls: add revocation support to client/server options (#4781) commit 4555155af248cab3368e5c5e650bd216366c8bb5 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Mon Sep 27 17:36:16 2021 -0400 xds: Small changes at xDS RBAC Layer (#4815) * xds: Small changes at xDS RBAC Layer commit 689f7b154ee8a3f3ab6a6107ff7ad78189baae06 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Mon Sep 27 16:55:46 2021 -0400 transport: logic specified in A41 to support RBAC xDS HTTP Filter (#4803) * transport: logic specified in A41 to support RBAC xDS HTTP Filter commit 11437f66f20f3473e09fcf3fb5c23d4388af936f Author: Doug Fawley Date: Fri Sep 24 15:29:25 2021 -0700 test: add option to make httpServer wait for END_STREAM; fix RetryStats race (#4811) commit 6ff68b489ecba2884aff152835d745389598935a Author: Doug Fawley Date: Thu Sep 23 14:40:18 2021 -0700 channelz: recommend using admin.Register instead (#4797) commit 78d3aa8b3ed1b59bf84db4242ac7c316e8943797 Author: Easwar Swaminathan Date: Thu Sep 23 07:43:14 2021 -0700 grpc: cleanup parse target and authority tests (#4787) commit 83a3461520f69c1896990dfae724101c1ed6a1d2 Author: Easwar Swaminathan Date: Wed Sep 22 17:43:36 2021 -0700 xds: have separate tests for RBAC on and off (#4807) commit d7208f02ca7721bef504d100b61c1ef8cd569390 Author: Doug Fawley Date: Wed Sep 22 16:35:39 2021 -0700 github: set a shorter timeout on testing jobs (#4806) commit 32cd3d617642c49c435ab2a435e716efd4a5e949 Author: apolcyn Date: Wed Sep 22 16:08:17 2021 -0700 interop: don't use WithBlock dial option in the client (#4805) commit d623accd30f0f13047e6e2b7147aee41691054c3 Author: Menghan Li Date: Wed Sep 22 16:01:18 2021 -0700 xds: fix parent balancers to handle Idle children (#4801) commit e6246c22eb0440d525ce1c226b0c9f1ea9ea693a Author: Evan Jones Date: Wed Sep 22 16:30:27 2021 -0400 server: optimize chain interceptors (-1 allocation, -10% time/call) (#4746) commit 458ea7640a92039aad37edc67b63e6d040a93320 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed Sep 22 15:08:44 2021 -0400 xds: Added validations for HCM to support xDS RBAC Filter (#4786) * xds: Added validations for HCM to support xDS RBAC Filter commit 1f12bf44284e6ba4be72cd028a2a1eb01c2d18bb Author: Yury Frolov <57130330+EinKrebs@users.noreply.github.com> Date: Wed Sep 22 23:04:45 2021 +0500 transport: fix a typo in http2_server.go (#4745) commit 606403ded29c7b922a66b4c5a449a1643269bc96 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Tue Sep 21 19:33:18 2021 -0400 transport: fix log spam from Server Authentication Handshake errors (#4798) * transport: fix log spam from Server Authentication Handshake errors commit 616977cc7974d6cbec50399297db7026d791c9dd Author: Doug Fawley Date: Tue Sep 21 11:32:51 2021 -0700 Change version to 1.42.0-dev (#4793) commit 4ddf8ceaa7b5de2170b082bfc7162c4887ddaeb5 Author: Doug Fawley Date: Tue Sep 21 10:55:00 2021 -0700 Revert "transport/server: add :method POST to incoming metadata (#4770)" (#4790) This reverts commit c84a5de06496bf8416cebf9d0058f481e37c165e. commit d53469981f2356f7c270d4b3beaafc6d1a653817 Author: Doug Fawley Date: Tue Sep 21 10:39:59 2021 -0700 transport: fix transparent retries when per-RPC credentials are in use (#4785) commit 5417cf809116a5e3e8ca06b15cb48cbffb946204 Author: Menghan Li Date: Mon Sep 20 13:27:27 2021 -0700 xds/test: delete use of removed types (#4784) They were deprecated, and removed later. commit 1109452fd118ec20164e859f71c0bb59fd209d21 Author: Lidi Zheng Date: Fri Sep 17 15:19:26 2021 -0700 [Backport grpc#27373] add testing_version flag (#4783) commit e469f0d5f5bcc1324dc3940c584e0969e2ea1f90 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Fri Sep 17 01:01:07 2021 -0400 xds: Add env var protection for RBAC HTTP Filter (#4765) * xds: Add env var protection for RBAC HTTP Filter commit 567da6b86340a83d509467638c91e68168bc1921 Author: Menghan Li Date: Thu Sep 16 13:38:35 2021 -0700 tlogger: print log type (#4774) Error logs cause tests to fail. This makes it easier (possible) to find the error log commit 03b2ebe5080c2b521c742cf6e06bd0824b75fc52 Author: Menghan Li Date: Thu Sep 16 11:07:04 2021 -0700 xds: enable ringhash and retry by default (#4776) commit b186ee8975f3c69bc36333a99fc82d1388977012 Author: Ed Warnicke Date: Thu Sep 16 09:59:36 2021 -0500 test/bufconn: add Listener.DialContext(context.Context) (#4763) commit 7cf9689be2d2b1e7f00dfc15d2516b7635c65c45 Author: Easwar Swaminathan Date: Wed Sep 15 15:38:01 2021 -0700 xds: validations for security config, as specified in A29 (#4762) * xds: validations for security config, as specified in A29 * make vet happy * fix error log * fix error msg in test commit 4f093b9a5afa5f3c8f29774dbdce8c02ce516d70 Author: Menghan Li Date: Wed Sep 15 14:47:18 2021 -0700 ringhash: the balancer (#4741) commit 4c5f7fb0eecd984708e0c1eeea7d426f275b22d3 Author: Easwar Swaminathan Date: Wed Sep 15 14:05:59 2021 -0700 xds: de-experimentalize xDS apis required for psm security (#4753) commit c84a5de06496bf8416cebf9d0058f481e37c165e Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed Sep 15 17:02:08 2021 -0400 transport/server: add :method POST to incoming metadata (#4770) * transport/server: add :method POST to incoming metadata commit 98ccf472da9a7e01d53bd27e5ad537d46c1b5ca9 Author: Menghan Li Date: Wed Sep 15 13:35:51 2021 -0700 priority: handle Idle children the same way as Ready (#4769) commit 2d4e44a0cd75808908c9fb98aac764af6558ff6e Author: Menghan Li Date: Tue Sep 14 16:11:03 2021 -0700 xds/affinity: fix bugs in clusterresolver and xds-resolver (#4744) commit d41f21ca050b1721093702ede81c21b7e3bdaa63 Author: Doug Fawley Date: Tue Sep 14 15:11:42 2021 -0700 stats: support stats for all retry attempts; support transparent retry (#4749) commit 5d8e5aad40bedb696205b96b786f1d0e1326b3f8 Author: Kobi Date: Tue Sep 14 17:15:02 2021 +0300 Create NOTICE.txt (#4739) commit 5bfc05fb0cf08fd2a8257d2bca8dba552263ba7e Author: Easwar Swaminathan Date: Mon Sep 13 11:50:52 2021 -0700 grpc: clarify the use of transport.ErrConnClosing from createTransport() (#4757) commit 77ffb2ef318a2b8442b9fb10f80724013b2e65eb Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Mon Sep 13 14:09:57 2021 -0400 xds: RBAC HTTP Filter (#4748) * xds: RBAC HTTP Filter commit 03268c8ed29e801944a2265a82f240f7c0e1b1c3 Author: Doug Fawley Date: Fri Sep 10 16:25:09 2021 -0700 balancer: fix aggregated state to not report idle with zero subconns (#4756) commit d25e31e741ddfb45f4126cd20e357185751e42c2 Author: Doug Fawley Date: Fri Sep 10 14:12:13 2021 -0700 client: fix case where GOAWAY would leak connections and memory (#4755) commit 7f560ef4c5224efb8a86f2877315c381c30fa126 Author: Easwar Swaminathan Date: Fri Sep 10 14:08:26 2021 -0700 grpc: close underlying transport when subConn is closed when in connecting state (#4751) commit 4e07a14b4e66e90ebf54ccc361012cb2b10724fd Author: Cesar Ghali Date: Fri Sep 10 13:58:12 2021 -0700 credentials/ALTS: Ensure ALTS record protocol names are consistent (#4754) commit 16cf65612e633d1cc0be8c65ee7a49fbe2b27825 Author: Menghan Li Date: Fri Sep 10 11:24:25 2021 -0700 xds: update xdsclient NACK to keep valid resources (#4743) commit 43e8fd4f69b65fd51d72578df4afa5c0519ca2b5 Author: Easwar Swaminathan Date: Fri Sep 10 10:59:25 2021 -0700 xds: don't remove env var protection for security on the client yet (#4752) Set the value to true by default, and remove it one release later. commit 0a99ae2d035feeb87506e767bd88d3b7364d1059 Author: Easwar Swaminathan Date: Fri Sep 10 09:04:59 2021 -0700 xds: support new fields to fetch security configuration (#4747) commit 2608e38e6386be7400720fecf2ece176c4cbc1b2 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Thu Sep 9 13:35:41 2021 -0400 xds: Added server side routing (#4726) * Added server side routing commit 1fe5adbbf82f15781a0ce3f704012dc44e6b8e63 Author: apolcyn Date: Wed Sep 8 17:31:51 2021 -0700 interop-testing: add soak test cases to interop client (#4677) commit a6a63177ae6094f9baa83b046bb4f20426ba5b82 Author: Doug Fawley Date: Wed Sep 8 10:00:44 2021 -0700 xds: add retry support (#4738) commit 2f3355d2244eb436564a93dfbe2b0ba907adeb98 Author: Easwar Swaminathan Date: Tue Sep 7 11:11:16 2021 -0700 xds: update go-control-plane to latest (#4737) commit 973e7cb9a17d398b9ddff102e19701f9e7a7a096 Author: Menghan Li Date: Tue Sep 7 10:41:26 2021 -0700 ringhash: the picker (#4730) commit 00a7dc8901e6f74713b131601d76cfc8fb62f8b0 Author: Easwar Swaminathan Date: Tue Sep 7 10:28:56 2021 -0700 xds: remove env var protection for security on client (#4735) commit c99a9c19b08500bd4259e95e3529ff483a0ae405 Author: Menghan Li Date: Tue Sep 7 10:10:36 2021 -0700 priority: forward the first IDLE state and picker (#4731) commit 0ca7dca97726252050774a4bff20d92ca5772331 Author: yihuaz Date: Tue Sep 7 09:12:01 2021 -0700 oauth: Allow access to Google API regional endpoints via Google Default Credentials (#4713) commit b2ba77a36ff809ab344b98368d9ecc3e12f943d6 Author: Easwar Swaminathan Date: Fri Sep 3 10:59:33 2021 -0700 xds: use separate update channels for listeners in test (#4712) commit c93e472777b9d2963eff865ff4ee9f0895876b43 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Thu Sep 2 14:43:26 2021 -0400 Fixed race in Filter Chain (#4728) commit b189f5e1bc9a495447332355df8a9648e65a2e44 Author: Ashitha Santhosh <55257063+ashithasantosh@users.noreply.github.com> Date: Thu Sep 2 11:22:07 2021 -0700 authz: create interceptors for gRPC security policy API (#4664) * Static Authorization Interceptor commit d6a5f5f4f3621542ec98cfed52c0620beab9fbd5 Author: Menghan Li Date: Thu Sep 2 10:49:35 2021 -0700 ringhash: the ring (#4701) commit 51003aa81e09b20c1a74ec88c961a68902349143 Author: Easwar Swaminathan Date: Wed Sep 1 13:49:44 2021 -0700 xds: start a management server per test (#4720) commit ed501aa1fd1d368d77e17de619046e2e1ebb82a9 Author: Tobias Klauser Date: Wed Sep 1 20:08:00 2021 +0200 xds/internal/resolver: update github.com/cespare/xxhash to v2 (#4671) github.com/cespare/xxhash/v2 supports Go ≥ 1.11 and this package states 1.11 in its go.mod file. The only symbol used from the xxhash package is the Sum64String func which still exists and works the same in v2. This gets rid of two indirect dependencies. commit f7d66b5846f00b6ab0b41a675aef9764176830fa Author: Lidi Zheng Date: Tue Aug 31 13:42:43 2021 -0700 Change to a non-workload-identity GKE cluster (#4723) commit 198d951db5082bddddd36e53efa8e9cbc924a228 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Tue Aug 31 09:27:06 2021 -0400 xds: Instantiated HTTP Filters on Server Side (#4669) * Instantiated HTTP Filters on Server Side commit ef66d13abb84ad6c6d99c8cbf3697607b7891f32 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Mon Aug 30 16:49:46 2021 -0400 xds: Required Router Filter for both Client and Server side (#4676) * Added isTerminal() to FilterAPI and required router filter on Client and Server side commit 85b9a1a0fa3fc7ce6677ac19267b380ef0cf59a7 Author: Easwar Swaminathan Date: Fri Aug 27 08:18:29 2021 -0700 xds: pass empty balancer.BuildOptions in clusterresolver_test (#4711) commit 43b19ef0e473c675b0ec7666a9856bf5edd7439e Author: Doug Fawley Date: Thu Aug 26 13:29:59 2021 -0700 grpctest: extend use of mutex to guard more things (#4710) commit d074cae66bc68d4ec5ccf427de2fce700223f4c7 Author: Doug Fawley Date: Thu Aug 26 11:21:36 2021 -0700 github: fold security tests into 'tests'; update testing to 1.17-1.15 (#4708) commit 0b372df5f45ee5e81aaae18ae9e5ad60eab60586 Author: Menghan Li Date: Thu Aug 26 10:21:09 2021 -0700 xds/client: NACK ringhash lb policy if env var is not set (#4707) commit 712e8d4f57fd4a4fbb83406148f9c71eb3e7714e Author: Easwar Swaminathan Date: Wed Aug 25 14:51:41 2021 -0700 Remove support for Go 1.13 and older (cont) (#4706) commit 498743c19e864d45b6761fd0b8c6cf7ad72eb271 Author: apolcyn Date: Wed Aug 25 14:03:53 2021 -0700 xds/c2p: update default XDS server name in C2P resolver (#4705) commit 6bd8e8cf30e25b6cde3ec16389ff470680c107b1 Author: Easwar Swaminathan Date: Tue Aug 24 14:24:34 2021 -0700 multiple: remove support for Go 1.11 (#4700) commit 5f4bc66745e1af8406741bb329a7bb7119631e02 Author: Doug Fawley Date: Tue Aug 24 13:52:18 2021 -0700 grpc: fix stayConnected function to connect upon entry (#4699) If stayConnected was called while the ClientConn was in IDLE already, it would never call Connect, and stay stuck in that state. This change ensures cc.Connect is always called at least once. commit 46ab723bb20867a29022047224194fefd311cb37 Author: Easwar Swaminathan Date: Tue Aug 24 12:30:13 2021 -0700 multiple: remove appengine specific build constraints and code (#4685) commit bfd964bba69658b989ff619c40383e59d13770f1 Author: Easwar Swaminathan Date: Tue Aug 24 11:19:04 2021 -0700 xds: use the defaultTestTimeout instead of the short one (#4684) commit dc3afb202f85e5540ece8743b114c7287a5f37a4 Author: Easwar Swaminathan Date: Tue Aug 24 11:04:25 2021 -0700 xds: deflake Test/ServerSideXDS_ServingModeChanges (#4689) commit 45a623cbefb83b4708e549616fde9c6d613710ad Author: Easwar Swaminathan Date: Tue Aug 24 10:02:55 2021 -0700 test: use non blocking dials in end2end_test (#4687) commit c361e9ea1646283baf7b23a5d060c45fce9a1dea Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Mon Aug 23 19:39:14 2021 -0400 Move Server Credentials Handshake to transport (#4692) * Move Server Credentials Handshake to transport commit 8ab16ef276a33df4cdb106446eeff40ff56a6928 Author: Doug Fawley Date: Wed Aug 18 15:04:35 2021 -0700 balancer: add ExitIdle optional interface (#4673) commit 52cea2453436fbb4b962d3cb2da34da7ef6f10c7 Author: 吴亲库里 <36129334+wuqinqiang@users.noreply.github.com> Date: Thu Aug 19 04:31:22 2021 +0800 server: fix net.conn closed twice (#4663) commit a42567fe92f005c47e60146bdbb0d5f7fc232219 Author: Menghan Li Date: Thu Aug 12 11:12:02 2021 -0700 xds: support picking ringhash in xds client and cds policy (#4657) commit ad87ad009856d3423e067fc49b990d05e16d706c Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed Aug 11 18:48:24 2021 -0400 xds: Add support for Dynamic RDS in listener wrapper (#4655) * Add support for Dynamic RDS in listener wrapper commit 88dc96b463fb9a695e6181750e78524df1903601 Author: Lidi Zheng Date: Wed Aug 11 14:33:44 2021 -0700 Copy the tag_and_push_docker_image method to grpc-go (#4667) commit 9c668aeab86903a70e291eb47a04f48d84e67006 Author: Aliaksandr Mianzhynski Date: Wed Aug 11 19:17:59 2021 +0300 all: preallocate slices where possible (#4609) commit c7c1e9e0ec7aed0a530cde1e7d2fc7382a6816a2 Author: Lidi Zheng Date: Tue Aug 10 20:31:26 2021 -0700 Update xDS client/server image per-branch tag after build (#4661) commit 997ce619eb555b6a481e741afa6390ad3cd80d5c Author: Doug Fawley Date: Tue Aug 10 13:22:34 2021 -0700 clientconn: do not automatically reconnect addrConns; go idle instead (#4613) commit 01bababd83492b6eb1c7046ab4c3a4b1bcc5e9d6 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Mon Aug 9 23:15:57 2021 -0400 Added connection to transport context (#4649) * Added connection to transport context commit 574137db7de3c10e010d5023626169f13540cef1 Author: Easwar Swaminathan Date: Fri Aug 6 10:56:44 2021 -0700 xds: fix flaky test (TestPickerUpdateAfterClose) (#4658) commit fc30d5b571f5981b71e8391a04e23c5f98eab4c3 Author: Menghan Li Date: Thu Aug 5 14:30:04 2021 -0700 xds/cluster_resolver: support RING_HASH as a child of cluster_resolver balancer (#4621) 1. merge endpoint picking and localility picking policy to one field in cluster_resolver's balancer config - This field only supports ROUND_ROBIN or RING_HASH. - This is to support RING_HASH policy, which is responsible both endpoint picking and locality picking. - If policy is RING_HASH, endpoints in localities will be flattened to a list of endpoints, and passed to the policy. 1. support building policy config with RING_HASH as a child - The config tree has one less layer comparing with ROUND_ROBIN - This also need to define RING_HASH's balancer config config 1. Deleted test `TestEDS_UpdateSubBalancerName` because now the balancer doesn't support updating child to a custom policy. commit 74370577fa163f6022fb88a5926192a7c26a3933 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Thu Aug 5 17:28:06 2021 -0400 xds: Add route to filterchain (#4610) * Added RDS Information from LDS in filter chain commit 6ba56c814be74c95e35a000582e074a380e545b0 Author: Menghan Li Date: Tue Aug 3 15:12:56 2021 -0700 transport: fix race accessing s.recvCompress (#4645) This is a backport of #4641 commit 0d6854ab5ecc205b0f7437919b7988f67144eba9 Author: Menghan Li Date: Tue Aug 3 14:17:02 2021 -0700 transport: fix race accessing s.recvCompress (#4641) commit edb9b3bc226676eba6fe1cddec44d082b5a30e4f Author: Doug Fawley Date: Mon Aug 2 15:56:58 2021 -0700 github: update stale bot to v4 (#4636) commit c052940bcd91bba85050ac193aeeca6e1c588e8a Author: Menghan Li Date: Mon Aug 2 13:05:02 2021 -0700 server: fix leaked net.Conn (#4633) This happens when NewServerTransport() returns nil, nil. The rawConn is closed when the transport is closed, which will never happen in this case (since the returned transport is nil). commit 8ed8dd26555f396d81f497415086ec73103e5825 Author: ZhenLian Date: Mon Aug 2 13:03:54 2021 -0700 advancedtls: fix a typo in crl.go (#4634) commit ea9b7a0a7651baaf43c5403cb83349fffb5162df Author: Easwar Swaminathan Date: Thu Jul 29 17:23:32 2021 -0700 xds: fix a typo (#4631) commit ad0a2a847cdfb3204c30d1423436fdeec8ff17bf Author: April Kyle Nassi Date: Wed Jul 28 14:46:46 2021 -0700 Update MAINTAINERS.md (#4628) moved 2 to emeritus list commit 61c704607b40236f021f3120e5a4b1c237ed8ade Author: raymonder jin Date: Thu Jul 29 02:02:38 2021 +0800 fix typo (#4616) commit 245ad25715e019716d10f5b24d761f85ff158c15 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Tue Jul 27 15:13:18 2021 -0400 Change version to 1.41.0-dev (#4625) commit 00edd8c13a7a27bc25c8de2a68cf6de35f88bd7e Author: Lidi Zheng Date: Mon Jul 26 13:02:56 2021 -0700 Add xDS k8s url-map test Kokoro job (#4614) commit 1ddab338690a578975747239ad4ecd2ae63b1965 Author: Doug Fawley Date: Fri Jul 23 10:37:18 2021 -0700 client: fix detection of whether IO was performed in NewStream (#4611) For transparent retry. Also allow non-WFR RPCs to retry indefinitely on errors that resulted in no I/O; the spec used to forbid it at one point during development, but it no longer does. commit 582ef458c6d8174087877ee83bb514abc16650a5 Author: Menghan Li Date: Thu Jul 22 16:12:30 2021 -0700 cluster_resolver: move balancer config types into cluster_resolver package and unexport (#4607) commit c513103bee39e1ebc3793e7128941794667779de Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed Jul 21 22:42:38 2021 -0400 Add extra layer on top of RBAC Engine (#4576) * Add extra layer in RBAC commit a0bed723f1c00c8b07c6ceaf1f6ac2cb42ec0b35 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed Jul 21 21:58:19 2021 -0400 xds: add http filters to FilterChain matching (#4595) * Add HTTP Filters to FilterChain commit 0a8c63739a87bee6ff6097d272b63727659f4503 Author: apolcyn Date: Wed Jul 21 10:50:37 2021 -0700 grpclb: propagate the most recent connection error when grpclb enters transient failure (#4605) commit 8332d5b997af9e1554418167860351696d35e628 Author: lzhfromustc <43191155+lzhfromustc@users.noreply.github.com> Date: Wed Jul 21 13:40:04 2021 -0400 test: fix possible goroutine leaks in unit tests (#4570) commit 0300770df1c0b742f4eef4cce47ca315379ad4d1 Author: Menghan Li Date: Wed Jul 21 10:22:02 2021 -0700 xds: support cluster fallback in cluster_resolver (#4594) commit 65cabd74d8e18d7347fecd414fa8d83a00035f5f Author: Jille Timmermans Date: Tue Jul 20 19:58:14 2021 +0200 internal/binarylog: Fix data race when calling Write() and Close() in parallel (#4604) They both touched bufferedSink.writeTicker commit ce7bdf50abb1f7c7a5ba1a54890e6dac46eb87f7 Author: Matt Jones Date: Thu Jul 15 09:53:31 2021 -0700 advancedtls: CRL checking for golang gRPC (#4489) * Code for CRL checking for golang gRPC. commit 0103ea2d6c98f59ddd6ff09aa93f963936157213 Author: John Howard Date: Wed Jul 14 13:59:50 2021 -0700 client: improve GOAWAY debug messages (#4587) commit b586e9215896c69206b29af00f30bc34d483b6fc Author: Menghan Li Date: Wed Jul 14 13:10:19 2021 -0700 xds/client: notify the resource watchers of xDS errors (#4564) commit bfe1d0dc23ac33e7c8ebf125753e5fb0698a4bde Author: Jille Timmermans Date: Wed Jul 14 20:34:40 2021 +0200 binarylog: Use a simple boolean rather than a sync.Once (#4581) commit ba41bbac225e6e1a9b822fe636c40c3b7d977894 Author: James Protzman Date: Wed Jul 14 13:54:58 2021 -0400 transport: validate http 200 status for responses (#4474) commit ebfe3be62a82434bc83fd7b36410141a603a96be Author: Menghan Li Date: Mon Jul 12 16:42:02 2021 -0700 cluster_resolver: implement resource resolver to resolve EDS and DNS (#4531) commit 30dfb4b933a50fd366d7ed36ed4f71dbba2d382e Author: Jille Timmermans Date: Thu Jul 8 19:06:55 2021 +0200 binarylog: Don't continue after failing to marshal the proto (#4582) commit 51e780ce00959f0a2ba16ca7c65f3b99a91c3c61 Author: Jille Timmermans Date: Thu Jul 8 19:06:11 2021 +0200 internal/binarylog: Use defer to unlock mutexes (#4590) commit afad37618961fd1123d6582661895c6c533852ea Author: Easwar Swaminathan Date: Thu Jul 8 09:20:15 2021 -0700 Fix bootstrap format in comment (#4586) commit 91e0aeb192456225adf27966d04ada4cf8599915 Author: Jille Timmermans Date: Thu Jul 8 01:37:57 2021 +0200 binarylog: Don't leak the flusher goroutine when closing a Sink (#4583) time.Ticker.Stop() doesn't close the ticker channel, so we need to signal the goroutine to die some other way commit dd589923e1a17f5cc7c667359ae12d56bc1d3113 Author: Doug Fawley Date: Fri Jul 2 16:21:46 2021 -0700 clientconn: stop automatically connecting to idle subchannels returned by picker (#4579) commit 52546c5d89b7e362064f2a21c9d10803b44af15f Author: Ashitha Santhosh <55257063+ashithasantosh@users.noreply.github.com> Date: Wed Jun 30 11:14:57 2021 -0700 authorization: translate SDK policy to Envoy RBAC proto (#4523) * Translates SDK authorization policy to Envoy RBAC proto. commit b3f274c2babaeab7802d98e21a66209846437ff5 Author: Menghan Li Date: Tue Jun 29 11:45:16 2021 -0700 xds/cluster_impl: fix cluster_impl not correctly starting LoadReport stream (#4566) commit 83f9def5feb388c4fd7e6586bd55cf6bf6d46a01 Author: Vicent Martí <42793+vmg@users.noreply.github.com> Date: Mon Jun 28 18:51:21 2021 +0200 internal/transport: do not mask ConnectionError (#4561) commit 9b2fa9f8d3caed4aae28242f6ac7cd27c790806c Author: Aliaksandr Mianzhynski Date: Fri Jun 25 08:11:47 2021 +0300 server: improve chained interceptors performance (#4524) commit e24ede593630782a7718aeb27f116446e0284f90 Author: Menghan Li Date: Thu Jun 24 16:20:11 2021 -0700 xds: delete LRS policy and move the functionality to xds_cluster_impl (#4528) - (cluster_resolver) attach locality ID to addresses - (cluster_impl) wrap SubConn - (lrs) delete commit d9eb12feed7a0f45d4acbf478e83171f4c00210a Author: Doug Fawley Date: Wed Jun 23 14:15:56 2021 -0700 xdsclient: move tests out of tests directory (#4535) commit b9270c3a7f163541823e37485aae70fcf043d406 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed Jun 23 16:36:24 2021 -0400 client: add deadline for TransportCredentials handshaker (#4559) * Add deadline on connection for TransportCredentials handshake commit 4440c3b8306d28f4af5833bdf12ac54866dc1423 Author: Menghan Li Date: Tue Jun 22 14:57:05 2021 -0700 cluster_resolver: fix DiscoveryMechanismType marshal JSON (#4532) commit 14c7ed60ad7655f522345032f0c0c7ae05303816 Author: Menghan Li Date: Tue Jun 22 11:03:12 2021 -0700 xds/circuit_breaking: counters should be keyed by {cluster, EDS service name} pair (#4560) commit 50328cf800a44d78199311c2d93f5856e4b699c1 Author: Sergii Tkachenko Date: Mon Jun 21 15:11:57 2021 -0400 buildscripts: add option to use xds-k8s test driver from a fork (#4548) commit 4faa31f0a5809a5064ee128c9d855c0bedc1c783 Author: Iskandarov Lev Date: Fri Jun 18 23:21:07 2021 +0300 stats: add stream info inside stats.Begin (#4533) commit 74fe073e9acce820ff3815b78e49aadd10439d59 Author: Doug Fawley Date: Thu Jun 17 16:53:52 2021 -0700 Revert "xds: require router filter when filters are empty" (#4556) This reverts commit 00ae0c57cc0a418f5208906d4f68c4b682dc662c. commit 1c1e3f88d343d53aa7be5712e21d42d46892bc32 Author: Menghan Li Date: Thu Jun 17 11:29:17 2021 -0700 xds: fix test race in cluster_resolver (#4555) There's a race between update sub-balancer and the first EDS resp. If sub-balancer is updated after the first EDS resp, the old balancers (round_robin) will create two lingering SubConns that are not handled, which will mess up the following SubConn state updates. commit 151c8b770a05e77528859076e2869405ac403d1a Author: Menghan Li Date: Thu Jun 17 11:14:00 2021 -0700 xds/clusterimpl: fix race between picker update and ClientConn state update (#4551) commit 00ae0c57cc0a418f5208906d4f68c4b682dc662c Author: Aliaksandr Mianzhynski Date: Thu Jun 17 20:23:18 2021 +0300 xds: require router filter when filters are empty (#4553) commit 633fbe4dfee2289937bafe9c08ccb46d045c0310 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Thu Jun 17 09:00:05 2021 -0400 xds: generate per-request hash config selector (#4525) * xds: generate per-request hash in config selector commit 7e3535650101d07525dbbfe398caf82f4ea1a6c8 Author: Konrad Reiche Date: Wed Jun 16 16:56:04 2021 -0700 metadata: add Delete method to MD (#4549) commit 4c651eda23d0bc60edc6c932ce60f1246a2a2034 Author: Menghan Li Date: Wed Jun 16 11:04:33 2021 -0700 xds: move eds package to cluster_resolver (#4545) commit 549c53a90c2a61a4bbe4e067b21f709ead03e2de Author: Menghan Li Date: Tue Jun 15 14:03:10 2021 -0700 xds/eds: rewrite EDS policy using child policies (#4457) commit cd9f53ac49fe8d2ae979dd94cb0eb2a5e5b9660c Author: Menghan Li Date: Tue Jun 15 11:09:10 2021 -0700 xds/cds: update CDS balancer to partially handle aggregated cluster (#4539) commit f06e0060c6567a63a687be461f905268b9cc193d Author: Doug Fawley Date: Tue Jun 15 10:49:54 2021 -0700 Change version to 1.40.0-dev (#4543) commit 22c535818725b54cc34ccbc4b953318f19bc13a6 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Mon Jun 14 15:02:50 2021 -0400 xds: add HashPolicy fields to RDS update (#4521) * Add HashPolicy fields to RDS update commit 45549242f79aacb850de77336a76777bef8bbe01 Author: Menghan Li Date: Fri Jun 11 13:14:09 2021 -0700 internal: fix deadlock during switch_balancer and NewSubConn() (#4536) commit 2d3b1f900edcb0f08915526e01adb17d1c829180 Author: Dustin Ward Date: Fri Jun 11 12:48:03 2021 -0400 grpc: prevent deadlock in Test/ClientUpdatesParamsAfterGoAway on failure (#4534) commit 6351a55c3895e5658b2c59769c81109d962d0e04 Author: Doug Fawley Date: Thu Jun 10 09:33:06 2021 -0700 xds: remove env var protetion of advanced routing features (#4529) commit 95e48a892d6c51e95d2aa77742da72c2df14dc28 Author: Aliaksandr Mianzhynski Date: Wed Jun 9 21:05:17 2021 +0300 Add GetServiceInfo to xds.GRPCServer (#4507) commit aa1169ab7c3b34a8ed665b16ce9cfc5343306807 Author: Doug Fawley Date: Wed Jun 9 10:01:40 2021 -0700 vet: remove support for non-module-aware Go versions (#4530) commit b1418a6e74bc6bed7dad82588b6d817b5417b20b Author: Menghan Li Date: Tue Jun 8 16:05:50 2021 -0700 xds: export XDSClient interface and use it in balancer tests (#4510) - xdsclient.New returns the interface now - xdsclient.SetClient and xdsclient.FromResolverState take and return the interface now - cleanup xds balancer tests to pass xds_client in resolver state commit 7301a311748ce82f30d8bd8076fad23ec4c7c1df Author: Menghan Li Date: Mon Jun 7 21:57:17 2021 -0700 c2p: add random number to xDS node ID in google-c2p resolver (#4519) commit d30e2c91a0545bd393774c3775cd9f9c5f5a5673 Author: Doug Fawley Date: Mon Jun 7 17:13:48 2021 -0700 xds/resolver: test xds client closed by resolver Close (#4509) commit 656cad9ae5cf6ac93dc06669f308d29be7118481 Author: Doug Fawley Date: Fri Jun 4 12:00:13 2021 -0700 xds: standardize xds client field name (xdsClient) (#4518) commit 7f9eeeae36417349a8d33f515a2cac04afceb30e Author: Doug Fawley Date: Fri Jun 4 11:40:23 2021 -0700 xds: standardize builder type names (bb) and balancer receiver names (b) (#4517) commit 7beddeea913bd74a9d3b4e7ec49f0265a0ac7b88 Author: Doug Fawley Date: Fri Jun 4 08:58:26 2021 -0700 cleanup: remove "Interface" as suffix of (almost all) interface names (#4512) commit 5c164e2b8f227a29f4aa6b2de3afb2afa38880ba Author: Doug Fawley Date: Thu Jun 3 16:10:21 2021 -0700 xds: rename xds/internal/client package to xdsclient (#4511) commit 32d5490aee8dd29a6fbfe75dc8caade5b6aa5d87 Author: Menghan Li Date: Thu Jun 3 15:23:46 2021 -0700 metadata: convert keys to lowercase in FromContext() (#4416) commit c67c056bee6a3a40a36a8d42f91fe997442a2d07 Author: Jerry Y. Chen Date: Fri Jun 4 05:28:32 2021 +0800 doc: fix typo in package networktype (#4508) commit a3715292f8de67482ffe707076b000a15747815e Author: Menghan Li Date: Thu Jun 3 13:59:37 2021 -0700 csds: return empty response if xds client is not set (#4505) commit 0956b12520b5d76fe9d43f7eda8ad51765c44ce1 Author: Menghan Li Date: Wed Jun 2 21:22:13 2021 -0700 client: handle RemoveSubConn in goroutine to avoid deadlock (#4504) commit 174b1c28afaa3c1ca3518c251deb53f014603bbd Author: Easwar Swaminathan Date: Wed Jun 2 16:47:35 2021 -0700 internal/transport: skip log on EOF when reading client preface (#4458) commit e7b12ef3b15f6c46da7c5c3c71f4ca06ba410c1c Author: Menghan Li Date: Wed Jun 2 15:58:39 2021 -0700 cluster_resolver: add functions to build child balancer config (#4429) commit 3508452162f48011bf36f303f901f4efc50087ec Author: Doug Fawley Date: Wed Jun 2 10:48:18 2021 -0700 xds: add test-only injection of xds config to client and server (#4476) commit e5cad3dcff812a49f39c8105ffb5cc4881230e60 Author: laststem Date: Wed Jun 2 08:50:35 2021 +0900 doc: fix broken benchmark dashboard link in README.md (#4503) commit 8bdcb4c9ab8de15f6a60ebce93b6f4c8d86622ef Author: Evan Jones Date: Tue Jun 1 11:54:43 2021 -0400 client: Clarify that WaitForReady will block for CONNECTING channels (#4477) commit 2de42fcbbce31dcdf14ee24836a713b65fc06dae Author: Easwar Swaminathan Date: Wed May 26 15:35:27 2021 -0700 kokoro: Specify the correct path to the build config (#4495) commit 34bd6fbb8e3b570fdbda35e5537e389f7942b406 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed May 26 14:20:25 2021 -0400 xds: add RBAC Engine (#4471) * Added RBAC Engine commit 194dcc921a94aa12fc04e2b3262ac3e4f69142b1 Author: dkkb <82504881+dkkb@users.noreply.github.com> Date: Thu May 27 02:17:27 2021 +0800 example: improve hello world server with starting msg (#4468) commit 4bae49e05b281411fd01180f7893894e39941337 Author: Doug Fawley Date: Tue May 25 16:06:58 2021 -0700 mergeable: update relnotes regex (#4488) commit bbb542c3d9c07f587e0025c9bdf0768e9624951b Author: Easwar Swaminathan Date: Tue May 25 15:46:02 2021 -0700 Kokoro build configs for PSM security interop tests (#4481) commit e26e756f13345dd19470073c5c2920b65a24ac3c Author: Easwar Swaminathan Date: Tue May 25 15:43:14 2021 -0700 Enable logging in xds interop docker containers (#4482) commit 598e3f6a9dafe9f4da7b874f9ed8c8b3c0ff65ae Author: Doug Fawley Date: Tue May 25 11:46:30 2021 -0700 github: update lock bot to github actions (#4484) commit 67b720630d6a61ae4fb38d190f16ca7685078a18 Author: Doug Fawley Date: Tue May 25 11:45:53 2021 -0700 github: increase stale bot ops per run to process everything (#4485) commit 4ecb61bedbdef3fb4c52e4f06247d504b54ace9b Author: Doug Fawley Date: Tue May 25 11:24:19 2021 -0700 github: limit repo access of testing workflows (#4483) commit 69da917ce95ec0c81e53647b43b6da5b184fdb88 Author: Doug Fawley Date: Tue May 25 10:25:54 2021 -0700 github: update stale bot to github actions (#4480) commit 280df42a316deb7962dd49d32dedbea720806473 Author: Doug Fawley Date: Tue May 25 09:16:23 2021 -0700 mergeable: require RELEASE NOTES in PR description, milestone, and Type label (#4475) commit 728364accfb93cd52003fb38a6412c8e4965116b Author: Easwar Swaminathan Date: Mon May 24 17:30:40 2021 -0700 server: return UNIMPLEMENTED on receipt of malformed method name (#4464) commit c4ed6360a98355b1ca6e772a73bd27ece15de3e9 Author: Easwar Swaminathan Date: Mon May 24 17:30:29 2021 -0700 transport: remove RequestURI field from requests in transport test (#4465) commit 359fdbb7b310c71882a354675949a4ca95957d75 Author: Doug Fawley Date: Fri May 21 15:54:45 2021 -0700 Delete .travis.yml file (#4462) commit a8e85e0d5704da1f5bd858a7b47621e77fe5035b Author: Ehsan Afzali Date: Sat May 22 01:54:24 2021 +0300 server: allow PreparedMsgs to work for server streams (#3480) commit b1f7648a9fc72ce76cbcd42d8e2c60d9d9bed9fc Author: Doug Fawley Date: Fri May 21 15:15:58 2021 -0700 client: ensure LB policy is closed before closing resolver (#4478) commit 3dd75a6888ce5d1b5195c5cf72241d9e36f68e42 Author: AlphaBaby Date: Thu May 20 02:18:52 2021 +0800 xds_client/rds: weighted_cluster totalWeight default to 100 (#4439) commit 84d0920b59e5f138ffd4da11f7b2ab51e862b581 Author: Doug Fawley Date: Wed May 19 11:05:26 2021 -0700 transport: unblock read throttling when controlbuf exits (#4447) commit 86ac0fbc4037c1e748a650002d34a8044fff59e6 Author: Aaron Jheng Date: Thu May 20 01:57:27 2021 +0800 Documentation: Fix typo (#4445) commit 23a83dd097ec07fc7ddfb4a30c675763e4972ba4 Author: Doug Fawley Date: Tue May 18 15:26:51 2021 -0700 transport: various simplifications noticed during #4447 (#4455) commit c9c9a7536f5756744347acaba907189e53c38468 Author: Menghan Li Date: Tue May 18 10:32:05 2021 -0700 internal: fix test unset env var AggregateAndDNSSupportEnv (#4454) commit 74c40c963fefb22798e08e7cf13ef616786b2402 Author: Menghan Li Date: Tue May 18 10:31:27 2021 -0700 xds/cds: fix LOGICAL_DNS cluster semantics (#4434) commit 584fa418225e60652638b79c38a189be1ff00036 Author: Menghan Li Date: Tue May 18 10:30:43 2021 -0700 xds/testing: export variables for testing (#4449) The exported variables will be used by tests (to be added in a future PR, in another package) that use these balancers as child balancer. commit 2713b77e85261254c628d9c61d00f582e6a20d08 Author: Easwar Swaminathan Date: Mon May 17 17:27:58 2021 -0700 use depth logging from the e2e package (#4448) commit 39015b9c5e190f8b687d8c79f1e6353568974104 Author: Easwar Swaminathan Date: Mon May 17 15:03:59 2021 -0700 interop/xds: support xds security on interop server (#4444) commit 9749a79336273a1957e338d519ac553f4885cee9 Author: James Protzman Date: Mon May 17 17:49:15 2021 -0400 transport: remove decodeState from server to reduce allocations (#4423) commit 78e8edf34d3649c7459e9cf88855f5bbf4f8e6f9 Author: Easwar Swaminathan Date: Mon May 17 14:13:32 2021 -0700 interop/xds: dockerfile for the xds interop client (#4443) commit a12250e98f973530f34191d39f840ae435f00a91 Author: Menghan Li Date: Fri May 14 15:20:45 2021 -0700 xds/cds: add env var for aggregated and DNS cluster (#4440) commit 50c071e9b5431dcb90be089c7159efc63edff4cb Author: Zeke Lu Date: Sat May 15 05:09:26 2021 +0800 example: correct the default value for server_host_override (#4407) commit b759b408e84fd5e990073fdaa28cd24d8eb2adad Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Fri May 14 17:02:10 2021 -0400 xds: moved shared matchers to internal/xds (#4441) * Moved shared matchers to internal/xds commit 71a1ca6c7f859658e44f0073fb754c4698216202 Author: Easwar Swaminathan Date: Fri May 14 11:13:26 2021 -0700 interop/xds: support xds credentials in interop client (#4436) commit dc77d7ffe311f78f2e577572d984af3c0a8df82b Author: Easwar Swaminathan Date: Wed May 12 18:03:52 2021 -0700 xds: revert a workaround made in #4413 (#4428) commit a16b156e990b0fb4100a4694e1c6dda779b08f77 Author: Menghan Li Date: Wed May 12 17:43:29 2021 -0700 internal: fix flaky test KeepaliveClientStaysHealthyWithResponsiveServer (#4427) Server should allow `NoStream`, otherwise there's a small chance (5/1000) the connection will be closed due to `too many pings`. commit 6fea90d7a884ad070a4f04863521eaf43e6c5d11 Author: Mayank Singhal Date: Thu May 13 05:45:47 2021 +0530 benchmark: do not allow addition of values lower than the minimum allowed in histogram stats commit a712a738897ceebf3b6690d722006b61013572e0 Author: Menghan Li Date: Wed May 12 16:25:07 2021 -0700 xds/cds: add separate fields for cluster name and eds service name (#4414) commit 397adad6a0d1d12ddd9b7f0101e902da274c15c8 Author: Easwar Swaminathan Date: Wed May 12 15:52:15 2021 -0700 update go.mod and go.sum to point to latest go-control-plane (#4425) commit 9cb99a52111e9b67165d498ec2c322774b54a5f1 Author: Menghan Li Date: Wed May 12 15:48:16 2021 -0700 xds: pretty print xDS updates and service config (#4405) commit 45e60095da54baad1e7ae28391941b64a40477e5 Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed May 12 17:28:49 2021 -0400 xds: add support for aggregate clusters (#4332) Add support for aggregate clusters in CDS Balancer commit 8bf65c69b99ed9e1106c07c1f5d2f42f312b7ec5 Author: Easwar Swaminathan Date: Wed May 12 10:18:50 2021 -0700 xds: use same format while registering and watching resources (#4422) commit aa59641d5da52eaa3728c4624e16a3ac76688c39 Author: Easwar Swaminathan Date: Wed May 12 10:17:13 2021 -0700 interop: use credentials.NewTLS() when possible (#4390) commit a95a5c3bacecea965def0addd986b3ef709f6e27 Author: James Protzman Date: Wed May 12 11:49:07 2021 -0400 transport: remove decodeState from client to reduce allocations (#3313) commit 62adda2ece5ec803c824c5009b83cea86de5030d Author: Doug Fawley Date: Tue May 11 17:05:16 2021 -0700 client: fix ForceCodec to set content-type header appropriately (#4401) commit 81b8cca6a9d92794be3e789b179e798aa1bc3209 Author: Menghan Li Date: Tue May 11 15:28:46 2021 -0700 Change version to 1.39.0-dev (#4420) commit 5f95ad62331add45bbf5ee167b67cadc72e1d322 Author: Easwar Swaminathan Date: Tue May 11 10:39:31 2021 -0700 xds: workaround to deflake xds e2e tests (#4413) commit b1940e15f6778067675e2192d8947608e8a20e32 Author: Easwar Swaminathan Date: Mon May 10 10:11:31 2021 -0700 xds: register resources at the mgmt server before requesting them (#4406) commit 98c895f7e06adc82ad030c4f90bcada672f523a2 Author: Doug Fawley Date: Mon May 10 09:35:55 2021 -0700 cleanup: use testutils.MarshalAny in more places (#4404) commit 12a377b1e4c9f1960bd25f47b9156d9dbd732ed0 Author: Easwar Swaminathan Date: Fri May 7 15:42:59 2021 -0700 xds: nack route configuration with regexes that don't compile (#4388) commit c15291b0f5929ab8cf659269a11e8aa79cb71788 Author: Doug Fawley Date: Fri May 7 15:24:10 2021 -0700 client: initialize safe config selector when creating ClientConn (#4398) commit 328b1d171a65d7e855bcd7bb5cb1f973c7e6f5d2 Author: Doug Fawley Date: Fri May 7 14:37:52 2021 -0700 transport: allow InTapHandle to return status errors (#4365) commit aff517ba8a8ded7306801c3b95f1f7f480c1268b Author: Easwar Swaminathan Date: Fri May 7 14:35:48 2021 -0700 xds: make e2e tests use a single management server instance (#4399) commit 0439465fe2b4020767d9aab1bc3055e492c14089 Author: Doug Fawley Date: Fri May 7 11:57:56 2021 -0700 xds_resolver: fix flaky Test/XDSResolverDelayedOnCommitted (#4393) Before this change, if two xds client updates came too close together, the second one could replace the first one. The fix is to wait for the effects of the first update before sending the second update. I injected a synthetic delay into handling the updates from the channel to reproduce this flake 100%, and confirmed this change fixes it. As part of this change I also noticed that we're actually calling the context cancellation function twice via defers, and never the cancel function from the test setup, so I fixed that, too. commit 0ab423af82154f9466b48cfece8043314e7114d4 Author: Menghan Li Date: Fri May 7 11:55:48 2021 -0700 test: fix flaky GoAwayThenClose (#4394) In this test, we 1. make a streaming RPC on a connection 1. graceful stop it to send a GOAWAY 1. hard stop it, so the client will create a connection to another server Before this fix, 2 and 3 can happen too soon, so the RPC in 1 would fail and then transparent retry (because the stream is unprocessed by the server in that case). This retry attempt could pick the new connection, and then the RPC would block until timeout. After this streaming RPC fails, we make unary RPCs with the same deadline (note: deadline not timeout) as the streaming RPC and expect them to succeed. But they will also fail due to timeout. The fix is to make a round-trip on the streaming RPC first, to make sure it actually goes on the first connection. commit b6f206b84f739768a1c75c1c83fe50ed75845245 Author: Doug Fawley Date: Fri May 7 11:17:26 2021 -0700 grpc: improve docs on StreamDesc (#4397) commit c7ea734087dbbcdb22137ab3b7d8b16842b080bf Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Fri May 7 08:28:34 2021 -0400 dns: fix flaky TestRateLimitedResolve (#4387) * Rewrote TestRateLimitedResolve in dns resolver test to get rid of flakiness. commit cb396472c2f78e923dc0b28565c9d704291196f8 Author: Menghan Li Date: Thu May 6 13:28:27 2021 -0700 Revert "grpc: call balancer.Close() before returning from ccBalancerWrapper.close()" (#4391) This reverts commit 28078834f35b944281662807d8ec071645c37307. commit d2d6bdae07e844b8a3502dcaf00dc7b1b5519a59 Author: Mikhail Mazurskiy <126021+ash2k@users.noreply.github.com> Date: Fri May 7 02:40:54 2021 +1000 server: add ForceServerCodec() to set a custom encoding.Codec on the server (#4205) commit d426aa5f2e5e809639b45d9619416ce22e56319a Author: Lidi Zheng Date: Wed May 5 13:37:13 2021 -0700 test: extend the xDS interop tests timeout to 360 mins (#4380) commit 40b25c5b2c2d1b06d5f5d750d759294c6037d995 Author: Easwar Swaminathan Date: Wed May 5 12:34:15 2021 -0700 xds: set correct order of certificate providers in handshake info (#4350) commit 0fc0397d779d96879d7b903c3fa1b9bd53e490e3 Author: Easwar Swaminathan Date: Tue May 4 16:54:57 2021 -0700 xds: actually close stuff in cds/eds `Close()` (#4381) commit 4f3aa7cfa157c38bd5c2da7f4568614f815ab4ad Author: Doug Fawley Date: Tue May 4 15:29:58 2021 -0700 xds: optimize fault injection filter with empty config (#4367) commit 79e55d64442716d4082d373540eac78b018e81c4 Author: Easwar Swaminathan Date: Tue May 4 15:06:43 2021 -0700 xds: use SendContext() to fail in time when the channel is full (#4386) commit 11bd77660dba95e270659c6a5077507ef37a8c41 Author: Doug Fawley Date: Tue May 4 14:51:32 2021 -0700 xds: work around xdsclient race in fault injection test (#4377) commit 75497df97f8bc9d5ec905d6e6b283a207eb3e9f0 Author: Easwar Swaminathan Date: Tue May 4 14:38:47 2021 -0700 meshca: remove meshca certificate provider implementation (#4385) commit ebd6aba6754d073a696e5727158cd0c917ce1019 Author: Menghan Li Date: Mon May 3 15:16:49 2021 -0700 Revert "xds/cds: add separate fields for cluster name and eds service name" (#4382) This reverts PRs #4352 (and two follow up fixes #4372 #4378). Because the xds interop tests were flaky. Revert before the branch cut. commit b418de839e738968aa8f845584efd0d34da4bae8 Author: Menghan Li Date: Fri Apr 30 11:53:31 2021 -0700 xds/eds: restart EDS watch after previous was canceled (#4378) commit 28078834f35b944281662807d8ec071645c37307 Author: Easwar Swaminathan Date: Thu Apr 29 21:44:26 2021 -0700 grpc: call balancer.Close() before returning from ccBalancerWrapper.close() (#4364) commit aa3ef8fb8ff6c92134743e780cf659eaa7eeccbc Author: Menghan Li Date: Thu Apr 29 12:17:56 2021 -0700 internal: regenerate proto (#4373) commit c3b66015bd51d33d3e0a75ea5086defcb9d05e64 Author: Menghan Li Date: Thu Apr 29 11:56:50 2021 -0700 xds/circuit_breaking: use cluster name as key, not EDS service name (#4372) commit 91d8f0c916d76f2a5aac9e846cd7ffcb838db769 Author: Menghan Li Date: Wed Apr 28 18:11:45 2021 -0700 serviceconfig: support marshalling BalancerConfig to JSON (#4368) commit b602d17e459c0e4d64e24b6d07875f58d5f40f0e Author: irfan sharif Date: Wed Apr 28 13:05:50 2021 -0400 metadata: reduce memory footprint in FromOutgoingContext (#4360) When Looking at memory profiles for cockroachdb/cockroach, we observed that the intermediate metadata.MD array constructed to iterate over appended metadata escaped to the heap. Fortunately, this is easily rectifiable. go build -gcflags '-m' google.golang.org/grpc/metadata ... google.golang.org/grpc/metadata/metadata.go:198:13: make([]MD, 0, len(raw.added) + 1) escapes to heap commit 24d03d9f769106b3c96b4145244ce682999d3d88 Author: Menghan Li Date: Tue Apr 27 15:22:25 2021 -0700 xds/priority: add ignore reresolution boolean to config (#4275) commit 7c5e73795d163c13e616aa53066f9e1d845275dd Author: Menghan Li Date: Tue Apr 27 13:37:48 2021 -0700 xds/cds: add separate fields for cluster name and eds service name (#4352) commit 145f12a95b19d2a2f926176cd63fe5645b376186 Author: Joshua Humphries Date: Tue Apr 27 16:15:08 2021 -0400 reflection: accept interface instead of grpc.Server struct in Register() (#4340) commit 52a707c0dafe4ac6c0443c3d83dfdeeb9b828684 Author: Easwar Swaminathan Date: Mon Apr 26 14:29:06 2021 -0700 xds: serving mode changes outlined in gRFC A36 (#4328) commit 9572fd6faeaee33fe295ce3a79eab729d05bb349 Author: apolcyn Date: Fri Apr 23 17:26:26 2021 -0700 client: include details about GOAWAYs in status messages (#4316) commit e158e3e82cbac01ba513de4b0982b35b1fcc6183 Author: Menghan Li Date: Fri Apr 23 13:15:21 2021 -0700 xds/lrs: server name is not required to be non-empty (#4356) commit 74fe6eaa41706a8451df3c03a0b131c70f71773d Author: Doug Fawley Date: Thu Apr 22 14:59:51 2021 -0700 github: testing action workflow improvements and update to test Go1.16 (#4358) commit f02863c306d287e05bcb796035b38fd956db1576 Author: Easwar Swaminathan Date: Thu Apr 22 14:58:58 2021 -0700 xds: specify "h2" as the alpn in xds creds (#4361) commit 7276af6dd73483d9edfedbef778c831f044736eb Author: Menghan Li Date: Thu Apr 22 10:45:24 2021 -0700 client: fix leaked addrConn struct when addresses are updated (#4347) commit f2783f271924fd379910c91fb62aae1dbfad83bd Author: Jan Tattermusch Date: Thu Apr 22 18:08:53 2021 +0200 Run emulated linux arm64 tests (#4344) commit 6f35bbbfb82de348a1537774af2ffd706cd3bb12 Author: Lidi Zheng Date: Wed Apr 21 17:27:51 2021 -0700 test: enable xDS CSDS test (#4354) commit 671707bdf3bfa85f176f07810de5100d0109776b Author: Menghan Li Date: Wed Apr 21 14:06:54 2021 -0700 internal: fix symbol undefined build failure (#4353) Caused by git merge commit 970aa0928304dec8dbf2bc11ee0dd49ad16c8f30 Author: Menghan Li Date: Wed Apr 21 10:11:28 2021 -0700 xds/balancers: export balancer names and config structs (#4334) commit 1c598a11a4e503e1cfd500999c040e72072dc16b Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Tue Apr 20 13:20:09 2021 -0400 Move exponential backoff to DNS resolver from resolver.ClientConn (#4270) commit 41676e61b1d576484cf2c0315a25fe2c9438c769 Author: lzhfromustc <43191155+lzhfromustc@users.noreply.github.com> Date: Mon Apr 19 12:49:37 2021 -0400 Fix goroutine leaks (#4214) commit 1a870aec2ff99bb682d5e200763c9124185eafca Author: Menghan Li Date: Thu Apr 15 15:08:03 2021 -0700 xds/clusterimpl: trigger re-resolution on subconn transient_failure (#4314) commit 87eb5b7502493f758e76c4d09430c0049a81a557 Author: Doug Fawley Date: Tue Apr 13 16:19:17 2021 -0700 credentials/google: remove unnecessary dependency on xds protos (#4339) commit 6fafb9193bde04c61d75a2da9de53c4d029748b4 Author: Easwar Swaminathan Date: Tue Apr 13 15:31:34 2021 -0700 xds: support unspecified and wildcard filter chain prefixes (#4333) commit c229922995e2c1af095282ef4d17abcd7300ecaf Author: apolcyn Date: Tue Apr 13 13:06:05 2021 -0700 client: propagate connection error causes to RPC statuses (#4311) commit 7a6ab591158c9c43b13b229a5d0a6471abfbeca6 Author: Easwar Swaminathan Date: Tue Apr 13 11:47:25 2021 -0700 multiple: go mod tidy to make vet happy (#4337) commit 950ddd3c37fc38deaf95f3a27b5883af4776a679 Author: Menghan Li Date: Mon Apr 12 09:56:37 2021 -0700 xds/google_default_creds: handshake based on cluster name in address attributes (#4310) commit fab5982df20a27885393f866db267ee7b35808d2 Author: Easwar Swaminathan Date: Fri Apr 9 16:49:25 2021 -0700 xds: server-side listener network filter validation (#4312) commit d6abfb459860721299c6f0bc7ffcbed5f9feebe4 Author: Aliaksandr Mianzhynski Date: Sat Apr 10 02:30:59 2021 +0300 cmd/protoc-gen-go-grpc: add protoc and protoc-gen-go-grpc versions to top comment (#4313) commit 1d1bbb55b381f39fbe93edbb1d0fd96a6b1ecaef Author: Menghan Li Date: Thu Apr 8 16:11:44 2021 -0700 weightedtarget: handle updating child policy name (#4309) commit 2df4370b332809e4daf1e2109b2389500e64c1c0 Author: Easwar Swaminathan Date: Thu Apr 8 16:02:52 2021 -0700 examples: update xds examples for PSM security (#4256) commit 69f6f5a51249d3a9f4b6a9262167ddd984599cdc Author: Easwar Swaminathan Date: Thu Apr 8 15:52:49 2021 -0700 xds: add support for unsupported filter matchers (#4315) commit c7a203dcb5c97bf4cc7fd79b905b044ab14a5fbc Author: Menghan Li Date: Thu Apr 8 14:31:20 2021 -0700 xds/interop: move header/path matching to all (#4325) commit 1895da54b012305f2628e3feee697937149aac57 Author: Menghan Li Date: Thu Apr 8 11:34:02 2021 -0700 xds/resolver: fix panic when two LDS updates are receives without RDS in between (#4327) Also confirmed that the LDS updates shouldn't trigger state update without the RDS. commit 493d388ad24c7a3e957f552a1a15dccdd1c9124b Author: Doug Fawley Date: Tue Apr 6 15:09:00 2021 -0700 xds/csds: update proto imports to separate grpc from non-grpc symbols (#4326) commit 004ef8ade68b267f285c82e955a2f663c9a591be Author: Menghan Li Date: Tue Apr 6 13:47:15 2021 -0700 xds/clusterimpl: fix picker update race after balancer is closed (#4318) commit 9a10f357871cf04dbc16b064b993e81e66c660f7 Author: Menghan Li Date: Tue Apr 6 13:11:49 2021 -0700 balancergroup: fix leak child balancer not closed (#4308) commit 777b228b599fd383aafd29155c35741d617b564c Author: Menghan Li Date: Tue Apr 6 10:55:19 2021 -0700 xds: fix service request counter flaky test (#4324) commit 8892a7b247c0aef5059175bacee30f2b055aac88 Author: Menghan Li Date: Mon Apr 5 13:56:00 2021 -0700 [xds_interop_client_admin] xds/interop: register admin services and reflection (#4307) commit 5730f8d113ee31f14709a787572c4a3f3af5d3dd Author: ZhenLian Date: Fri Apr 2 11:19:22 2021 -0700 Invoke Go Vet Check in Sub-modules (#4302) * Invoke Go Vet Check in Sub-modules commit db816235452978bb98c6d18ac03ce643e9ab13fc Author: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Thu Apr 1 14:41:47 2021 -0400 xds: Add fields to cluster update (#4277) * Added support for more fields in CDS response commit f6bb3972ed15a0aaf47730344c47e9840bb5dbba Author: Easwar Swaminathan Date: Wed Mar 31 16:58:24 2021 -0700 xds: filter chain matching logic for server-side (#4281) commit c72e1c8f7528615e2b5b887d279015abb2b6c659 Author: Menghan Li Date: Wed Mar 31 16:30:10 2021 -0700 xds/resolver: support inline RDS resource from LDS response (#4299) commit 0028242dbbf8efab46fb0e25cef649ef7bea1730 Author: Menghan Li Date: Wed Mar 31 10:36:16 2021 -0700 Change version to 1.38.0-dev (#4306) commit 4a19753e9dfdf7c54c4b44ae419876e94ef3a0cc Author: apolcyn Date: Fri Mar 26 10:09:12 2021 -0700 interop: add a flag to clients to statically configure grpclb (#4290) commit 2456c5cff04bb867e220f084bc88034f588c8aa8 Author: apolcyn Date: Thu Mar 25 20:56:46 2021 -0700 Allow using interop client for making Traffic Director RPCs (#4291) commit 80e380eff4edbfdacb4be1ae7d92c772400b2159 Author: longxboy Date: Fri Mar 26 04:08:24 2021 +0800 balancer/base: keep address attributes for pickers (#4253) commit 702608ffae4d03a6821b96d3e2311973d34b96dc Author: Doug Fawley Date: Wed Mar 24 10:20:16 2021 -0700 xds: enable timeout, circuit breaking, and fault injection by default (#4286) commit faf4e1c777f0c306e1632c8efda49f81f8de7646 Author: Doug Fawley Date: Tue Mar 23 15:19:03 2021 -0700 xds: rename proto import to grpc (#4287) commit 46da49ca604aef87498c628719b3408f27f4c6d7 Author: Doug Fawley Date: Tue Mar 23 13:26:01 2021 -0700 xds: use different proto import for grpc services (#4285) commit b331a48e06791ab7595f706af46b8bf9244d1f2e Author: Doug Fawley Date: Tue Mar 23 10:42:27 2021 -0700 alts: re-add vmOnGCP and once globals for easier testing (#4284) commit d26af8e3916597bde07641df24c3d38ca9b1f5a2 Author: Menghan Li Date: Mon Mar 22 15:14:11 2021 -0700 admin: implement admin services (#4274) commit f320c793495fc90c222831f1708c18119793e0f8 Author: Doug Fawley Date: Mon Mar 22 09:42:11 2021 -0700 test: enable fault_injection xds test (#4283) commit bce1cded4b05db45e02a87b94b75fa5cb07a76a5 Author: Menghan Li Date: Thu Mar 18 16:01:39 2021 -0700 internal: use strings.Replace instead strings.ReplaceAll (#4279) strings.ReplaceAll is only available after go 1.12. We still support go 1.11. commit 967933baf52a7bd113bfc23cbf4c5d01b8367d5b Author: Menghan Li Date: Tue Mar 16 14:50:07 2021 -0700 xds/cdsbalancer: move xds client close to run() (#4273) Otherwise client may be used by run() after closed. commit 95173a53fe5443098a11acca95a71eae006ecaa9 Author: Menghan Li Date: Tue Mar 16 14:05:05 2021 -0700 csds: implement CSDS service handler (#4243) commit 1e7119b13689dac5b8fe0a70118434eff96e997b Author: Easwar Swaminathan Date: Mon Mar 15 14:13:13 2021 -0700 xds: support all matchers for SANs (#4246) commit 21976fa3e38a266811384409bc8b25437cc1ff1d Author: Doug Fawley Date: Fri Mar 12 15:19:57 2021 -0800 xds: disable fault injection test on 386 (#4264) commit de3c78e4f1f16ca8e7f8fed6d12dcf3e1de337c1 Author: Easwar Swaminathan Date: Fri Mar 12 13:23:11 2021 -0800 xds: validate 'listener_filters' and 'use_original_dst' fields (#4258) commit d7737376c30e219e2773ebb143c678bdd91810e4 Author: Doug Fawley Date: Fri Mar 12 08:38:49 2021 -0800 xds: implement fault injection HTTP filter (A33) (#4236) commit f168a3cb3bf52b839691623a343b867d9ef7a566 Author: Menghan Li Date: Thu Mar 11 14:17:43 2021 -0800 c2p: add google-c2p resolver (#4204) commit 2f7f1f6c22e925c849e56bbe0823961588f9f9ec Author: Easwar Swaminathan Date: Thu Mar 11 12:07:48 2021 -0800 rls: update pb.gos after https://github.com/grpc/grpc-proto/pull/92 (#4257) commit e8930beb0e042eaf70b0a0a0a87dcaf8daffe782 Author: Easwar Swaminathan Date: Wed Mar 10 21:12:44 2021 -0800 xds: Prepare to support filter chains on the server (#4222) commit a45f13b160731da4f8b356e0449e9754312e5f1a Author: Easwar Swaminathan Date: Wed Mar 10 09:26:23 2021 -0800 xds: Support server_listener_resource_name_template (#4233) --- .github/lock.yml | 2 - .github/mergeable.yml | 39 +- .github/stale.yml | 58 - .github/workflows/codeql-analysis.yml | 10 +- .github/workflows/lock.yml | 20 + .github/workflows/stale.yml | 30 + .github/workflows/testing.yml | 91 +- .travis.yml | 42 - Documentation/server-reflection-tutorial.md | 2 +- MAINTAINERS.md | 5 +- Makefile | 2 - NOTICE.txt | 13 + README.md | 2 +- admin/admin.go | 58 + .../sni_appengine.go => admin/admin_test.go | 20 +- admin/test/admin_test.go | 38 + admin/test/utils.go | 114 + authz/rbac_translator.go | 306 ++ authz/rbac_translator_test.go | 273 ++ authz/sdk_end2end_test.go | 548 +++ authz/sdk_server_interceptors.go | 172 + authz/sdk_server_interceptors_test.go | 121 + balancer/balancer.go | 68 +- balancer/base/balancer.go | 44 +- balancer/base/balancer_test.go | 54 +- .../grpc_lb_v1/load_balancer_grpc.pb.go | 4 + balancer/grpclb/grpclb.go | 40 +- balancer/grpclb/grpclb_config.go | 1 + balancer/grpclb/grpclb_config_test.go | 40 +- balancer/grpclb/grpclb_remote_balancer.go | 43 +- balancer/grpclb/grpclb_test.go | 643 ++-- balancer/grpclb/grpclb_test_util_test.go | 25 +- balancer/rls/internal/balancer.go | 4 + balancer/rls/internal/config.go | 262 +- balancer/rls/internal/config_test.go | 157 +- balancer/rls/internal/keys/builder.go | 2 +- .../internal/proto/grpc_lookup_v1/rls.pb.go | 180 +- .../proto/grpc_lookup_v1/rls_config.pb.go | 319 +- .../proto/grpc_lookup_v1/rls_grpc.pb.go | 4 + balancer/roundrobin/roundrobin.go | 4 +- balancer_conn_wrappers.go | 116 +- balancer_conn_wrappers_test.go | 90 - balancer_switching_test.go | 48 + benchmark/stats/histogram.go | 11 +- call_test.go | 2 +- channelz/grpc_channelz_v1/channelz_grpc.pb.go | 4 + channelz/service/func_linux.go | 54 +- channelz/service/func_nonlinux.go | 3 +- channelz/service/service.go | 8 +- channelz/service/service_sktopt_test.go | 1 + channelz/service/util_sktopt_386_test.go | 1 + channelz/service/util_sktopt_amd64_test.go | 1 + clientconn.go | 355 +- clientconn_authority_test.go | 122 + clientconn_parsed_target_test.go | 183 + clientconn_state_transition_test.go | 25 +- clientconn_test.go | 96 +- cmd/protoc-gen-go-grpc/grpc.go | 23 +- connectivity/connectivity.go | 35 +- credentials/alts/alts.go | 5 +- credentials/alts/alts_test.go | 1 + credentials/alts/internal/conn/record_test.go | 82 +- .../internal/handshaker/handshaker_test.go | 7 +- .../proto/grpc_gcp/handshaker_grpc.pb.go | 4 + credentials/alts/utils.go | 94 - credentials/alts/utils_test.go | 64 +- credentials/credentials.go | 24 +- credentials/go12.go | 30 - credentials/google/google.go | 81 +- credentials/google/google_test.go | 131 + credentials/google/xds.go | 90 + credentials/local/local_test.go | 2 + credentials/oauth/oauth.go | 19 +- credentials/oauth/oauth_test.go | 60 + credentials/sts/sts.go | 2 - credentials/sts/sts_test.go | 6 +- credentials/tls.go | 3 + .../tls/certprovider/distributor_test.go | 2 - .../tls/certprovider/meshca/builder.go | 165 - .../tls/certprovider/meshca/builder_test.go | 177 - credentials/tls/certprovider/meshca/config.go | 310 -- .../tls/certprovider/meshca/config_test.go | 375 --- .../meshca/internal/v1/meshca.pb.go | 276 -- .../meshca/internal/v1/meshca_grpc.pb.go | 106 - credentials/tls/certprovider/meshca/plugin.go | 289 -- .../tls/certprovider/meshca/plugin_test.go | 459 --- .../tls/certprovider/pemfile/watcher_test.go | 16 +- credentials/tls/certprovider/store_test.go | 31 +- credentials/xds/xds.go | 12 +- credentials/xds/xds_client_test.go | 33 +- credentials/xds/xds_server_test.go | 61 +- dialoptions.go | 17 +- examples/examples_test.sh | 3 + examples/features/encryption/README.md | 4 +- examples/features/proto/echo/echo_grpc.pb.go | 4 + examples/features/unix_abstract/README.md | 29 + .../features/unix_abstract/client/main.go | 68 + .../features/unix_abstract/server/main.go | 58 + examples/features/xds/client/main.go | 66 +- examples/features/xds/server/main.go | 93 +- examples/go.mod | 6 +- examples/go.sum | 53 +- examples/helloworld/greeter_server/main.go | 1 + .../helloworld/helloworld_grpc.pb.go | 4 + examples/route_guide/client/client.go | 2 +- .../routeguide/route_guide_grpc.pb.go | 4 + go.mod | 13 +- go.sum | 53 +- grpclog/loggerv2.go | 86 +- grpclog/loggerv2_test.go | 4 +- health/grpc_health_v1/health_grpc.pb.go | 4 + install_gae.sh | 6 - internal/admin/admin.go | 60 + internal/balancer/stub/stub.go | 7 + internal/binarylog/sink.go | 41 +- internal/channelz/funcs.go | 2 +- internal/channelz/types_linux.go | 2 - internal/channelz/types_nonlinux.go | 5 +- internal/channelz/util_linux.go | 2 - internal/channelz/util_nonlinux.go | 3 +- internal/channelz/util_test.go | 3 +- internal/credentials/credentials.go | 49 + internal/credentials/spiffe.go | 2 - internal/credentials/spiffe_appengine.go | 31 - internal/credentials/syscallconn.go | 2 - internal/credentials/syscallconn_test.go | 2 - internal/credentials/util.go | 4 +- internal/credentials/xds/handshake_info.go | 151 +- .../credentials/xds/handshake_info_test.go | 304 ++ internal/envconfig/envconfig.go | 6 +- internal/googlecloud/googlecloud.go | 128 + internal/googlecloud/googlecloud_test.go | 86 + internal/grpcrand/grpcrand.go | 29 +- internal/grpctest/tlogger.go | 57 +- internal/grpcutil/target_test.go | 114 - internal/internal.go | 11 +- internal/leakcheck/leakcheck.go | 1 - internal/pretty/pretty.go | 82 + internal/profiling/buffer/buffer.go | 2 - internal/profiling/buffer/buffer_appengine.go | 43 - internal/profiling/buffer/buffer_test.go | 2 - internal/profiling/goid_modified.go | 1 + internal/profiling/goid_regular.go | 1 + internal/resolver/config_selector.go | 7 +- internal/resolver/config_selector_test.go | 6 +- internal/resolver/dns/dns_resolver.go | 52 +- internal/resolver/dns/dns_resolver_test.go | 411 ++- internal/resolver/dns/go113.go | 33 - internal/serviceconfig/serviceconfig.go | 20 +- internal/serviceconfig/serviceconfig_test.go | 53 +- internal/status/status.go | 14 +- internal/syscall/syscall_linux.go | 2 - internal/syscall/syscall_nonlinux.go | 21 +- .../testutils/marshal_any.go | 26 +- internal/transport/controlbuf.go | 60 +- internal/transport/handler_server.go | 3 +- internal/transport/handler_server_test.go | 13 +- internal/transport/http2_client.go | 322 +- internal/transport/http2_server.go | 275 +- internal/transport/http_util.go | 224 +- internal/transport/http_util_test.go | 65 - internal/transport/keepalive_test.go | 64 +- internal/transport/networktype/networktype.go | 2 +- internal/transport/proxy_test.go | 3 +- internal/transport/transport.go | 19 +- internal/transport/transport_test.go | 558 +++- internal/xds/bootstrap.go | 147 + {xds/internal => internal/xds}/env/env.go | 53 +- internal/xds/matcher/matcher_header.go | 253 ++ .../xds/matcher}/matcher_header_test.go | 36 +- internal/xds/matcher/string_matcher.go | 183 + internal/xds/matcher/string_matcher_test.go | 309 ++ internal/xds/rbac/matchers.go | 426 +++ internal/xds/rbac/rbac_engine.go | 225 ++ internal/xds/rbac/rbac_engine_test.go | 1007 ++++++ internal/xds_handshake_cluster.go | 40 + interop/client/client.go | 72 +- .../grpc_testing/benchmark_service_grpc.pb.go | 4 + .../report_qps_scenario_service_grpc.pb.go | 4 + interop/grpc_testing/test_grpc.pb.go | 4 + .../grpc_testing/worker_service_grpc.pb.go | 4 + .../{client.go => client_linux.go} | 2 - interop/test_utils.go | 88 + interop/xds/client/Dockerfile | 37 + interop/xds/client/client.go | 27 +- interop/xds/server/Dockerfile | 36 + interop/xds/server/server.go | 147 +- metadata/metadata.go | 100 +- metadata/metadata_test.go | 29 + picker_wrapper.go | 4 +- pickfirst.go | 21 +- profiling/proto/service_grpc.pb.go | 4 + .../reflection_grpc.pb.go | 4 + reflection/grpc_testing/test_grpc.pb.go | 4 + reflection/serverreflection.go | 14 +- regenerate.sh | 10 - resolver/manual/manual.go | 20 +- resolver/resolver.go | 2 +- resolver_conn_wrapper.go | 71 +- resolver_conn_wrapper_test.go | 85 - rpc_util.go | 50 +- security/advancedtls/advancedtls.go | 71 +- .../advancedtls_integration_test.go | 4 +- security/advancedtls/advancedtls_test.go | 36 +- security/advancedtls/crl.go | 499 +++ security/advancedtls/crl_test.go | 718 ++++ security/advancedtls/examples/go.mod | 2 +- security/advancedtls/examples/go.sum | 41 +- security/advancedtls/go.mod | 3 +- security/advancedtls/go.sum | 41 +- security/advancedtls/sni.go | 2 - security/advancedtls/sni_beforego114.go | 42 - security/advancedtls/testdata/crl/0b35a562.r0 | 1 + security/advancedtls/testdata/crl/0b35a562.r1 | 1 + security/advancedtls/testdata/crl/1.crl | 10 + security/advancedtls/testdata/crl/1ab871c8.r0 | 1 + security/advancedtls/testdata/crl/2.crl | 10 + security/advancedtls/testdata/crl/3.crl | 11 + security/advancedtls/testdata/crl/4.crl | 10 + security/advancedtls/testdata/crl/5.crl | 10 + security/advancedtls/testdata/crl/6.crl | 11 + security/advancedtls/testdata/crl/71eac5a2.r0 | 1 + security/advancedtls/testdata/crl/7a1799af.r0 | 1 + security/advancedtls/testdata/crl/8828a7e6.r0 | 1 + security/advancedtls/testdata/crl/README.md | 48 + security/advancedtls/testdata/crl/deee447d.r0 | 1 + .../advancedtls/testdata/crl/revokedInt.pem | 58 + .../advancedtls/testdata/crl/revokedLeaf.pem | 59 + .../advancedtls/testdata/crl/unrevoked.pem | 58 + security/authorization/go.mod | 2 +- security/authorization/go.sum | 11 - server.go | 226 +- server_test.go | 57 + stats/stats.go | 11 +- stats/stats_test.go | 70 +- stream.go | 196 +- stress/grpc_testing/metrics_grpc.pb.go | 4 + tap/tap.go | 16 +- test/authority_test.go | 1 + test/balancer_test.go | 60 +- test/bufconn/bufconn.go | 10 + ...x_go110_test.go => channelz_linux_test.go} | 2 - test/channelz_test.go | 18 +- test/end2end_test.go | 1009 +++++- test/go_vet/vet.go | 53 - test/grpc_testing/test_grpc.pb.go | 4 + test/insecure_creds_test.go | 53 +- test/kokoro/xds.cfg | 2 +- test/kokoro/xds.sh | 8 +- test/kokoro/xds_k8s.cfg | 13 + test/kokoro/xds_k8s.sh | 155 + test/kokoro/xds_url_map.cfg | 13 + test/kokoro/xds_url_map.sh | 138 + test/kokoro/xds_v3.cfg | 2 +- test/race.go | 1 + test/retry_test.go | 231 +- test/tools/go.mod | 2 +- test/tools/tools.go | 8 +- xds/go113.go => test/tools/tools_vet.go | 12 +- version.go | 2 +- vet.sh | 56 +- xds/csds/csds.go | 305 ++ xds/csds/csds_test.go | 739 +++++ xds/googledirectpath/googlec2p.go | 178 + xds/googledirectpath/googlec2p_test.go | 242 ++ xds/googledirectpath/utils.go | 96 + xds/internal/balancer/balancer.go | 10 +- .../balancer/balancergroup/balancergroup.go | 38 +- .../balancergroup/balancergroup_test.go | 34 +- .../balancer/cdsbalancer/cdsbalancer.go | 288 +- .../cdsbalancer/cdsbalancer_security_test.go | 99 +- .../balancer/cdsbalancer/cdsbalancer_test.go | 163 +- .../balancer/cdsbalancer/cluster_handler.go | 318 ++ .../cdsbalancer/cluster_handler_test.go | 685 ++++ .../balancer/clusterimpl/balancer_test.go | 555 +++- .../balancer/clusterimpl/clusterimpl.go | 508 ++- xds/internal/balancer/clusterimpl/config.go | 27 +- .../balancer/clusterimpl/config_test.go | 14 +- xds/internal/balancer/clusterimpl/picker.go | 90 +- .../clustermanager/balancerstateaggregator.go | 9 +- .../balancer/clustermanager/clustermanager.go | 18 +- .../clustermanager/clustermanager_test.go | 65 + .../clusterresolver/clusterresolver.go | 378 +++ .../clusterresolver/clusterresolver_test.go | 500 +++ .../balancer/clusterresolver/config.go | 185 ++ .../balancer/clusterresolver/config_test.go | 269 ++ .../balancer/clusterresolver/configbuilder.go | 364 ++ .../clusterresolver/configbuilder_test.go | 979 ++++++ .../balancer/clusterresolver/eds_impl_test.go | 575 ++++ .../{lrs => clusterresolver}/logging.go | 6 +- .../priority_test.go} | 610 ++-- .../clusterresolver/resource_resolver.go | 247 ++ .../clusterresolver/resource_resolver_dns.go | 114 + .../clusterresolver/resource_resolver_test.go | 870 +++++ .../testutil_test.go} | 75 +- xds/internal/balancer/edsbalancer/config.go | 124 - xds/internal/balancer/edsbalancer/eds.go | 392 --- xds/internal/balancer/edsbalancer/eds_impl.go | 571 ---- .../balancer/edsbalancer/eds_impl_priority.go | 358 -- .../balancer/edsbalancer/eds_impl_test.go | 935 ------ xds/internal/balancer/edsbalancer/eds_test.go | 825 ----- .../edsbalancer/load_store_wrapper.go | 88 - xds/internal/balancer/edsbalancer/util.go | 44 - .../balancer/edsbalancer/util_test.go | 88 - .../balancer/edsbalancer/xds_lrs_test.go | 71 - xds/internal/balancer/edsbalancer/xds_old.go | 46 - .../balancer/loadstore/load_store_wrapper.go | 2 +- xds/internal/balancer/lrs/balancer.go | 246 -- xds/internal/balancer/lrs/balancer_test.go | 144 - xds/internal/balancer/lrs/config.go | 54 - xds/internal/balancer/lrs/config_test.go | 127 - xds/internal/balancer/lrs/picker.go | 85 - xds/internal/balancer/priority/balancer.go | 32 +- .../balancer/priority/balancer_child.go | 20 +- .../balancer/priority/balancer_priority.go | 19 +- .../balancer/priority/balancer_test.go | 536 ++- xds/internal/balancer/priority/config.go | 19 +- xds/internal/balancer/priority/config_test.go | 15 +- .../balancer/priority/ignore_resolve_now.go | 73 + .../priority/ignore_resolve_now_test.go | 104 + xds/internal/balancer/ringhash/config.go | 56 + xds/internal/balancer/ringhash/config_test.go | 68 + .../{edsbalancer => ringhash}/logging.go | 8 +- xds/internal/balancer/ringhash/picker.go | 154 + xds/internal/balancer/ringhash/picker_test.go | 285 ++ xds/internal/balancer/ringhash/ring.go | 163 + xds/internal/balancer/ringhash/ring_test.go | 113 + xds/internal/balancer/ringhash/ringhash.go | 434 +++ .../balancer/ringhash/ringhash_test.go | 458 +++ xds/internal/balancer/ringhash/util.go | 40 + .../weightedaggregator/aggregator.go | 8 +- .../balancer/weightedtarget/weightedtarget.go | 81 +- .../weightedtarget/weightedtarget_config.go | 29 +- .../weightedtarget_config_test.go | 32 +- .../weightedtarget/weightedtarget_test.go | 103 +- xds/internal/client/cds_test.go | 833 ----- xds/internal/client/lds_test.go | 1628 --------- xds/internal/client/requests_counter.go | 82 - xds/internal/client/singleton.go | 101 - xds/internal/client/tests/README.md | 1 - xds/internal/client/watchers_listener_test.go | 358 -- xds/internal/client/xds.go | 915 ----- xds/internal/httpfilter/fault/fault.go | 301 ++ xds/internal/httpfilter/fault/fault_test.go | 672 ++++ xds/internal/httpfilter/httpfilter.go | 11 +- xds/internal/httpfilter/rbac/rbac.go | 220 ++ xds/internal/httpfilter/router/router.go | 21 +- xds/internal/internal.go | 18 + xds/internal/resolver/matcher.go | 161 - xds/internal/resolver/matcher_header.go | 188 -- xds/internal/resolver/serviceconfig.go | 120 +- xds/internal/resolver/serviceconfig_test.go | 74 + xds/internal/resolver/watch_service.go | 155 +- xds/internal/resolver/watch_service_test.go | 146 +- xds/internal/resolver/xds_resolver.go | 46 +- xds/internal/resolver/xds_resolver_test.go | 346 +- xds/internal/server/conn_wrapper.go | 165 + xds/internal/server/listener_wrapper.go | 442 +++ xds/internal/server/listener_wrapper_test.go | 484 +++ xds/internal/server/rds_handler.go | 133 + xds/internal/server/rds_handler_test.go | 401 +++ xds/internal/test/e2e/README.md | 19 + xds/internal/test/e2e/controlplane.go | 62 + xds/internal/test/e2e/e2e.go | 178 + xds/internal/test/e2e/e2e_test.go | 257 ++ .../internal/test/e2e/e2e_utils.go | 24 +- xds/internal/test/e2e/run.sh | 6 + xds/internal/test/xds_client_affinity_test.go | 136 + .../test/xds_client_integration_test.go | 212 +- xds/internal/test/xds_integration_test.go | 151 +- .../test/xds_security_config_nack_test.go | 372 +++ .../test/xds_server_integration_test.go | 1304 ++++++-- .../test/xds_server_serving_mode_test.go | 388 +++ xds/internal/testutils/balancer.go | 31 +- xds/internal/testutils/e2e/bootstrap.go | 116 +- xds/internal/testutils/e2e/clientresources.go | 308 +- xds/internal/testutils/e2e/server.go | 48 +- xds/internal/testutils/fakeclient/client.go | 152 +- xds/internal/testutils/protos.go | 4 +- xds/internal/xdsclient/attributes.go | 59 + .../bootstrap/bootstrap.go | 53 +- .../bootstrap/bootstrap_test.go | 30 +- .../bootstrap/logging.go | 0 .../{client => xdsclient}/callback.go | 215 +- xds/internal/xdsclient/cds_test.go | 1590 +++++++++ xds/internal/{client => xdsclient}/client.go | 301 +- .../{client => xdsclient}/client_test.go | 93 +- xds/internal/{client => xdsclient}/dump.go | 2 +- .../{client/tests => xdsclient}/dump_test.go | 137 +- .../{client => xdsclient}/eds_test.go | 105 +- xds/internal/{client => xdsclient}/errors.go | 2 +- xds/internal/xdsclient/filter_chain.go | 852 +++++ xds/internal/xdsclient/filter_chain_test.go | 2939 +++++++++++++++++ xds/internal/xdsclient/lds_test.go | 1944 +++++++++++ .../{client => xdsclient}/load/reporter.go | 0 .../{client => xdsclient}/load/store.go | 0 .../{client => xdsclient}/load/store_test.go | 0 .../{client => xdsclient}/loadreport.go | 4 +- .../tests => xdsclient}/loadreport_test.go | 10 +- xds/internal/{client => xdsclient}/logging.go | 2 +- xds/internal/xdsclient/matcher.go | 278 ++ .../{resolver => xdsclient}/matcher_path.go | 4 +- .../matcher_path_test.go | 2 +- .../{resolver => xdsclient}/matcher_test.go | 77 +- .../{client => xdsclient}/rds_test.go | 739 ++++- xds/internal/xdsclient/requests_counter.go | 107 + .../requests_counter_test.go | 42 +- xds/internal/xdsclient/singleton.go | 198 ++ .../{client => xdsclient}/transport_helper.go | 11 +- .../{client => xdsclient}/v2/ack_test.go | 2 +- .../{client => xdsclient}/v2/cds_test.go | 27 +- .../{client => xdsclient}/v2/client.go | 51 +- .../{client => xdsclient}/v2/client_test.go | 104 +- .../{client => xdsclient}/v2/eds_test.go | 44 +- .../{client => xdsclient}/v2/lds_test.go | 37 +- .../{client => xdsclient}/v2/loadreport.go | 19 +- .../{client => xdsclient}/v2/rds_test.go | 47 +- .../{client => xdsclient}/v3/client.go | 51 +- .../{client => xdsclient}/v3/loadreport.go | 19 +- .../{client => xdsclient}/watchers.go | 23 +- .../watchers_cluster_test.go | 260 +- .../watchers_endpoints_test.go | 199 +- .../xdsclient/watchers_listener_test.go | 591 ++++ .../watchers_route_test.go | 198 +- xds/internal/xdsclient/xds.go | 1334 ++++++++ .../xdsclient_test.go} | 8 +- xds/server.go | 501 ++- xds/server_options.go | 76 + xds/server_test.go | 509 ++- xds/xds.go | 64 +- 430 files changed, 44656 insertions(+), 17635 deletions(-) delete mode 100644 .github/lock.yml delete mode 100644 .github/stale.yml create mode 100644 .github/workflows/lock.yml create mode 100644 .github/workflows/stale.yml delete mode 100644 .travis.yml create mode 100644 NOTICE.txt create mode 100644 admin/admin.go rename security/advancedtls/sni_appengine.go => admin/admin_test.go (63%) create mode 100644 admin/test/admin_test.go create mode 100644 admin/test/utils.go create mode 100644 authz/rbac_translator.go create mode 100644 authz/rbac_translator_test.go create mode 100644 authz/sdk_end2end_test.go create mode 100644 authz/sdk_server_interceptors.go create mode 100644 authz/sdk_server_interceptors_test.go delete mode 100644 balancer_conn_wrappers_test.go create mode 100644 clientconn_authority_test.go create mode 100644 clientconn_parsed_target_test.go delete mode 100644 credentials/go12.go create mode 100644 credentials/google/google_test.go create mode 100644 credentials/google/xds.go create mode 100644 credentials/oauth/oauth_test.go delete mode 100644 credentials/tls/certprovider/meshca/builder.go delete mode 100644 credentials/tls/certprovider/meshca/builder_test.go delete mode 100644 credentials/tls/certprovider/meshca/config.go delete mode 100644 credentials/tls/certprovider/meshca/config_test.go delete mode 100644 credentials/tls/certprovider/meshca/internal/v1/meshca.pb.go delete mode 100644 credentials/tls/certprovider/meshca/internal/v1/meshca_grpc.pb.go delete mode 100644 credentials/tls/certprovider/meshca/plugin.go delete mode 100644 credentials/tls/certprovider/meshca/plugin_test.go create mode 100644 examples/features/unix_abstract/README.md create mode 100644 examples/features/unix_abstract/client/main.go create mode 100644 examples/features/unix_abstract/server/main.go delete mode 100755 install_gae.sh create mode 100644 internal/admin/admin.go create mode 100644 internal/credentials/credentials.go delete mode 100644 internal/credentials/spiffe_appengine.go create mode 100644 internal/credentials/xds/handshake_info_test.go create mode 100644 internal/googlecloud/googlecloud.go create mode 100644 internal/googlecloud/googlecloud_test.go delete mode 100644 internal/grpcutil/target_test.go create mode 100644 internal/pretty/pretty.go delete mode 100644 internal/profiling/buffer/buffer_appengine.go delete mode 100644 internal/resolver/dns/go113.go rename credentials/tls/certprovider/meshca/logging.go => internal/testutils/marshal_any.go (55%) create mode 100644 internal/xds/bootstrap.go rename {xds/internal => internal/xds}/env/env.go (52%) create mode 100644 internal/xds/matcher/matcher_header.go rename {xds/internal/resolver => internal/xds/matcher}/matcher_header_test.go (88%) create mode 100644 internal/xds/matcher/string_matcher.go create mode 100644 internal/xds/matcher/string_matcher_test.go create mode 100644 internal/xds/rbac/matchers.go create mode 100644 internal/xds/rbac/rbac_engine.go create mode 100644 internal/xds/rbac/rbac_engine_test.go create mode 100644 internal/xds_handshake_cluster.go rename interop/grpclb_fallback/{client.go => client_linux.go} (99%) create mode 100644 interop/xds/client/Dockerfile create mode 100644 interop/xds/server/Dockerfile create mode 100644 security/advancedtls/crl.go create mode 100644 security/advancedtls/crl_test.go delete mode 100644 security/advancedtls/sni_beforego114.go create mode 120000 security/advancedtls/testdata/crl/0b35a562.r0 create mode 120000 security/advancedtls/testdata/crl/0b35a562.r1 create mode 100644 security/advancedtls/testdata/crl/1.crl create mode 120000 security/advancedtls/testdata/crl/1ab871c8.r0 create mode 100644 security/advancedtls/testdata/crl/2.crl create mode 100644 security/advancedtls/testdata/crl/3.crl create mode 100644 security/advancedtls/testdata/crl/4.crl create mode 100644 security/advancedtls/testdata/crl/5.crl create mode 100644 security/advancedtls/testdata/crl/6.crl create mode 120000 security/advancedtls/testdata/crl/71eac5a2.r0 create mode 120000 security/advancedtls/testdata/crl/7a1799af.r0 create mode 120000 security/advancedtls/testdata/crl/8828a7e6.r0 create mode 100644 security/advancedtls/testdata/crl/README.md create mode 120000 security/advancedtls/testdata/crl/deee447d.r0 create mode 100644 security/advancedtls/testdata/crl/revokedInt.pem create mode 100644 security/advancedtls/testdata/crl/revokedLeaf.pem create mode 100644 security/advancedtls/testdata/crl/unrevoked.pem rename test/{channelz_linux_go110_test.go => channelz_linux_test.go} (99%) delete mode 100644 test/go_vet/vet.go create mode 100644 test/kokoro/xds_k8s.cfg create mode 100755 test/kokoro/xds_k8s.sh create mode 100644 test/kokoro/xds_url_map.cfg create mode 100755 test/kokoro/xds_url_map.sh rename xds/go113.go => test/tools/tools_vet.go (75%) create mode 100644 xds/csds/csds.go create mode 100644 xds/csds/csds_test.go create mode 100644 xds/googledirectpath/googlec2p.go create mode 100644 xds/googledirectpath/googlec2p_test.go create mode 100644 xds/googledirectpath/utils.go create mode 100644 xds/internal/balancer/cdsbalancer/cluster_handler.go create mode 100644 xds/internal/balancer/cdsbalancer/cluster_handler_test.go create mode 100644 xds/internal/balancer/clusterresolver/clusterresolver.go create mode 100644 xds/internal/balancer/clusterresolver/clusterresolver_test.go create mode 100644 xds/internal/balancer/clusterresolver/config.go create mode 100644 xds/internal/balancer/clusterresolver/config_test.go create mode 100644 xds/internal/balancer/clusterresolver/configbuilder.go create mode 100644 xds/internal/balancer/clusterresolver/configbuilder_test.go create mode 100644 xds/internal/balancer/clusterresolver/eds_impl_test.go rename xds/internal/balancer/{lrs => clusterresolver}/logging.go (84%) rename xds/internal/balancer/{edsbalancer/eds_impl_priority_test.go => clusterresolver/priority_test.go} (54%) create mode 100644 xds/internal/balancer/clusterresolver/resource_resolver.go create mode 100644 xds/internal/balancer/clusterresolver/resource_resolver_dns.go create mode 100644 xds/internal/balancer/clusterresolver/resource_resolver_test.go rename xds/internal/balancer/{edsbalancer/eds_testutil.go => clusterresolver/testutil_test.go} (61%) delete mode 100644 xds/internal/balancer/edsbalancer/config.go delete mode 100644 xds/internal/balancer/edsbalancer/eds.go delete mode 100644 xds/internal/balancer/edsbalancer/eds_impl.go delete mode 100644 xds/internal/balancer/edsbalancer/eds_impl_priority.go delete mode 100644 xds/internal/balancer/edsbalancer/eds_impl_test.go delete mode 100644 xds/internal/balancer/edsbalancer/eds_test.go delete mode 100644 xds/internal/balancer/edsbalancer/load_store_wrapper.go delete mode 100644 xds/internal/balancer/edsbalancer/util.go delete mode 100644 xds/internal/balancer/edsbalancer/util_test.go delete mode 100644 xds/internal/balancer/edsbalancer/xds_lrs_test.go delete mode 100644 xds/internal/balancer/edsbalancer/xds_old.go delete mode 100644 xds/internal/balancer/lrs/balancer.go delete mode 100644 xds/internal/balancer/lrs/balancer_test.go delete mode 100644 xds/internal/balancer/lrs/config.go delete mode 100644 xds/internal/balancer/lrs/config_test.go delete mode 100644 xds/internal/balancer/lrs/picker.go create mode 100644 xds/internal/balancer/priority/ignore_resolve_now.go create mode 100644 xds/internal/balancer/priority/ignore_resolve_now_test.go create mode 100644 xds/internal/balancer/ringhash/config.go create mode 100644 xds/internal/balancer/ringhash/config_test.go rename xds/internal/balancer/{edsbalancer => ringhash}/logging.go (83%) create mode 100644 xds/internal/balancer/ringhash/picker.go create mode 100644 xds/internal/balancer/ringhash/picker_test.go create mode 100644 xds/internal/balancer/ringhash/ring.go create mode 100644 xds/internal/balancer/ringhash/ring_test.go create mode 100644 xds/internal/balancer/ringhash/ringhash.go create mode 100644 xds/internal/balancer/ringhash/ringhash_test.go create mode 100644 xds/internal/balancer/ringhash/util.go delete mode 100644 xds/internal/client/cds_test.go delete mode 100644 xds/internal/client/lds_test.go delete mode 100644 xds/internal/client/requests_counter.go delete mode 100644 xds/internal/client/singleton.go delete mode 100644 xds/internal/client/tests/README.md delete mode 100644 xds/internal/client/watchers_listener_test.go delete mode 100644 xds/internal/client/xds.go create mode 100644 xds/internal/httpfilter/fault/fault.go create mode 100644 xds/internal/httpfilter/fault/fault_test.go create mode 100644 xds/internal/httpfilter/rbac/rbac.go delete mode 100644 xds/internal/resolver/matcher.go delete mode 100644 xds/internal/resolver/matcher_header.go create mode 100644 xds/internal/server/conn_wrapper.go create mode 100644 xds/internal/server/listener_wrapper.go create mode 100644 xds/internal/server/listener_wrapper_test.go create mode 100644 xds/internal/server/rds_handler.go create mode 100644 xds/internal/server/rds_handler_test.go create mode 100644 xds/internal/test/e2e/README.md create mode 100644 xds/internal/test/e2e/controlplane.go create mode 100644 xds/internal/test/e2e/e2e.go create mode 100644 xds/internal/test/e2e/e2e_test.go rename internal/credentials/syscallconn_appengine.go => xds/internal/test/e2e/e2e_utils.go (50%) create mode 100755 xds/internal/test/e2e/run.sh create mode 100644 xds/internal/test/xds_client_affinity_test.go create mode 100644 xds/internal/test/xds_security_config_nack_test.go create mode 100644 xds/internal/test/xds_server_serving_mode_test.go create mode 100644 xds/internal/xdsclient/attributes.go rename xds/internal/{client => xdsclient}/bootstrap/bootstrap.go (88%) rename xds/internal/{client => xdsclient}/bootstrap/bootstrap_test.go (95%) rename xds/internal/{client => xdsclient}/bootstrap/logging.go (100%) rename xds/internal/{client => xdsclient}/callback.go (60%) create mode 100644 xds/internal/xdsclient/cds_test.go rename xds/internal/{client => xdsclient}/client.go (68%) rename xds/internal/{client => xdsclient}/client_test.go (78%) rename xds/internal/{client => xdsclient}/dump.go (99%) rename xds/internal/{client/tests => xdsclient}/dump_test.go (76%) rename xds/internal/{client => xdsclient}/eds_test.go (82%) rename xds/internal/{client => xdsclient}/errors.go (98%) create mode 100644 xds/internal/xdsclient/filter_chain.go create mode 100644 xds/internal/xdsclient/filter_chain_test.go create mode 100644 xds/internal/xdsclient/lds_test.go rename xds/internal/{client => xdsclient}/load/reporter.go (100%) rename xds/internal/{client => xdsclient}/load/store.go (100%) rename xds/internal/{client => xdsclient}/load/store_test.go (100%) rename xds/internal/{client => xdsclient}/loadreport.go (98%) rename xds/internal/{client/tests => xdsclient}/loadreport_test.go (94%) rename xds/internal/{client => xdsclient}/logging.go (98%) create mode 100644 xds/internal/xdsclient/matcher.go rename xds/internal/{resolver => xdsclient}/matcher_path.go (97%) rename xds/internal/{resolver => xdsclient}/matcher_path_test.go (99%) rename xds/internal/{resolver => xdsclient}/matcher_test.go (55%) rename xds/internal/{client => xdsclient}/rds_test.go (55%) create mode 100644 xds/internal/xdsclient/requests_counter.go rename xds/internal/{client => xdsclient}/requests_counter_test.go (76%) create mode 100644 xds/internal/xdsclient/singleton.go rename xds/internal/{client => xdsclient}/transport_helper.go (98%) rename xds/internal/{client => xdsclient}/v2/ack_test.go (99%) rename xds/internal/{client => xdsclient}/v2/cds_test.go (86%) rename xds/internal/{client => xdsclient}/v2/client.go (82%) rename xds/internal/{client => xdsclient}/v2/client_test.go (88%) rename xds/internal/{client => xdsclient}/v2/eds_test.go (85%) rename xds/internal/{client => xdsclient}/v2/lds_test.go (81%) rename xds/internal/{client => xdsclient}/v2/loadreport.go (88%) rename xds/internal/{client => xdsclient}/v2/rds_test.go (77%) rename xds/internal/{client => xdsclient}/v3/client.go (82%) rename xds/internal/{client => xdsclient}/v3/loadreport.go (88%) rename xds/internal/{client => xdsclient}/watchers.go (95%) rename xds/internal/{client => xdsclient}/watchers_cluster_test.go (58%) rename xds/internal/{client => xdsclient}/watchers_endpoints_test.go (57%) create mode 100644 xds/internal/xdsclient/watchers_listener_test.go rename xds/internal/{client => xdsclient}/watchers_route_test.go (54%) create mode 100644 xds/internal/xdsclient/xds.go rename xds/internal/{client/tests/client_test.go => xdsclient/xdsclient_test.go} (92%) create mode 100644 xds/server_options.go diff --git a/.github/lock.yml b/.github/lock.yml deleted file mode 100644 index 78f7b19b71d..00000000000 --- a/.github/lock.yml +++ /dev/null @@ -1,2 +0,0 @@ -daysUntilLock: 180 -lockComment: false diff --git a/.github/mergeable.yml b/.github/mergeable.yml index d647dafb7ab..187de98277b 100644 --- a/.github/mergeable.yml +++ b/.github/mergeable.yml @@ -5,32 +5,17 @@ mergeable: - do: label must_include: regex: '^Type:' - fail: - - do: checks - status: 'failure' - payload: - title: 'Need an appropriate "Type:" label' - summary: 'Need an appropriate "Type:" label' - - when: pull_request.* - # This validator requires either the "no release notes" label OR a "Release" milestone - # to be considered successful. However, validators "pass" in mergeable only if all - # checks pass. So it is implemented in reverse. - # I.e.: !(!no_relnotes && !release_milestone) ==> no_relnotes || release_milestone - # If both validators pass, then it is considered a failure, and if either fails, it is - # considered a success. - validate: - - do: label - must_exclude: - regex: '^no release notes$' + - do: description + must_include: + # Allow: + # RELEASE NOTES: none (case insensitive) + # + # RELEASE NOTES: N/A (case insensitive) + # + # RELEASE NOTES: + # * + regex: '^RELEASE NOTES:\s*([Nn][Oo][Nn][Ee]|[Nn]/[Aa]|\n(\*|-)\s*.+)$' + regex_flag: 'm' - do: milestone - must_exclude: + must_include: regex: 'Release$' - pass: - - do: checks - status: 'failure' # fail on pass - payload: - title: 'Need Release milestone or "no release notes" label' - summary: 'Need Release milestone or "no release notes" label' - fail: - - do: checks - status: 'success' # pass on fail diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index 8f69dbc4fe8..00000000000 --- a/.github/stale.yml +++ /dev/null @@ -1,58 +0,0 @@ -# Configuration for probot-stale - https://github.com/probot/stale - -# Number of days of inactivity before an Issue or Pull Request becomes stale -daysUntilStale: 6 - -# Number of days of inactivity before an Issue or Pull Request with the stale label is closed. -# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale. -daysUntilClose: 7 - -# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled) -onlyLabels: - - "Status: Requires Reporter Clarification" - -# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable -exemptLabels: [] - -# Set to true to ignore issues in a project (defaults to false) -exemptProjects: false - -# Set to true to ignore issues in a milestone (defaults to false) -exemptMilestones: false - -# Set to true to ignore issues with an assignee (defaults to false) -exemptAssignees: false - -# Label to use when marking as stale -staleLabel: "stale" - -# Comment to post when marking as stale. Set to `false` to disable -markComment: > - This issue is labeled as requiring an update from the reporter, and no update has been received - after 6 days. If no update is provided in the next 7 days, this issue will be automatically closed. - -# Comment to post when removing the stale label. -# unmarkComment: > -# Your comment here. - -# Comment to post when closing a stale Issue or Pull Request. -# closeComment: > -# Your comment here. - -# Limit the number of actions per hour, from 1-30. Default is 30 -limitPerRun: 1 - -# Limit to only `issues` or `pulls` -# only: issues - -# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls': -# pulls: -# daysUntilStale: 30 -# markComment: > -# This pull request has been automatically marked as stale because it has not had -# recent activity. It will be closed if no further activity occurs. Thank you -# for your contributions. - -# issues: -# exemptLabels: -# - confirmed diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 0c3806bdc23..2a73b94079c 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -3,16 +3,20 @@ name: "CodeQL" on: push: branches: [ master ] - pull_request: - # The branches below must be a subset of the branches above - branches: [ master ] schedule: - cron: '24 20 * * 3' +permissions: + contents: read + security-events: write + pull-requests: read + actions: read + jobs: analyze: name: Analyze runs-on: ubuntu-latest + timeout-minutes: 30 strategy: fail-fast: false diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml new file mode 100644 index 00000000000..5f49c7900a3 --- /dev/null +++ b/.github/workflows/lock.yml @@ -0,0 +1,20 @@ +name: 'Lock Threads' + +on: + workflow_dispatch: + schedule: + - cron: '22 1 * * *' + +permissions: + issues: write + pull-requests: write + +jobs: + lock: + runs-on: ubuntu-latest + steps: + - uses: dessant/lock-threads@v2 + with: + github-token: ${{ github.token }} + issue-lock-inactive-days: 180 + pr-lock-inactive-days: 180 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000000..5e01a1e70c4 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,30 @@ +name: Stale bot + +on: + workflow_dispatch: + schedule: + - cron: "44 */2 * * *" + +jobs: + stale: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + + steps: + - uses: actions/stale@v4 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + days-before-stale: 6 + days-before-close: 7 + only-labels: 'Status: Requires Reporter Clarification' + stale-issue-label: 'stale' + stale-pr-label: 'stale' + operations-per-run: 999 + stale-issue-message: > + This issue is labeled as requiring an update from the reporter, and no update has been received + after 6 days. If no update is provided in the next 7 days, this issue will be automatically closed. + stale-pr-message: > + This PR is labeled as requiring an update from the reporter, and no update has been received + after 6 days. If no update is provided in the next 7 days, this issue will be automatically closed. diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 378e2846676..b687fdb3dc3 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -7,6 +7,9 @@ on: schedule: - cron: 0 0 * * * # daily at 00:00 +permissions: + contents: read + # Always force the use of Go modules env: GO111MODULE: on @@ -15,6 +18,7 @@ jobs: # Check generated protos match their source repos (optional for PRs). vet-proto: runs-on: ubuntu-latest + timeout-minutes: 20 steps: # Setup the environment. - name: Setup Go @@ -34,77 +38,80 @@ jobs: env: VET_SKIP_PROTO: 1 runs-on: ubuntu-latest + timeout-minutes: 20 strategy: matrix: include: - - type: vet - goversion: 1.15 - - type: race - goversion: 1.15 - - type: 386 - goversion: 1.15 - - type: retry - goversion: 1.15 + - type: vet+tests + goversion: 1.17 + + - type: tests + goversion: 1.17 + testflags: -race + - type: extras - goversion: 1.15 + goversion: 1.17 + - type: tests - goversion: 1.14 + goversion: 1.17 + goarch: 386 + + - type: tests + goversion: 1.17 + goarch: arm64 + - type: tests - goversion: 1.13 - - type: tests111 - goversion: 1.11 # Keep until interop tests no longer require Go1.11 + goversion: 1.16 + + - type: tests + goversion: 1.15 steps: # Setup the environment. - - name: Setup GOARCH=386 - if: ${{ matrix.type == '386' }} - run: echo "GOARCH=386" >> $GITHUB_ENV - - name: Setup RETRY - if: ${{ matrix.type == 'retry' }} - run: echo "GRPC_GO_RETRY=on" >> $GITHUB_ENV + - name: Setup GOARCH + if: matrix.goarch != '' + run: echo "GOARCH=${{ matrix.goarch }}" >> $GITHUB_ENV + + - name: Setup qemu emulator + if: matrix.goarch == 'arm64' + # setup qemu-user-static emulator and register it with binfmt_misc so that aarch64 binaries + # are automatically executed using qemu. + run: docker run --rm --privileged multiarch/qemu-user-static:5.2.0-2 --reset --credential yes --persistent yes + + - name: Setup GRPC environment + if: matrix.grpcenv != '' + run: echo "${{ matrix.grpcenv }}" >> $GITHUB_ENV + - name: Setup Go uses: actions/setup-go@v2 with: go-version: ${{ matrix.goversion }} + - name: Checkout repo uses: actions/checkout@v2 # Only run vet for 'vet' runs. - name: Run vet.sh - if: ${{ matrix.type == 'vet' }} + if: startsWith(matrix.type, 'vet') run: ./vet.sh -install && ./vet.sh - # Main tests run for everything except when testing "extras", the race - # detector and Go1.11 (where we run a reduced set of tests). + # Main tests run for everything except when testing "extras" + # (where we run a reduced set of tests). - name: Run tests - if: ${{ matrix.type != 'extras' && matrix.type != 'race' && matrix.type != 'tests111' }} + if: contains(matrix.type, 'tests') run: | go version - go test -cpu 1,4 -timeout 7m google.golang.org/grpc/... + go test ${{ matrix.testflags }} -cpu 1,4 -timeout 7m google.golang.org/grpc/... + cd ${GITHUB_WORKSPACE}/security/advancedtls && go test ${{ matrix.testflags }} -timeout 2m google.golang.org/grpc/security/advancedtls/... + cd ${GITHUB_WORKSPACE}/security/authorization && go test ${{ matrix.testflags }} -timeout 2m google.golang.org/grpc/security/authorization/... - # Race detector tests - - name: Run test race - if: ${{ matrix.TYPE == 'race' }} - run: | - go version - go test -race -cpu 1,4 -timeout 7m google.golang.org/grpc/... # Non-core gRPC tests (examples, interop, etc) - name: Run extras tests - if: ${{ matrix.TYPE == 'extras' }} + if: matrix.type == 'extras' run: | go version examples/examples_test.sh security/advancedtls/examples/examples_test.sh interop/interop_test.sh - cd ${GITHUB_WORKSPACE}/security/advancedtls && go test -cpu 1,4 -timeout 7m google.golang.org/grpc/security/advancedtls/... - cd ${GITHUB_WORKSPACE}/security/authorization && go test -cpu 1,4 -timeout 7m google.golang.org/grpc/security/authorization/... - - # Reduced set of tests for Go 1.11 - - name: Run Go1.11 tests - if: ${{ matrix.type == 'tests111' }} - run: | - go version - tests=$(find ${GITHUB_WORKSPACE} -name '*_test.go' | xargs -n1 dirname | sort -u | sed "s:^${GITHUB_WORKSPACE}:.:" | sed "s:\/$::" | grep -v ^./security | grep -v ^./credentials/sts | grep -v ^./credentials/tls/certprovider | grep -v ^./credentials/xds | grep -v ^./xds ) - echo "Running tests for " ${tests} - go test -cpu 1,4 -timeout 7m ${tests} + xds/internal/test/e2e/run.sh diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 5847d94e551..00000000000 --- a/.travis.yml +++ /dev/null @@ -1,42 +0,0 @@ -language: go - -matrix: - include: - - go: 1.14.x - env: VET=1 GO111MODULE=on - - go: 1.14.x - env: RACE=1 GO111MODULE=on - - go: 1.14.x - env: RUN386=1 - - go: 1.14.x - env: GRPC_GO_RETRY=on - - go: 1.14.x - env: TESTEXTRAS=1 - - go: 1.13.x - env: GO111MODULE=on - - go: 1.12.x - env: GO111MODULE=on - - go: 1.11.x # Keep until interop tests no longer require Go1.11 - env: GO111MODULE=on - -go_import_path: google.golang.org/grpc - -before_install: - - if [[ "${GO111MODULE}" = "on" ]]; then mkdir "${HOME}/go"; export GOPATH="${HOME}/go"; fi - - if [[ -n "${RUN386}" ]]; then export GOARCH=386; fi - - if [[ "${TRAVIS_EVENT_TYPE}" = "cron" && -z "${RUN386}" ]]; then RACE=1; fi - - if [[ "${TRAVIS_EVENT_TYPE}" != "cron" ]]; then export VET_SKIP_PROTO=1; fi - -install: - - try3() { eval "$*" || eval "$*" || eval "$*"; } - - try3 'if [[ "${GO111MODULE}" = "on" ]]; then go mod download; else make testdeps; fi' - - if [[ -n "${GAE}" ]]; then source ./install_gae.sh; make testappenginedeps; fi - - if [[ -n "${VET}" ]]; then ./vet.sh -install; fi - -script: - - set -e - - if [[ -n "${TESTEXTRAS}" ]]; then examples/examples_test.sh; security/advancedtls/examples/examples_test.sh; interop/interop_test.sh; make testsubmodule; exit 0; fi - - if [[ -n "${VET}" ]]; then ./vet.sh; fi - - if [[ -n "${GAE}" ]]; then make testappengine; exit 0; fi - - if [[ -n "${RACE}" ]]; then make testrace; exit 0; fi - - make test diff --git a/Documentation/server-reflection-tutorial.md b/Documentation/server-reflection-tutorial.md index b1781fa68dc..9f26656f22b 100644 --- a/Documentation/server-reflection-tutorial.md +++ b/Documentation/server-reflection-tutorial.md @@ -58,7 +58,7 @@ $ go run examples/features/reflection/server/main.go Open a new terminal and make sure you are in the directory where grpc_cli lives: ```sh -$ cd /bins/opt +$ cd /bins/opt ``` ### List services diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 093c82b3afe..c6672c0a3ef 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -8,17 +8,18 @@ See [CONTRIBUTING.md](https://github.com/grpc/grpc-community/blob/master/CONTRIB for general contribution guidelines. ## Maintainers (in alphabetical order) -- [canguler](https://github.com/canguler), Google LLC + - [cesarghali](https://github.com/cesarghali), Google LLC - [dfawley](https://github.com/dfawley), Google LLC - [easwars](https://github.com/easwars), Google LLC -- [jadekler](https://github.com/jadekler), Google LLC - [menghanl](https://github.com/menghanl), Google LLC - [srini100](https://github.com/srini100), Google LLC ## Emeritus Maintainers (in alphabetical order) - [adelez](https://github.com/adelez), Google LLC +- [canguler](https://github.com/canguler), Google LLC - [iamqizhao](https://github.com/iamqizhao), Google LLC +- [jadekler](https://github.com/jadekler), Google LLC - [jtattermusch](https://github.com/jtattermusch), Google LLC - [lyuxuan](https://github.com/lyuxuan), Google LLC - [makmukhi](https://github.com/makmukhi), Google LLC diff --git a/Makefile b/Makefile index 1f0722f1624..1f8960922b3 100644 --- a/Makefile +++ b/Makefile @@ -41,8 +41,6 @@ vetdeps: clean \ proto \ test \ - testappengine \ - testappenginedeps \ testrace \ vet \ vetdeps diff --git a/NOTICE.txt b/NOTICE.txt new file mode 100644 index 00000000000..530197749e9 --- /dev/null +++ b/NOTICE.txt @@ -0,0 +1,13 @@ +Copyright 2014 gRPC authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md index 3949a683fb5..0e6ae69a584 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,6 @@ errors. [Go module]: https://github.com/golang/go/wiki/Modules [gRPC]: https://grpc.io [Go gRPC docs]: https://grpc.io/docs/languages/go -[Performance benchmark]: https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5652536396611584&widget=490377658&container=1286539696 +[Performance benchmark]: https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5180705743044608 [quick start]: https://grpc.io/docs/languages/go/quickstart [go-releases]: https://golang.org/doc/devel/release.html diff --git a/admin/admin.go b/admin/admin.go new file mode 100644 index 00000000000..803a4b93534 --- /dev/null +++ b/admin/admin.go @@ -0,0 +1,58 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package admin provides a convenient method for registering a collection of +// administration services to a gRPC server. The services registered are: +// +// - Channelz: https://github.com/grpc/proposal/blob/master/A14-channelz.md +// +// - CSDS: https://github.com/grpc/proposal/blob/master/A40-csds-support.md +// +// Experimental +// +// Notice: All APIs in this package are experimental and may be removed in a +// later release. +package admin + +import ( + "google.golang.org/grpc" + channelzservice "google.golang.org/grpc/channelz/service" + internaladmin "google.golang.org/grpc/internal/admin" +) + +func init() { + // Add a list of default services to admin here. Optional services, like + // CSDS, will be added by other packages. + internaladmin.AddService(func(registrar grpc.ServiceRegistrar) (func(), error) { + channelzservice.RegisterChannelzServiceToServer(registrar) + return nil, nil + }) +} + +// Register registers the set of admin services to the given server. +// +// The returned cleanup function should be called to clean up the resources +// allocated for the service handlers after the server is stopped. +// +// Note that if `s` is not a *grpc.Server or a *xds.GRPCServer, CSDS will not be +// registered because CSDS generated code is old and doesn't support interface +// `grpc.ServiceRegistrar`. +// https://github.com/envoyproxy/go-control-plane/issues/403 +func Register(s grpc.ServiceRegistrar) (cleanup func(), _ error) { + return internaladmin.Register(s) +} diff --git a/security/advancedtls/sni_appengine.go b/admin/admin_test.go similarity index 63% rename from security/advancedtls/sni_appengine.go rename to admin/admin_test.go index fffbb0107dd..0ee4aade0f3 100644 --- a/security/advancedtls/sni_appengine.go +++ b/admin/admin_test.go @@ -1,8 +1,6 @@ -// +build appengine - /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +16,19 @@ * */ -package advancedtls +package admin_test import ( - "crypto/tls" + "testing" + + "google.golang.org/grpc/admin/test" + "google.golang.org/grpc/codes" ) -// buildGetCertificates is a no-op for appengine builds. -func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) { - return nil, nil +func TestRegisterNoCSDS(t *testing.T) { + test.RunRegisterTests(t, test.ExpectedStatusCodes{ + ChannelzCode: codes.OK, + // CSDS is not registered because xDS isn't imported. + CSDSCode: codes.Unimplemented, + }) } diff --git a/admin/test/admin_test.go b/admin/test/admin_test.go new file mode 100644 index 00000000000..f0f784bfdf3 --- /dev/null +++ b/admin/test/admin_test.go @@ -0,0 +1,38 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// This file has the same content as admin_test.go, difference is that this is +// in another package, and it imports "xds", so we can test that csds is +// registered when xds is imported. + +package test_test + +import ( + "testing" + + "google.golang.org/grpc/admin/test" + "google.golang.org/grpc/codes" + _ "google.golang.org/grpc/xds" +) + +func TestRegisterWithCSDS(t *testing.T) { + test.RunRegisterTests(t, test.ExpectedStatusCodes{ + ChannelzCode: codes.OK, + CSDSCode: codes.OK, + }) +} diff --git a/admin/test/utils.go b/admin/test/utils.go new file mode 100644 index 00000000000..1add8afa824 --- /dev/null +++ b/admin/test/utils.go @@ -0,0 +1,114 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package test contains test only functions for package admin. It's used by +// admin/admin_test.go and admin/test/admin_test.go. +package test + +import ( + "context" + "net" + "testing" + "time" + + v3statusgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + v3statuspb "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/admin" + channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/xds" + "google.golang.org/grpc/status" +) + +const ( + defaultTestTimeout = 10 * time.Second +) + +// ExpectedStatusCodes contains the expected status code for each RPC (can be +// OK). +type ExpectedStatusCodes struct { + ChannelzCode codes.Code + CSDSCode codes.Code +} + +// RunRegisterTests makes a client, runs the RPCs, and compares the status +// codes. +func RunRegisterTests(t *testing.T, ec ExpectedStatusCodes) { + nodeID := uuid.New().String() + bootstrapCleanup, err := xds.SetupBootstrapFile(xds.BootstrapOptions{ + Version: xds.TransportV3, + NodeID: nodeID, + ServerURI: "no.need.for.a.server", + }) + if err != nil { + t.Fatal(err) + } + defer bootstrapCleanup() + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("cannot create listener: %v", err) + } + + server := grpc.NewServer() + defer server.Stop() + cleanup, err := admin.Register(server) + if err != nil { + t.Fatalf("failed to register admin: %v", err) + } + defer cleanup() + go func() { + server.Serve(lis) + }() + + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("cannot connect to server: %v", err) + } + + t.Run("channelz", func(t *testing.T) { + if err := RunChannelz(conn); status.Code(err) != ec.ChannelzCode { + t.Fatalf("%s RPC failed with error %v, want code %v", "channelz", err, ec.ChannelzCode) + } + }) + t.Run("csds", func(t *testing.T) { + if err := RunCSDS(conn); status.Code(err) != ec.CSDSCode { + t.Fatalf("%s RPC failed with error %v, want code %v", "CSDS", err, ec.CSDSCode) + } + }) +} + +// RunChannelz makes a channelz RPC. +func RunChannelz(conn *grpc.ClientConn) error { + c := channelzpb.NewChannelzClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err := c.GetTopChannels(ctx, &channelzpb.GetTopChannelsRequest{}, grpc.WaitForReady(true)) + return err +} + +// RunCSDS makes a CSDS RPC. +func RunCSDS(conn *grpc.ClientConn) error { + c := v3statusgrpc.NewClientStatusDiscoveryServiceClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err := c.FetchClientStatus(ctx, &v3statuspb.ClientStatusRequest{}, grpc.WaitForReady(true)) + return err +} diff --git a/authz/rbac_translator.go b/authz/rbac_translator.go new file mode 100644 index 00000000000..039d76bc99d --- /dev/null +++ b/authz/rbac_translator.go @@ -0,0 +1,306 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package authz exposes methods to manage authorization within gRPC. +// +// Experimental +// +// Notice: This package is EXPERIMENTAL and may be changed or removed +// in a later release. +package authz + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" +) + +type header struct { + Key string + Values []string +} + +type peer struct { + Principals []string +} + +type request struct { + Paths []string + Headers []header +} + +type rule struct { + Name string + Source peer + Request request +} + +// Represents the SDK authorization policy provided by user. +type authorizationPolicy struct { + Name string + DenyRules []rule `json:"deny_rules"` + AllowRules []rule `json:"allow_rules"` +} + +func principalOr(principals []*v3rbacpb.Principal) *v3rbacpb.Principal { + return &v3rbacpb.Principal{ + Identifier: &v3rbacpb.Principal_OrIds{ + OrIds: &v3rbacpb.Principal_Set{ + Ids: principals, + }, + }, + } +} + +func permissionOr(permission []*v3rbacpb.Permission) *v3rbacpb.Permission { + return &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_OrRules{ + OrRules: &v3rbacpb.Permission_Set{ + Rules: permission, + }, + }, + } +} + +func permissionAnd(permission []*v3rbacpb.Permission) *v3rbacpb.Permission { + return &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_AndRules{ + AndRules: &v3rbacpb.Permission_Set{ + Rules: permission, + }, + }, + } +} + +func getStringMatcher(value string) *v3matcherpb.StringMatcher { + switch { + case value == "*": + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{}, + } + case strings.HasSuffix(value, "*"): + prefix := strings.TrimSuffix(value, "*") + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: prefix}, + } + case strings.HasPrefix(value, "*"): + suffix := strings.TrimPrefix(value, "*") + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: suffix}, + } + default: + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: value}, + } + } +} + +func getHeaderMatcher(key, value string) *v3routepb.HeaderMatcher { + switch { + case value == "*": + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{}, + } + case strings.HasSuffix(value, "*"): + prefix := strings.TrimSuffix(value, "*") + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: prefix}, + } + case strings.HasPrefix(value, "*"): + suffix := strings.TrimPrefix(value, "*") + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: suffix}, + } + default: + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: value}, + } + } +} + +func parsePrincipalNames(principalNames []string) []*v3rbacpb.Principal { + ps := make([]*v3rbacpb.Principal, 0, len(principalNames)) + for _, principalName := range principalNames { + newPrincipalName := &v3rbacpb.Principal{ + Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{ + PrincipalName: getStringMatcher(principalName), + }, + }} + ps = append(ps, newPrincipalName) + } + return ps +} + +func parsePeer(source peer) (*v3rbacpb.Principal, error) { + if len(source.Principals) > 0 { + return principalOr(parsePrincipalNames(source.Principals)), nil + } + return &v3rbacpb.Principal{ + Identifier: &v3rbacpb.Principal_Any{ + Any: true, + }, + }, nil +} + +func parsePaths(paths []string) []*v3rbacpb.Permission { + ps := make([]*v3rbacpb.Permission, 0, len(paths)) + for _, path := range paths { + newPath := &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{ + Rule: &v3matcherpb.PathMatcher_Path{Path: getStringMatcher(path)}}}} + ps = append(ps, newPath) + } + return ps +} + +func parseHeaderValues(key string, values []string) []*v3rbacpb.Permission { + vs := make([]*v3rbacpb.Permission, 0, len(values)) + for _, value := range values { + newHeader := &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_Header{ + Header: getHeaderMatcher(key, value)}} + vs = append(vs, newHeader) + } + return vs +} + +var unsupportedHeaders = map[string]bool{ + "host": true, + "connection": true, + "keep-alive": true, + "proxy-authenticate": true, + "proxy-authorization": true, + "te": true, + "trailer": true, + "transfer-encoding": true, + "upgrade": true, +} + +func unsupportedHeader(key string) bool { + return key[0] == ':' || strings.HasPrefix(key, "grpc-") || unsupportedHeaders[key] +} + +func parseHeaders(headers []header) ([]*v3rbacpb.Permission, error) { + hs := make([]*v3rbacpb.Permission, 0, len(headers)) + for i, header := range headers { + if header.Key == "" { + return nil, fmt.Errorf(`"headers" %d: "key" is not present`, i) + } + header.Key = strings.ToLower(header.Key) + if unsupportedHeader(header.Key) { + return nil, fmt.Errorf(`"headers" %d: unsupported "key" %s`, i, header.Key) + } + if len(header.Values) == 0 { + return nil, fmt.Errorf(`"headers" %d: "values" is not present`, i) + } + values := parseHeaderValues(header.Key, header.Values) + hs = append(hs, permissionOr(values)) + } + return hs, nil +} + +func parseRequest(request request) (*v3rbacpb.Permission, error) { + var and []*v3rbacpb.Permission + if len(request.Paths) > 0 { + and = append(and, permissionOr(parsePaths(request.Paths))) + } + if len(request.Headers) > 0 { + headers, err := parseHeaders(request.Headers) + if err != nil { + return nil, err + } + and = append(and, permissionAnd(headers)) + } + if len(and) > 0 { + return permissionAnd(and), nil + } + return &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_Any{ + Any: true, + }, + }, nil +} + +func parseRules(rules []rule, prefixName string) (map[string]*v3rbacpb.Policy, error) { + policies := make(map[string]*v3rbacpb.Policy) + for i, rule := range rules { + if rule.Name == "" { + return policies, fmt.Errorf(`%d: "name" is not present`, i) + } + principal, err := parsePeer(rule.Source) + if err != nil { + return nil, fmt.Errorf("%d: %v", i, err) + } + permission, err := parseRequest(rule.Request) + if err != nil { + return nil, fmt.Errorf("%d: %v", i, err) + } + policyName := prefixName + "_" + rule.Name + policies[policyName] = &v3rbacpb.Policy{ + Principals: []*v3rbacpb.Principal{principal}, + Permissions: []*v3rbacpb.Permission{permission}, + } + } + return policies, nil +} + +// translatePolicy translates SDK authorization policy in JSON format to two +// Envoy RBAC polices (deny followed by allow policy) or only one Envoy RBAC +// allow policy. If the input policy cannot be parsed or is invalid, an error +// will be returned. +func translatePolicy(policyStr string) ([]*v3rbacpb.RBAC, error) { + policy := &authorizationPolicy{} + d := json.NewDecoder(bytes.NewReader([]byte(policyStr))) + d.DisallowUnknownFields() + if err := d.Decode(policy); err != nil { + return nil, fmt.Errorf("failed to unmarshal policy: %v", err) + } + if policy.Name == "" { + return nil, fmt.Errorf(`"name" is not present`) + } + if len(policy.AllowRules) == 0 { + return nil, fmt.Errorf(`"allow_rules" is not present`) + } + rbacs := make([]*v3rbacpb.RBAC, 0, 2) + if len(policy.DenyRules) > 0 { + denyPolicies, err := parseRules(policy.DenyRules, policy.Name) + if err != nil { + return nil, fmt.Errorf(`"deny_rules" %v`, err) + } + denyRBAC := &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_DENY, + Policies: denyPolicies, + } + rbacs = append(rbacs, denyRBAC) + } + allowPolicies, err := parseRules(policy.AllowRules, policy.Name) + if err != nil { + return nil, fmt.Errorf(`"allow_rules" %v`, err) + } + allowRBAC := &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_ALLOW, Policies: allowPolicies} + return append(rbacs, allowRBAC), nil +} diff --git a/authz/rbac_translator_test.go b/authz/rbac_translator_test.go new file mode 100644 index 00000000000..9a883e9d78d --- /dev/null +++ b/authz/rbac_translator_test.go @@ -0,0 +1,273 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authz + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/testing/protocmp" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" +) + +func TestTranslatePolicy(t *testing.T) { + tests := map[string]struct { + authzPolicy string + wantErr string + wantPolicies []*v3rbacpb.RBAC + }{ + "valid policy": { + authzPolicy: `{ + "name": "authz", + "deny_rules": [ + { + "name": "deny_policy_1", + "source": { + "principals":[ + "spiffe://foo.abc", + "spiffe://bar*", + "*baz", + "spiffe://abc.*.com" + ] + } + }], + "allow_rules": [ + { + "name": "allow_policy_1", + "source": { + "principals":["*"] + }, + "request": { + "paths": ["path-foo*"] + } + }, + { + "name": "allow_policy_2", + "request": { + "paths": [ + "path-bar", + "*baz" + ], + "headers": [ + { + "key": "key-1", + "values": ["foo", "*bar"] + }, + { + "key": "key-2", + "values": ["baz*"] + } + ] + } + }] + }`, + wantPolicies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "authz_deny_policy_1": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://foo.abc"}, + }}, + }}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "spiffe://bar"}, + }}, + }}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}, + }}, + }}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://abc.*.com"}, + }}, + }}, + }, + }}}, + }, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + }, + }, + }, + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "authz_allow_policy_1": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{}, + }}, + }}, + }, + }}}, + }, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "path-foo"}, + }}}, + }}, + }, + }}}, + }, + }}}, + }, + }, + "authz_allow_policy_2": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "path-bar"}, + }}}, + }}, + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}, + }}}, + }}, + }, + }}}, + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "foo"}, + }, + }}, + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: "bar"}, + }, + }}, + }, + }}}, + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-2", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "baz"}, + }, + }}, + }, + }}}, + }, + }}}, + }, + }}}, + }, + }, + }, + }, + }, + }, + "unknown field": { + authzPolicy: `{"random": 123}`, + wantErr: "failed to unmarshal policy", + }, + "missing name field": { + authzPolicy: `{}`, + wantErr: `"name" is not present`, + }, + "invalid field type": { + authzPolicy: `{"name": 123}`, + wantErr: "failed to unmarshal policy", + }, + "missing allow rules field": { + authzPolicy: `{"name": "authz-foo"}`, + wantErr: `"allow_rules" is not present`, + }, + "missing rule name field": { + authzPolicy: `{ + "name": "authz-foo", + "allow_rules": [{}] + }`, + wantErr: `"allow_rules" 0: "name" is not present`, + }, + "missing header key": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [{ + "name": "allow_policy_1", + "request": {"headers":[{"key":"key-a", "values": ["value-a"]}, {}]} + }] + }`, + wantErr: `"allow_rules" 0: "headers" 1: "key" is not present`, + }, + "missing header values": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [{ + "name": "allow_policy_1", + "request": {"headers":[{"key":"key-a"}]} + }] + }`, + wantErr: `"allow_rules" 0: "headers" 0: "values" is not present`, + }, + "unsupported header": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [{ + "name": "allow_policy_1", + "request": {"headers":[{"key":":method", "values":["GET"]}]} + }] + }`, + wantErr: `"allow_rules" 0: "headers" 0: unsupported "key" :method`, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + gotPolicies, gotErr := translatePolicy(test.authzPolicy) + if gotErr != nil && !strings.HasPrefix(gotErr.Error(), test.wantErr) { + t.Fatalf("unexpected error\nwant:%v\ngot:%v", test.wantErr, gotErr) + } + if diff := cmp.Diff(gotPolicies, test.wantPolicies, protocmp.Transform()); diff != "" { + t.Fatalf("unexpected policy\ndiff (-want +got):\n%s", diff) + } + }) + } +} diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go new file mode 100644 index 00000000000..093b2bb437d --- /dev/null +++ b/authz/sdk_end2end_test.go @@ -0,0 +1,548 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authz_test + +import ( + "context" + "io" + "io/ioutil" + "net" + "os" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/authz" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + pb "google.golang.org/grpc/test/grpc_testing" +) + +type testServer struct { + pb.UnimplementedTestServiceServer +} + +func (s *testServer) UnaryCall(ctx context.Context, req *pb.SimpleRequest) (*pb.SimpleResponse, error) { + return &pb.SimpleResponse{}, nil +} + +func (s *testServer) StreamingInputCall(stream pb.TestService_StreamingInputCallServer) error { + for { + _, err := stream.Recv() + if err == io.EOF { + return stream.SendAndClose(&pb.StreamingInputCallResponse{}) + } + if err != nil { + return err + } + } +} + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +var sdkTests = map[string]struct { + authzPolicy string + md metadata.MD + wantStatus *status.Status +}{ + "DeniesRpcMatchInDenyNoMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_StreamingOutputCall", + "request": { + "paths": + [ + "/grpc.testing.TestService/StreamingOutputCall" + ] + } + } + ], + "deny_rules": + [ + { + "name": "deny_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/UnaryCall", + "/grpc.testing.TestService/StreamingInputCall" + ], + "headers": + [ + { + "key": "key-abc", + "values": + [ + "val-abc", + "val-def" + ] + } + ] + } + } + ] + }`, + md: metadata.Pairs("key-abc", "val-abc"), + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, + "DeniesRpcMatchInDenyAndAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/*" + ] + } + } + ], + "deny_rules": + [ + { + "name": "deny_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/*" + ] + } + } + ] + }`, + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, + "AllowsRpcNoMatchInDenyMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ], + "deny_rules": + [ + { + "name": "deny_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/UnaryCall", + "/grpc.testing.TestService/StreamingInputCall" + ], + "headers": + [ + { + "key": "key-abc", + "values": + [ + "val-abc", + "val-def" + ] + } + ] + } + } + ] + }`, + md: metadata.Pairs("key-xyz", "val-xyz"), + wantStatus: status.New(codes.OK, ""), + }, + "DeniesRpcNoMatchInDenyAndAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_some_user", + "source": { + "principals": + [ + "some_user" + ] + } + } + ], + "deny_rules": + [ + { + "name": "deny_StreamingOutputCall", + "request": { + "paths": + [ + "/grpc.testing.TestService/StreamingOutputCall" + ] + } + } + ] + }`, + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, + "AllowsRpcEmptyDenyMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_UnaryCall", + "request": + { + "paths": + [ + "/grpc.testing.TestService/UnaryCall" + ] + } + }, + { + "name": "allow_StreamingInputCall", + "request": + { + "paths": + [ + "/grpc.testing.TestService/StreamingInputCall" + ] + } + } + ] + }`, + wantStatus: status.New(codes.OK, ""), + }, + "DeniesRpcEmptyDenyNoMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_StreamingOutputCall", + "request": + { + "paths": + [ + "/grpc.testing.TestService/StreamingOutputCall" + ] + } + } + ] + }`, + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, +} + +func (s) TestSDKStaticPolicyEnd2End(t *testing.T) { + for name, test := range sdkTests { + t.Run(name, func(t *testing.T) { + // Start a gRPC server with SDK unary and stream server interceptors. + i, _ := authz.NewStatic(test.authzPolicy) + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, test.md) + + // Verifying authorization decision for Unary RPC. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + + // Verifying authorization decision for Streaming RPC. + stream, err := client.StreamingInputCall(ctx) + if err != nil { + t.Fatalf("failed StreamingInputCall err: %v", err) + } + req := &pb.StreamingInputCallRequest{ + Payload: &pb.Payload{ + Body: []byte("hi"), + }, + } + if err := stream.Send(req); err != nil && err != io.EOF { + t.Fatalf("failed stream.Send err: %v", err) + } + _, err = stream.CloseAndRecv() + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + }) + } +} + +func (s) TestSDKFileWatcherEnd2End(t *testing.T) { + for name, test := range sdkTests { + t.Run(name, func(t *testing.T) { + file := createTmpPolicyFile(t, name, []byte(test.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 1*time.Second) + defer i.Close() + + // Start a gRPC server with SDK unary and stream server interceptors. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, test.md) + + // Verifying authorization decision for Unary RPC. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + + // Verifying authorization decision for Streaming RPC. + stream, err := client.StreamingInputCall(ctx) + if err != nil { + t.Fatalf("failed StreamingInputCall err: %v", err) + } + req := &pb.StreamingInputCallRequest{ + Payload: &pb.Payload{ + Body: []byte("hi"), + }, + } + if err := stream.Send(req); err != nil && err != io.EOF { + t.Fatalf("failed stream.Send err: %v", err) + } + _, err = stream.CloseAndRecv() + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + }) + } +} + +func retryUntil(ctx context.Context, tsc pb.TestServiceClient, want *status.Status) (lastErr error) { + for ctx.Err() == nil { + _, lastErr = tsc.UnaryCall(ctx, &pb.SimpleRequest{}) + if s := status.Convert(lastErr); s.Code() == want.Code() && s.Message() == want.Message() { + return nil + } + time.Sleep(20 * time.Millisecond) + } + return lastErr +} + +func (s) TestSDKFileWatcher_ValidPolicyRefresh(t *testing.T) { + valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"] + file := createTmpPolicyFile(t, "valid_policy_refresh", []byte(valid1.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) + defer i.Close() + + // Start a gRPC server with SDK unary server interceptor. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err()) + } + + // Rewrite the file with a different valid authorization policy. + valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"] + if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Verifying authorization decision. + if got := retryUntil(ctx, client, valid2.wantStatus); got != nil { + t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got) + } +} + +func (s) TestSDKFileWatcher_InvalidPolicySkipReload(t *testing.T) { + valid := sdkTests["DeniesRpcMatchInDenyAndAllow"] + file := createTmpPolicyFile(t, "invalid_policy_skip_reload", []byte(valid.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 20*time.Millisecond) + defer i.Close() + + // Start a gRPC server with SDK unary server interceptors. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err()) + } + + // Skips the invalid policy update, and continues to use the valid policy. + if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Wait 40 ms for background go routine to read updated files. + time.Sleep(40 * time.Millisecond) + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err()) + } +} + +func (s) TestSDKFileWatcher_RecoversFromReloadFailure(t *testing.T) { + valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"] + file := createTmpPolicyFile(t, "recovers_from_reload_failure", []byte(valid1.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) + defer i.Close() + + // Start a gRPC server with SDK unary server interceptors. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err()) + } + + // Skips the invalid policy update, and continues to use the valid policy. + if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Wait 120 ms for background go routine to read updated files. + time.Sleep(120 * time.Millisecond) + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err()) + } + + // Rewrite the file with a different valid authorization policy. + valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"] + if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Verifying authorization decision. + if got := retryUntil(ctx, client, valid2.wantStatus); got != nil { + t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got) + } +} diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go new file mode 100644 index 00000000000..72dc14ed85e --- /dev/null +++ b/authz/sdk_server_interceptors.go @@ -0,0 +1,172 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package authz + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "sync/atomic" + "time" + "unsafe" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/xds/rbac" + "google.golang.org/grpc/status" +) + +var logger = grpclog.Component("authz") + +// StaticInterceptor contains engines used to make authorization decisions. It +// either contains two engines deny engine followed by an allow engine or only +// one allow engine. +type StaticInterceptor struct { + engines rbac.ChainEngine +} + +// NewStatic returns a new StaticInterceptor from a static authorization policy +// JSON string. +func NewStatic(authzPolicy string) (*StaticInterceptor, error) { + rbacs, err := translatePolicy(authzPolicy) + if err != nil { + return nil, err + } + chainEngine, err := rbac.NewChainEngine(rbacs) + if err != nil { + return nil, err + } + return &StaticInterceptor{*chainEngine}, nil +} + +// UnaryInterceptor intercepts incoming Unary RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + err := i.engines.IsAuthorized(ctx) + if err != nil { + if status.Code(err) == codes.PermissionDenied { + return nil, status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") + } + return nil, err + } + return handler(ctx, req) +} + +// StreamInterceptor intercepts incoming Stream RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + err := i.engines.IsAuthorized(ss.Context()) + if err != nil { + if status.Code(err) == codes.PermissionDenied { + return status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") + } + return err + } + return handler(srv, ss) +} + +// FileWatcherInterceptor contains details used to make authorization decisions +// by watching a file path that contains authorization policy in JSON format. +type FileWatcherInterceptor struct { + internalInterceptor unsafe.Pointer // *StaticInterceptor + policyFile string + policyContents []byte + refreshDuration time.Duration + cancel context.CancelFunc +} + +// NewFileWatcher returns a new FileWatcherInterceptor from a policy file +// that contains JSON string of authorization policy and a refresh duration to +// specify the amount of time between policy refreshes. +func NewFileWatcher(file string, duration time.Duration) (*FileWatcherInterceptor, error) { + if file == "" { + return nil, fmt.Errorf("authorization policy file path is empty") + } + if duration <= time.Duration(0) { + return nil, fmt.Errorf("requires refresh interval(%v) greater than 0s", duration) + } + i := &FileWatcherInterceptor{policyFile: file, refreshDuration: duration} + if err := i.updateInternalInterceptor(); err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) + i.cancel = cancel + // Create a background go routine for policy refresh. + go i.run(ctx) + return i, nil +} + +func (i *FileWatcherInterceptor) run(ctx context.Context) { + ticker := time.NewTicker(i.refreshDuration) + for { + if err := i.updateInternalInterceptor(); err != nil { + logger.Warningf("authorization policy reload status err: %v", err) + } + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + } + } +} + +// updateInternalInterceptor checks if the policy file that is watching has changed, +// and if so, updates the internalInterceptor with the policy. Unlike the +// constructor, if there is an error in reading the file or parsing the policy, the +// previous internalInterceptors will not be replaced. +func (i *FileWatcherInterceptor) updateInternalInterceptor() error { + policyContents, err := ioutil.ReadFile(i.policyFile) + if err != nil { + return fmt.Errorf("policyFile(%s) read failed: %v", i.policyFile, err) + } + if bytes.Equal(i.policyContents, policyContents) { + return nil + } + i.policyContents = policyContents + policyContentsString := string(policyContents) + interceptor, err := NewStatic(policyContentsString) + if err != nil { + return err + } + atomic.StorePointer(&i.internalInterceptor, unsafe.Pointer(interceptor)) + logger.Infof("authorization policy reload status: successfully loaded new policy %v", policyContentsString) + return nil +} + +// Close cleans up resources allocated by the interceptor. +func (i *FileWatcherInterceptor) Close() { + i.cancel() +} + +// UnaryInterceptor intercepts incoming Unary RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *FileWatcherInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).UnaryInterceptor(ctx, req, info, handler) +} + +// StreamInterceptor intercepts incoming Stream RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *FileWatcherInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).StreamInterceptor(srv, ss, info, handler) +} diff --git a/authz/sdk_server_interceptors_test.go b/authz/sdk_server_interceptors_test.go new file mode 100644 index 00000000000..f43f9807612 --- /dev/null +++ b/authz/sdk_server_interceptors_test.go @@ -0,0 +1,121 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authz_test + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "testing" + "time" + + "google.golang.org/grpc/authz" +) + +func createTmpPolicyFile(t *testing.T, dirSuffix string, policy []byte) string { + t.Helper() + + // Create a temp directory. Passing an empty string for the first argument + // uses the system temp directory. + dir, err := ioutil.TempDir("", dirSuffix) + if err != nil { + t.Fatalf("ioutil.TempDir() failed: %v", err) + } + t.Logf("Using tmpdir: %s", dir) + // Write policy into file. + filename := path.Join(dir, "policy.json") + if err := ioutil.WriteFile(filename, policy, os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", filename, err) + } + t.Logf("Wrote policy %s to file at %s", string(policy), filename) + return filename +} + +func (s) TestNewStatic(t *testing.T) { + tests := map[string]struct { + authzPolicy string + wantErr error + }{ + "InvalidPolicyFailsToCreateInterceptor": { + authzPolicy: `{}`, + wantErr: fmt.Errorf(`"name" is not present`), + }, + "ValidPolicyCreatesInterceptor": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ] + }`, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if _, err := authz.NewStatic(test.authzPolicy); fmt.Sprint(err) != fmt.Sprint(test.wantErr) { + t.Fatalf("NewStatic(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr) + } + }) + } +} + +func (s) TestNewFileWatcher(t *testing.T) { + tests := map[string]struct { + authzPolicy string + refreshDuration time.Duration + wantErr error + }{ + "InvalidRefreshDurationFailsToCreateInterceptor": { + refreshDuration: time.Duration(0), + wantErr: fmt.Errorf("requires refresh interval(0s) greater than 0s"), + }, + "InvalidPolicyFailsToCreateInterceptor": { + authzPolicy: `{}`, + refreshDuration: time.Duration(1), + wantErr: fmt.Errorf(`"name" is not present`), + }, + "ValidPolicyCreatesInterceptor": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ] + }`, + refreshDuration: time.Duration(1), + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + file := createTmpPolicyFile(t, name, []byte(test.authzPolicy)) + i, err := authz.NewFileWatcher(file, test.refreshDuration) + if fmt.Sprint(err) != fmt.Sprint(test.wantErr) { + t.Fatalf("NewFileWatcher(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr) + } + if i != nil { + i.Close() + } + }) + } +} diff --git a/balancer/balancer.go b/balancer/balancer.go index ab531f4c0b8..178de0898aa 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -75,24 +75,26 @@ func Get(name string) Builder { return nil } -// SubConn represents a gRPC sub connection. -// Each sub connection contains a list of addresses. gRPC will -// try to connect to them (in sequence), and stop trying the -// remainder once one connection is successful. +// A SubConn represents a single connection to a gRPC backend service. // -// The reconnect backoff will be applied on the list, not a single address. -// For example, try_on_all_addresses -> backoff -> try_on_all_addresses. +// Each SubConn contains a list of addresses. // -// All SubConns start in IDLE, and will not try to connect. To trigger -// the connecting, Balancers must call Connect. -// When the connection encounters an error, it will reconnect immediately. -// When the connection becomes IDLE, it will not reconnect unless Connect is -// called. +// All SubConns start in IDLE, and will not try to connect. To trigger the +// connecting, Balancers must call Connect. If a connection re-enters IDLE, +// Balancers must call Connect again to trigger a new connection attempt. // -// This interface is to be implemented by gRPC. Users should not need a -// brand new implementation of this interface. For the situations like -// testing, the new implementation should embed this interface. This allows -// gRPC to add new methods to this interface. +// gRPC will try to connect to the addresses in sequence, and stop trying the +// remainder once the first connection is successful. If an attempt to connect +// to all addresses encounters an error, the SubConn will enter +// TRANSIENT_FAILURE for a backoff period, and then transition to IDLE. +// +// Once established, if a connection is lost, the SubConn will transition +// directly to IDLE. +// +// This interface is to be implemented by gRPC. Users should not need their own +// implementation of this interface. For situations like testing, any +// implementations should embed this interface. This allows gRPC to add new +// methods to this interface. type SubConn interface { // UpdateAddresses updates the addresses used in this SubConn. // gRPC checks if currently-connected address is still in the new list. @@ -326,6 +328,20 @@ type Balancer interface { Close() } +// ExitIdler is an optional interface for balancers to implement. If +// implemented, ExitIdle will be called when ClientConn.Connect is called, if +// the ClientConn is idle. If unimplemented, ClientConn.Connect will cause +// all SubConns to connect. +// +// Notice: it will be required for all balancers to implement this in a future +// release. +type ExitIdler interface { + // ExitIdle instructs the LB policy to reconnect to backends / exit the + // IDLE state, if appropriate and possible. Note that SubConns that enter + // the IDLE state will not reconnect until SubConn.Connect is called. + ExitIdle() +} + // SubConnState describes the state of a SubConn. type SubConnState struct { // ConnectivityState is the connectivity state of the SubConn. @@ -353,8 +369,10 @@ var ErrBadResolverState = errors.New("bad resolver state") // // It's not thread safe. type ConnectivityStateEvaluator struct { - numReady uint64 // Number of addrConns in ready state. - numConnecting uint64 // Number of addrConns in connecting state. + numReady uint64 // Number of addrConns in ready state. + numConnecting uint64 // Number of addrConns in connecting state. + numTransientFailure uint64 // Number of addrConns in transient failure state. + numIdle uint64 // Number of addrConns in idle state. } // RecordTransition records state change happening in subConn and based on that @@ -362,9 +380,11 @@ type ConnectivityStateEvaluator struct { // // - If at least one SubConn in Ready, the aggregated state is Ready; // - Else if at least one SubConn in Connecting, the aggregated state is Connecting; -// - Else the aggregated state is TransientFailure. +// - Else if at least one SubConn is TransientFailure, the aggregated state is Transient Failure; +// - Else if at least one SubConn is Idle, the aggregated state is Idle; +// - Else there are no subconns and the aggregated state is Transient Failure // -// Idle and Shutdown are not considered. +// Shutdown is not considered. func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState connectivity.State) connectivity.State { // Update counters. for idx, state := range []connectivity.State{oldState, newState} { @@ -374,6 +394,10 @@ func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState conne cse.numReady += updateVal case connectivity.Connecting: cse.numConnecting += updateVal + case connectivity.TransientFailure: + cse.numTransientFailure += updateVal + case connectivity.Idle: + cse.numIdle += updateVal } } @@ -384,5 +408,11 @@ func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState conne if cse.numConnecting > 0 { return connectivity.Connecting } + if cse.numTransientFailure > 0 { + return connectivity.TransientFailure + } + if cse.numIdle > 0 { + return connectivity.Idle + } return connectivity.TransientFailure } diff --git a/balancer/base/balancer.go b/balancer/base/balancer.go index 383d02ec2bf..8dd504299fe 100644 --- a/balancer/base/balancer.go +++ b/balancer/base/balancer.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" + "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" @@ -41,7 +42,7 @@ func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) cc: cc, pickerBuilder: bb.pickerBuilder, - subConns: make(map[resolver.Address]balancer.SubConn), + subConns: make(map[resolver.Address]subConnInfo), scStates: make(map[balancer.SubConn]connectivity.State), csEvltr: &balancer.ConnectivityStateEvaluator{}, config: bb.config, @@ -57,6 +58,11 @@ func (bb *baseBuilder) Name() string { return bb.name } +type subConnInfo struct { + subConn balancer.SubConn + attrs *attributes.Attributes +} + type baseBalancer struct { cc balancer.ClientConn pickerBuilder PickerBuilder @@ -64,7 +70,7 @@ type baseBalancer struct { csEvltr *balancer.ConnectivityStateEvaluator state connectivity.State - subConns map[resolver.Address]balancer.SubConn // `attributes` is stripped from the keys of this map (the addresses) + subConns map[resolver.Address]subConnInfo // `attributes` is stripped from the keys of this map (the addresses) scStates map[balancer.SubConn]connectivity.State picker balancer.Picker config Config @@ -114,7 +120,7 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error { aNoAttrs := a aNoAttrs.Attributes = nil addrsSet[aNoAttrs] = struct{}{} - if sc, ok := b.subConns[aNoAttrs]; !ok { + if scInfo, ok := b.subConns[aNoAttrs]; !ok { // a is a new address (not existing in b.subConns). // // When creating SubConn, the original address with attributes is @@ -125,8 +131,9 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error { logger.Warningf("base.baseBalancer: failed to create new SubConn: %v", err) continue } - b.subConns[aNoAttrs] = sc + b.subConns[aNoAttrs] = subConnInfo{subConn: sc, attrs: a.Attributes} b.scStates[sc] = connectivity.Idle + b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle) sc.Connect() } else { // Always update the subconn's address in case the attributes @@ -135,13 +142,15 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error { // The SubConn does a reflect.DeepEqual of the new and old // addresses. So this is a noop if the current address is the same // as the old one (including attributes). - b.cc.UpdateAddresses(sc, []resolver.Address{a}) + scInfo.attrs = a.Attributes + b.subConns[aNoAttrs] = scInfo + b.cc.UpdateAddresses(scInfo.subConn, []resolver.Address{a}) } } - for a, sc := range b.subConns { + for a, scInfo := range b.subConns { // a was removed by resolver. if _, ok := addrsSet[a]; !ok { - b.cc.RemoveSubConn(sc) + b.cc.RemoveSubConn(scInfo.subConn) delete(b.subConns, a) // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. // The entry will be deleted in UpdateSubConnState. @@ -184,9 +193,10 @@ func (b *baseBalancer) regeneratePicker() { readySCs := make(map[balancer.SubConn]SubConnInfo) // Filter out all ready SCs from full subConn map. - for addr, sc := range b.subConns { - if st, ok := b.scStates[sc]; ok && st == connectivity.Ready { - readySCs[sc] = SubConnInfo{Address: addr} + for addr, scInfo := range b.subConns { + if st, ok := b.scStates[scInfo.subConn]; ok && st == connectivity.Ready { + addr.Attributes = scInfo.attrs + readySCs[scInfo.subConn] = SubConnInfo{Address: addr} } } b.picker = b.pickerBuilder.Build(PickerBuildInfo{ReadySCs: readySCs}) @@ -204,10 +214,14 @@ func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Su } return } - if oldS == connectivity.TransientFailure && s == connectivity.Connecting { - // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent + if oldS == connectivity.TransientFailure && + (s == connectivity.Connecting || s == connectivity.Idle) { + // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or // CONNECTING transitions to prevent the aggregated state from being // always CONNECTING when many backends exist but are all down. + if s == connectivity.Idle { + sc.Connect() + } return } b.scStates[sc] = s @@ -233,7 +247,6 @@ func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Su b.state == connectivity.TransientFailure { b.regeneratePicker() } - b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) } @@ -242,6 +255,11 @@ func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Su func (b *baseBalancer) Close() { } +// ExitIdle is a nop because the base balancer attempts to stay connected to +// all SubConns at all times. +func (b *baseBalancer) ExitIdle() { +} + // NewErrPicker returns a Picker that always returns err on Pick(). func NewErrPicker(err error) balancer.Picker { return &errPicker{err: err} diff --git a/balancer/base/balancer_test.go b/balancer/base/balancer_test.go index 03114251a04..f8ff8cf9844 100644 --- a/balancer/base/balancer_test.go +++ b/balancer/base/balancer_test.go @@ -23,6 +23,7 @@ import ( "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/resolver" ) @@ -35,12 +36,24 @@ func (c *testClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewS return c.newSubConn(addrs, opts) } +func (c *testClientConn) UpdateState(balancer.State) {} + type testSubConn struct{} func (sc *testSubConn) UpdateAddresses(addresses []resolver.Address) {} func (sc *testSubConn) Connect() {} +// testPickBuilder creates balancer.Picker for test. +type testPickBuilder struct { + validate func(info PickerBuildInfo) +} + +func (p *testPickBuilder) Build(info PickerBuildInfo) balancer.Picker { + p.validate(info) + return nil +} + func TestBaseBalancerStripAttributes(t *testing.T) { b := (&baseBuilder{}).Build(&testClientConn{ newSubConn: func(addrs []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) { @@ -64,7 +77,46 @@ func TestBaseBalancerStripAttributes(t *testing.T) { for addr := range b.subConns { if addr.Attributes != nil { - t.Errorf("in b.subConns, got address %+v with nil attributes, want not nil", addr) + t.Errorf("in b.subConns, got address %+v with not nil attributes, want nil", addr) + } + } +} + +func TestBaseBalancerReserveAttributes(t *testing.T) { + var v = func(info PickerBuildInfo) { + for _, sc := range info.ReadySCs { + if sc.Address.Addr == "1.1.1.1" { + if sc.Address.Attributes == nil { + t.Errorf("in picker.validate, got address %+v with nil attributes, want not nil", sc.Address) + } + foo, ok := sc.Address.Attributes.Value("foo").(string) + if !ok || foo != "2233niang" { + t.Errorf("in picker.validate, got address[1.1.1.1] with invalid attributes value %v, want 2233niang", sc.Address.Attributes.Value("foo")) + } + } else if sc.Address.Addr == "2.2.2.2" { + if sc.Address.Attributes != nil { + t.Error("in b.subConns, got address[2.2.2.2] with not nil attributes, want nil") + } + } } } + pickBuilder := &testPickBuilder{validate: v} + b := (&baseBuilder{pickerBuilder: pickBuilder}).Build(&testClientConn{ + newSubConn: func(addrs []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) { + return &testSubConn{}, nil + }, + }, balancer.BuildOptions{}).(*baseBalancer) + + b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {Addr: "1.1.1.1", Attributes: attributes.New("foo", "2233niang")}, + {Addr: "2.2.2.2", Attributes: nil}, + }, + }, + }) + + for sc := range b.scStates { + b.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready, ConnectionError: nil}) + } } diff --git a/balancer/grpclb/grpc_lb_v1/load_balancer_grpc.pb.go b/balancer/grpclb/grpc_lb_v1/load_balancer_grpc.pb.go index d56b77cca63..50cc9da4a90 100644 --- a/balancer/grpclb/grpc_lb_v1/load_balancer_grpc.pb.go +++ b/balancer/grpclb/grpc_lb_v1/load_balancer_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/lb/v1/load_balancer.proto package grpc_lb_v1 diff --git a/balancer/grpclb/grpclb.go b/balancer/grpclb/grpclb.go index a43d8964119..fe423af182a 100644 --- a/balancer/grpclb/grpclb.go +++ b/balancer/grpclb/grpclb.go @@ -25,6 +25,7 @@ package grpclb import ( "context" "errors" + "fmt" "sync" "time" @@ -134,6 +135,7 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal lb := &lbBalancer{ cc: newLBCacheClientConn(cc), + dialTarget: opt.Target.Endpoint, target: opt.Target.Endpoint, opt: opt, fallbackTimeout: b.fallbackTimeout, @@ -163,9 +165,10 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal } type lbBalancer struct { - cc *lbCacheClientConn - target string - opt balancer.BuildOptions + cc *lbCacheClientConn + dialTarget string // user's dial target + target string // same as dialTarget unless overridden in service config + opt balancer.BuildOptions usePickFirst bool @@ -221,6 +224,7 @@ type lbBalancer struct { // when resolved address updates are received, and read in the goroutine // handling fallback. resolvedBackendAddrs []resolver.Address + connErr error // the last connection error } // regeneratePicker takes a snapshot of the balancer, and generates a picker from @@ -230,7 +234,7 @@ type lbBalancer struct { // Caller must hold lb.mu. func (lb *lbBalancer) regeneratePicker(resetDrop bool) { if lb.state == connectivity.TransientFailure { - lb.picker = &errPicker{err: balancer.ErrTransientFailure} + lb.picker = &errPicker{err: fmt.Errorf("all SubConns are in TransientFailure, last connection error: %v", lb.connErr)} return } @@ -336,6 +340,8 @@ func (lb *lbBalancer) UpdateSubConnState(sc balancer.SubConn, scs balancer.SubCo // When an address was removed by resolver, b called RemoveSubConn but // kept the sc's state in scStates. Remove state for this sc here. delete(lb.scStates, sc) + case connectivity.TransientFailure: + lb.connErr = scs.ConnectionError } // Force regenerate picker if // - this sc became ready from not-ready @@ -394,6 +400,30 @@ func (lb *lbBalancer) handleServiceConfig(gc *grpclbServiceConfig) { lb.mu.Lock() defer lb.mu.Unlock() + // grpclb uses the user's dial target to populate the `Name` field of the + // `InitialLoadBalanceRequest` message sent to the remote balancer. But when + // grpclb is used a child policy in the context of RLS, we want the `Name` + // field to be populated with the value received from the RLS server. To + // support this use case, an optional "target_name" field has been added to + // the grpclb LB policy's config. If specified, it overrides the name of + // the target to be sent to the remote balancer; if not, the target to be + // sent to the balancer will continue to be obtained from the target URI + // passed to the gRPC client channel. Whenever that target to be sent to the + // balancer is updated, we need to restart the stream to the balancer as + // this target is sent in the first message on the stream. + if gc != nil { + target := lb.dialTarget + if gc.TargetName != "" { + target = gc.TargetName + } + if target != lb.target { + lb.target = target + if lb.ccRemoteLB != nil { + lb.ccRemoteLB.cancelRemoteBalancerCall() + } + } + } + newUsePickFirst := childIsPickFirst(gc) if lb.usePickFirst == newUsePickFirst { return @@ -484,3 +514,5 @@ func (lb *lbBalancer) Close() { } lb.cc.close() } + +func (lb *lbBalancer) ExitIdle() {} diff --git a/balancer/grpclb/grpclb_config.go b/balancer/grpclb/grpclb_config.go index aac3719631b..b4e23dee017 100644 --- a/balancer/grpclb/grpclb_config.go +++ b/balancer/grpclb/grpclb_config.go @@ -34,6 +34,7 @@ const ( type grpclbServiceConfig struct { serviceconfig.LoadBalancingConfig ChildPolicy *[]map[string]json.RawMessage + TargetName string } func (b *lbBuilder) ParseConfig(lbConfig json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { diff --git a/balancer/grpclb/grpclb_config_test.go b/balancer/grpclb/grpclb_config_test.go index 5a45de90494..0db2299157e 100644 --- a/balancer/grpclb/grpclb_config_test.go +++ b/balancer/grpclb/grpclb_config_test.go @@ -20,52 +20,68 @@ package grpclb import ( "encoding/json" - "errors" - "fmt" - "reflect" - "strings" "testing" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/serviceconfig" ) func (s) TestParse(t *testing.T) { tests := []struct { name string - s string + sc string want serviceconfig.LoadBalancingConfig - wantErr error + wantErr bool }{ { name: "empty", - s: "", + sc: "", want: nil, - wantErr: errors.New("unexpected end of JSON input"), + wantErr: true, }, { name: "success1", - s: `{"childPolicy":[{"pick_first":{}}]}`, + sc: ` +{ + "childPolicy": [ + {"pick_first":{}} + ], + "targetName": "foo-service" +}`, want: &grpclbServiceConfig{ ChildPolicy: &[]map[string]json.RawMessage{ {"pick_first": json.RawMessage("{}")}, }, + TargetName: "foo-service", }, }, { name: "success2", - s: `{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`, + sc: ` +{ + "childPolicy": [ + {"round_robin":{}}, + {"pick_first":{}} + ], + "targetName": "foo-service" +}`, want: &grpclbServiceConfig{ ChildPolicy: &[]map[string]json.RawMessage{ {"round_robin": json.RawMessage("{}")}, {"pick_first": json.RawMessage("{}")}, }, + TargetName: "foo-service", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got, err := (&lbBuilder{}).ParseConfig(json.RawMessage(tt.s)); !reflect.DeepEqual(got, tt.want) || !strings.Contains(fmt.Sprint(err), fmt.Sprint(tt.wantErr)) { - t.Errorf("parseFullServiceConfig() = %+v, %+v, want %+v, ", got, err, tt.want, tt.wantErr) + got, err := (&lbBuilder{}).ParseConfig(json.RawMessage(tt.sc)) + if (err != nil) != (tt.wantErr) { + t.Fatalf("ParseConfig(%q) returned error: %v, wantErr: %v", tt.sc, err, tt.wantErr) + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatalf("ParseConfig(%q) returned unexpected difference (-want +got):\n%s", tt.sc, diff) } }) } diff --git a/balancer/grpclb/grpclb_remote_balancer.go b/balancer/grpclb/grpclb_remote_balancer.go index 5ac8d86bd57..0210c012d7b 100644 --- a/balancer/grpclb/grpclb_remote_balancer.go +++ b/balancer/grpclb/grpclb_remote_balancer.go @@ -206,6 +206,9 @@ type remoteBalancerCCWrapper struct { backoff backoff.Strategy done chan struct{} + streamMu sync.Mutex + streamCancel func() + // waitgroup to wait for all goroutines to exit. wg sync.WaitGroup } @@ -319,10 +322,8 @@ func (ccw *remoteBalancerCCWrapper) sendLoadReport(s *balanceLoadClientStream, i } } -func (ccw *remoteBalancerCCWrapper) callRemoteBalancer() (backoff bool, _ error) { +func (ccw *remoteBalancerCCWrapper) callRemoteBalancer(ctx context.Context) (backoff bool, _ error) { lbClient := &loadBalancerClient{cc: ccw.cc} - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() stream, err := lbClient.BalanceLoad(ctx, grpc.WaitForReady(true)) if err != nil { return true, fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err) @@ -362,11 +363,43 @@ func (ccw *remoteBalancerCCWrapper) callRemoteBalancer() (backoff bool, _ error) return false, ccw.readServerList(stream) } +// cancelRemoteBalancerCall cancels the context used by the stream to the remote +// balancer. watchRemoteBalancer() takes care of restarting this call after the +// stream fails. +func (ccw *remoteBalancerCCWrapper) cancelRemoteBalancerCall() { + ccw.streamMu.Lock() + if ccw.streamCancel != nil { + ccw.streamCancel() + ccw.streamCancel = nil + } + ccw.streamMu.Unlock() +} + func (ccw *remoteBalancerCCWrapper) watchRemoteBalancer() { - defer ccw.wg.Done() + defer func() { + ccw.wg.Done() + ccw.streamMu.Lock() + if ccw.streamCancel != nil { + // This is to make sure that we don't leak the context when we are + // directly returning from inside of the below `for` loop. + ccw.streamCancel() + ccw.streamCancel = nil + } + ccw.streamMu.Unlock() + }() + var retryCount int + var ctx context.Context for { - doBackoff, err := ccw.callRemoteBalancer() + ccw.streamMu.Lock() + if ccw.streamCancel != nil { + ccw.streamCancel() + ccw.streamCancel = nil + } + ctx, ccw.streamCancel = context.WithCancel(context.Background()) + ccw.streamMu.Unlock() + + doBackoff, err := ccw.callRemoteBalancer(ctx) select { case <-ccw.done: return diff --git a/balancer/grpclb/grpclb_test.go b/balancer/grpclb/grpclb_test.go index 9cbb338c241..3b666764728 100644 --- a/balancer/grpclb/grpclb_test.go +++ b/balancer/grpclb/grpclb_test.go @@ -31,12 +31,16 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc" "google.golang.org/grpc/balancer" grpclbstate "google.golang.org/grpc/balancer/grpclb/state" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" @@ -60,6 +64,13 @@ var ( fakeName = "fake.Name" ) +const ( + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + testUserAgent = "test-user-agent" + grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}` +) + type s struct { grpctest.Tester } @@ -136,18 +147,6 @@ func (s *rpcStats) merge(cs *lbpb.ClientStats) { s.mu.Unlock() } -func mapsEqual(a, b map[string]int64) bool { - if len(a) != len(b) { - return false - } - for k, v1 := range a { - if v2, ok := b[k]; !ok || v1 != v2 { - return false - } - } - return true -} - func atomicEqual(a, b *int64) bool { return atomic.LoadInt64(a) == atomic.LoadInt64(b) } @@ -172,7 +171,7 @@ func (s *rpcStats) equal(o *rpcStats) bool { defer s.mu.Unlock() o.mu.Lock() defer o.mu.Unlock() - return mapsEqual(s.numCallsDropped, o.numCallsDropped) + return cmp.Equal(s.numCallsDropped, o.numCallsDropped, cmpopts.EquateEmpty()) } func (s *rpcStats) String() string { @@ -188,24 +187,28 @@ func (s *rpcStats) String() string { type remoteBalancer struct { lbgrpc.UnimplementedLoadBalancerServer - sls chan *lbpb.ServerList - statsDura time.Duration - done chan struct{} - stats *rpcStats - statsChan chan *lbpb.ClientStats - fbChan chan struct{} - - customUserAgent string + sls chan *lbpb.ServerList + statsDura time.Duration + done chan struct{} + stats *rpcStats + statsChan chan *lbpb.ClientStats + fbChan chan struct{} + balanceLoadCh chan struct{} // notify successful invocation of BalanceLoad + + wantUserAgent string // expected user-agent in metadata of BalancerLoad + wantServerName string // expected server name in InitialLoadBalanceRequest } -func newRemoteBalancer(customUserAgent string, statsChan chan *lbpb.ClientStats) *remoteBalancer { +func newRemoteBalancer(wantUserAgent, wantServerName string, statsChan chan *lbpb.ClientStats) *remoteBalancer { return &remoteBalancer{ - sls: make(chan *lbpb.ServerList, 1), - done: make(chan struct{}), - stats: newRPCStats(), - statsChan: statsChan, - fbChan: make(chan struct{}), - customUserAgent: customUserAgent, + sls: make(chan *lbpb.ServerList, 1), + done: make(chan struct{}), + stats: newRPCStats(), + statsChan: statsChan, + fbChan: make(chan struct{}), + balanceLoadCh: make(chan struct{}, 1), + wantUserAgent: wantUserAgent, + wantServerName: wantServerName, } } @@ -218,15 +221,18 @@ func (b *remoteBalancer) fallbackNow() { b.fbChan <- struct{}{} } +func (b *remoteBalancer) updateServerName(name string) { + b.wantServerName = name +} + func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error { md, ok := metadata.FromIncomingContext(stream.Context()) if !ok { return status.Error(codes.Internal, "failed to receive metadata") } - if b.customUserAgent != "" { - ua := md["user-agent"] - if len(ua) == 0 || !strings.HasPrefix(ua[0], b.customUserAgent) { - return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.customUserAgent) + if b.wantUserAgent != "" { + if ua := md["user-agent"]; len(ua) == 0 || !strings.HasPrefix(ua[0], b.wantUserAgent) { + return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.wantUserAgent) } } @@ -235,9 +241,10 @@ func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServe return err } initReq := req.GetInitialRequest() - if initReq.Name != beServerName { - return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name) + if initReq.Name != b.wantServerName { + return status.Errorf(codes.InvalidArgument, "invalid service name: %q, want: %q", initReq.Name, b.wantServerName) } + b.balanceLoadCh <- struct{}{} resp := &lbpb.LoadBalanceResponse{ LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{ InitialResponse: &lbpb.InitialLoadBalanceResponse{ @@ -253,11 +260,8 @@ func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServe } go func() { for { - var ( - req *lbpb.LoadBalanceRequest - err error - ) - if req, err = stream.Recv(); err != nil { + req, err := stream.Recv() + if err != nil { return } b.stats.merge(req.GetClientStats()) @@ -347,7 +351,7 @@ type testServers struct { beListeners []net.Listener } -func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) { +func startBackendsAndRemoteLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) { var ( beListeners []net.Listener ls *remoteBalancer @@ -380,7 +384,7 @@ func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan cha sn: lbServerName, } lb = grpc.NewServer(grpc.Creds(lbCreds)) - ls = newRemoteBalancer(customUserAgent, statsChan) + ls = newRemoteBalancer(customUserAgent, beServerName, statsChan) lbgrpc.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) @@ -407,34 +411,29 @@ func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan cha return } -var grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}` - func (s) TestGRPCLB(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - const testUserAgent = "test-user-agent" - tss, cleanup, err := newLoadBalancer(1, testUserAgent, nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, testUserAgent, nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, + tss.ls.sls <- &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, } - tss.ls.sls <- sl - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer), + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer), grpc.WithUserAgent(testUserAgent)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) @@ -445,12 +444,11 @@ func (s) TestGRPCLB(t *testing.T) { rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, &grpclbstate.State{BalancerAddresses: []resolver.Address{{ Addr: tss.lbAddr, - Type: resolver.Backend, ServerName: lbServerName, }}}) r.UpdateState(rs) - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) @@ -461,7 +459,7 @@ func (s) TestGRPCLB(t *testing.T) { func (s) TestGRPCLBWeighted(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(2, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(2, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -481,23 +479,25 @@ func (s) TestGRPCLBWeighted(t *testing.T) { portsToIndex[tss.bePorts[i]] = i } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}}) + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() sequences := []string{"00101", "00011"} for _, seq := range sequences { var ( @@ -526,7 +526,7 @@ func (s) TestGRPCLBWeighted(t *testing.T) { func (s) TestDropRequest(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(2, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(2, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -546,22 +546,23 @@ func (s) TestDropRequest(t *testing.T) { Drop: true, }}, } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}}) + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) var ( i int @@ -573,6 +574,8 @@ func (s) TestDropRequest(t *testing.T) { sleepEachLoop = time.Millisecond loopCount = int(time.Second / sleepEachLoop) ) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Make a non-fail-fast RPC and wait for it to succeed. for i = 0; i < loopCount; i++ { if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err == nil { @@ -681,49 +684,51 @@ func (s) TestBalancerDisconnects(t *testing.T) { lbs []*grpc.Server ) for i := 0; i < 2; i++ { - tss, cleanup, err := newLoadBalancer(1, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, + tss.ls.sls <- &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, } - tss.ls.sls <- sl tests = append(tests, tss) lbs = append(lbs, tss.lb) } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tests[0].lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: tests[1].lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}}) + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{ + { + Addr: tests[0].lbAddr, + ServerName: lbServerName, + }, + { + Addr: tests[1].lbAddr, + ServerName: lbServerName, + }, + }}) + r.UpdateState(rs) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() var p peer.Peer if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) @@ -750,16 +755,27 @@ func (s) TestBalancerDisconnects(t *testing.T) { func (s) TestFallback(t *testing.T) { balancer.Register(newLBBuilderWithFallbackTimeout(100 * time.Millisecond)) defer balancer.Register(newLBBuilder()) - r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, "", nil) + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } + tss.ls.sls <- sl - // Start a standalone backend. + // Start a standalone backend for fallback. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen %v", err) @@ -768,37 +784,29 @@ func (s) TestFallback(t *testing.T) { standaloneBEs := startBackends(beServerName, true, beLis) defer stopBackends(standaloneBEs) - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, - } - tss.ls.sls <- sl - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: "invalid.address", - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + // Push an update to the resolver with fallback backend address stored in + // the `Addresses` field and an invalid remote balancer address stored in + // attributes, which will cause fallback behavior to be invoked. + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: "invalid.address", ServerName: lbServerName}}}) + r.UpdateState(rs) + // Make an RPC and verify that it got routed to the fallback backend. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() var p peer.Peer if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, ", err) @@ -807,15 +815,21 @@ func (s) TestFallback(t *testing.T) { t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr()) } - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + // Push another update to the resolver, this time with a valid balancer + // address in the attributes field. + rs = resolver.State{ + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } + // Wait for RPCs to get routed to the backend behind the remote balancer. var backendUsed bool for i := 0; i < 1000; i++ { if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { @@ -856,7 +870,7 @@ func (s) TestFallback(t *testing.T) { t.Fatalf("No RPC sent to fallback after 2 seconds") } - // Restart backend and remote balancer, should not use backends. + // Restart backend and remote balancer, should not use fallback backend. tss.beListeners[0].(*restartableListener).restart() tss.lbListener.(*restartableListener).restart() tss.ls.sls <- sl @@ -880,13 +894,25 @@ func (s) TestFallback(t *testing.T) { func (s) TestExplicitFallback(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, "", nil) + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } + tss.ls.sls <- sl - // Start a standalone backend. + // Start a standalone backend for fallback. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen %v", err) @@ -895,37 +921,25 @@ func (s) TestExplicitFallback(t *testing.T) { standaloneBEs := startBackends(beServerName, true, beLis) defer stopBackends(standaloneBEs) - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, - } - tss.ls.sls <- sl - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() var p peer.Peer var backendUsed bool for i := 0; i < 2000; i++ { @@ -980,23 +994,34 @@ func (s) TestExplicitFallback(t *testing.T) { } func (s) TestFallBackWithNoServerAddress(t *testing.T) { - resolveNowCh := make(chan struct{}, 1) + resolveNowCh := testutils.NewChannel() r := manual.NewBuilderWithScheme("whatever") r.ResolveNowCallback = func(resolver.ResolveNowOptions) { - select { - case <-resolveNowCh: - default: + ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer cancel() + if err := resolveNowCh.SendContext(ctx, nil); err != nil { + t.Error("timeout when attemping to send on resolverNowCh") } - resolveNowCh <- struct{}{} } - tss, cleanup, err := newLoadBalancer(1, "", nil) + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer yet. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } - // Start a standalone backend. + // Start a standalone backend for fallback. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen %v", err) @@ -1005,81 +1030,61 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) { standaloneBEs := startBackends(beServerName, true, beLis) defer stopBackends(standaloneBEs) - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, - } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - // Select grpclb with service config. - const pfc = `{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"round_robin":{}}]}}]}` - scpr := r.CC.ParseServiceConfig(pfc) - if scpr.Err != nil { - t.Fatalf("Error parsing config %q: %v", pfc, scpr.Err) - } - + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() for i := 0; i < 2; i++ { - // Send an update with only backend address. grpclb should enter fallback - // and use the fallback backend. + // Send an update with only backend address. grpclb should enter + // fallback and use the fallback backend. r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}, - ServiceConfig: scpr, + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), }) - select { - case <-resolveNowCh: - t.Errorf("unexpected resolveNow when grpclb gets no balancer address 1111, %d", i) - case <-time.After(time.Second): + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := resolveNowCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("unexpected resolveNow when grpclb gets no balancer address 1111, %d", i) } var p peer.Peer - rpcCtx, rpcCancel := context.WithTimeout(context.Background(), time.Second) - defer rpcCancel() - if _, err := testC.EmptyCall(rpcCtx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, ", err) } if p.Addr.String() != beLis.Addr().String() { t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr()) } - select { - case <-resolveNowCh: + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := resolveNowCh.Receive(sCtx); err != context.DeadlineExceeded { t.Errorf("unexpected resolveNow when grpclb gets no balancer address 2222, %d", i) - case <-time.After(time.Second): } tss.ls.sls <- sl // Send an update with balancer address. The backends behind grpclb should // be used. - r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}, - ServiceConfig: scpr, - }) + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } var backendUsed bool for i := 0; i < 1000; i++ { @@ -1101,7 +1106,7 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) { func (s) TestGRPCLBPickFirst(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(3, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(3, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -1125,11 +1130,10 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { portsToIndex[tss.bePorts[i]] = i } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } @@ -1143,21 +1147,11 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { tss.ls.sls <- &lbpb.ServerList{Servers: beServers[0:3]} // Start with sub policy pick_first. - const pfc = `{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"pick_first":{}}]}}]}` - scpr := r.CC.ParseServiceConfig(pfc) - if scpr.Err != nil { - t.Fatalf("Error parsing config %q: %v", pfc, scpr.Err) - } - - r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}, - ServiceConfig: scpr, - }) + rs := resolver.State{ServiceConfig: r.CC.ParseServiceConfig(`{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"pick_first":{}}]}}]}`)} + r.UpdateState(grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}})) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() result = "" for i := 0; i < 1000; i++ { if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { @@ -1194,19 +1188,12 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { } // Switch sub policy to roundrobin. - grpclbServiceConfigEmpty := r.CC.ParseServiceConfig(`{}`) - if grpclbServiceConfigEmpty.Err != nil { - t.Fatalf("Error parsing config %q: %v", `{}`, grpclbServiceConfigEmpty.Err) - } - - r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ + rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ Addr: tss.lbAddr, - Type: resolver.GRPCLB, ServerName: lbServerName, - }}, - ServiceConfig: grpclbServiceConfigEmpty, - }) + }}}) + r.UpdateState(rs) result = "" for i := 0; i < 1000; i++ { @@ -1232,6 +1219,142 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { } } +func (s) TestGRPCLBBackendConnectionErrorPropagation(t *testing.T) { + r := manual.NewBuilderWithScheme("whatever") + + // Start up an LB which will tells the client to fall back right away. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(0, "", nil) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + + // Start a standalone backend, to be used during fallback. The creds + // are intentionally misconfigured in order to simulate failure of a + // security handshake. + beLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + defer beLis.Close() + standaloneBEs := startBackends("arbitrary.invalid.name", true, beLis) + defer stopBackends(standaloneBEs) + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + defer cc.Close() + testC := testpb.NewTestServiceClient(cc) + + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + + // If https://github.com/grpc/grpc-go/blob/65cabd74d8e18d7347fecd414fa8d83a00035f5f/balancer/grpclb/grpclb_test.go#L103 + // changes, then expectedErrMsg may need to be updated. + const expectedErrMsg = "received unexpected server name" + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + var wg sync.WaitGroup + wg.Add(1) + go func() { + tss.ls.fallbackNow() + wg.Done() + }() + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err == nil || !strings.Contains(err.Error(), expectedErrMsg) { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, rpc error containing substring: %q", testC, err, expectedErrMsg) + } + wg.Wait() +} + +func (s) TestGRPCLBWithTargetNameFieldInConfig(t *testing.T) { + r := manual.NewBuilderWithScheme("whatever") + + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } + tss.ls.sls <- sl + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer), + grpc.WithUserAgent(testUserAgent)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + defer cc.Close() + testC := testpb.NewTestServiceClient(cc) + + // Push a resolver update with grpclb configuration which does not contain the + // target_name field. Our fake remote balancer is configured to always + // expect `beServerName` as the server name in the initial request. + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + } + + // When the value of target_field changes, grpclb will recreate the stream + // to the remote balancer. So, we need to update the fake remote balancer to + // expect a new server name in the initial request. + const newServerName = "new-server-name" + tss.ls.updateServerName(newServerName) + tss.ls.sls <- sl + + // Push the resolver update with target_field changed. + // Push a resolver update with grpclb configuration containing the + // target_name field. Our fake remote balancer has been updated above to expect the newServerName in the initial request. + lbCfg := fmt.Sprintf(`{"loadBalancingConfig": [{"grpclb": {"targetName": "%s"}}]}`, newServerName) + rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(lbCfg)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } + + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + } +} + type failPreRPCCred struct{} func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { @@ -1255,7 +1378,7 @@ func checkStats(stats, expected *rpcStats) error { func runAndCheckStats(t *testing.T, drop bool, statsChan chan *lbpb.ClientStats, runRPCs func(*grpc.ClientConn), statsWant *rpcStats) error { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, "", statsChan) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", statsChan) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } diff --git a/balancer/grpclb/grpclb_test_util_test.go b/balancer/grpclb/grpclb_test_util_test.go index 5d3e6ba7fed..c143e961754 100644 --- a/balancer/grpclb/grpclb_test_util_test.go +++ b/balancer/grpclb/grpclb_test_util_test.go @@ -48,19 +48,20 @@ func newRestartableListener(l net.Listener) *restartableListener { } } -func (l *restartableListener) Accept() (conn net.Conn, err error) { - conn, err = l.Listener.Accept() - if err == nil { - l.mu.Lock() - if l.closed { - conn.Close() - l.mu.Unlock() - return nil, &tempError{} - } - l.conns = append(l.conns, conn) - l.mu.Unlock() +func (l *restartableListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err } - return + + l.mu.Lock() + defer l.mu.Unlock() + if l.closed { + conn.Close() + return nil, &tempError{} + } + l.conns = append(l.conns, conn) + return conn, nil } func (l *restartableListener) Close() error { diff --git a/balancer/rls/internal/balancer.go b/balancer/rls/internal/balancer.go index 7af97b76faf..b23783bf9da 100644 --- a/balancer/rls/internal/balancer.go +++ b/balancer/rls/internal/balancer.go @@ -129,6 +129,10 @@ func (lb *rlsBalancer) Close() { } } +func (lb *rlsBalancer) ExitIdle() { + // TODO: are we 100% sure this should be a nop? +} + // updateControlChannel updates the RLS client if required. // Caller must hold lb.mu. func (lb *rlsBalancer) updateControlChannel(newCfg *lbConfig) { diff --git a/balancer/rls/internal/config.go b/balancer/rls/internal/config.go index a3deb8906c9..b27a2970f08 100644 --- a/balancer/rls/internal/config.go +++ b/balancer/rls/internal/config.go @@ -22,15 +22,16 @@ import ( "bytes" "encoding/json" "fmt" + "net/url" "time" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/ptypes" durationpb "github.com/golang/protobuf/ptypes/duration" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/rls/internal/keys" rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1" - "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) @@ -62,9 +63,10 @@ type lbConfig struct { staleAge time.Duration cacheSizeBytes int64 defaultTarget string - cpName string - cpTargetField string - cpConfig map[string]json.RawMessage + + childPolicyName string + childPolicyConfig map[string]json.RawMessage + childPolicyTargetField string } func (lbCfg *lbConfig) Equal(other *lbConfig) bool { @@ -75,21 +77,21 @@ func (lbCfg *lbConfig) Equal(other *lbConfig) bool { lbCfg.staleAge == other.staleAge && lbCfg.cacheSizeBytes == other.cacheSizeBytes && lbCfg.defaultTarget == other.defaultTarget && - lbCfg.cpName == other.cpName && - lbCfg.cpTargetField == other.cpTargetField && - cpConfigEqual(lbCfg.cpConfig, other.cpConfig) + lbCfg.childPolicyName == other.childPolicyName && + lbCfg.childPolicyTargetField == other.childPolicyTargetField && + childPolicyConfigEqual(lbCfg.childPolicyConfig, other.childPolicyConfig) } -func cpConfigEqual(am, bm map[string]json.RawMessage) bool { - if (bm == nil) != (am == nil) { +func childPolicyConfigEqual(a, b map[string]json.RawMessage) bool { + if (b == nil) != (a == nil) { return false } - if len(bm) != len(am) { + if len(b) != len(a) { return false } - for k, jsonA := range am { - jsonB, ok := bm[k] + for k, jsonA := range a { + jsonB, ok := b[k] if !ok { return false } @@ -100,71 +102,18 @@ func cpConfigEqual(am, bm map[string]json.RawMessage) bool { return true } -// This struct resembles the JSON respresentation of the loadBalancing config +// This struct resembles the JSON representation of the loadBalancing config // and makes it easier to unmarshal. type lbConfigJSON struct { RouteLookupConfig json.RawMessage - ChildPolicy []*loadBalancingConfig + ChildPolicy []map[string]json.RawMessage ChildPolicyConfigTargetFieldName string } -// loadBalancingConfig represents a single load balancing config, -// stored in JSON format. -// -// TODO(easwars): This code seems to be repeated in a few places -// (service_config.go and in the xds code as well). Refactor and re-use. -type loadBalancingConfig struct { - Name string - Config json.RawMessage -} - -// MarshalJSON returns a JSON encoding of l. -func (l *loadBalancingConfig) MarshalJSON() ([]byte, error) { - return nil, fmt.Errorf("rls: loadBalancingConfig.MarshalJSON() is unimplemented") -} - -// UnmarshalJSON parses the JSON-encoded byte slice in data and stores it in l. -func (l *loadBalancingConfig) UnmarshalJSON(data []byte) error { - var cfg map[string]json.RawMessage - if err := json.Unmarshal(data, &cfg); err != nil { - return err - } - for name, config := range cfg { - l.Name = name - l.Config = config - } - return nil -} - // ParseConfig parses and validates the JSON representation of the service // config and returns the loadBalancingConfig to be used by the RLS LB policy. // // Helps implement the balancer.ConfigParser interface. -// -// The following validation checks are performed: -// * routeLookupConfig: -// ** grpc_keybuilders field: -// - must have at least one entry -// - must not have two entries with the same Name -// - must not have any entry with a Name with the service field unset or -// empty -// - must not have any entries without a Name -// - must not have a headers entry that has required_match set -// - must not have two headers entries with the same key within one entry -// ** lookup_service field: -// - must be set and non-empty and must parse as a target URI -// ** max_age field: -// - if not specified or is greater than maxMaxAge, it will be reset to -// maxMaxAge -// ** stale_age field: -// - if the value is greater than or equal to max_age, it is ignored -// - if set, then max_age must also be set -// ** valid_targets field: -// - will be ignored -// ** cache_size_bytes field: -// - must be greater than zero -// - TODO(easwars): Define a minimum value for this field, to be used when -// left unspecified // * childPolicy field: // - must find a valid child policy with a valid config (the child policy must // be able to parse the provided config successfully when we pass it a dummy @@ -178,20 +127,58 @@ func (*rlsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, return nil, fmt.Errorf("rls: json unmarshal failed for service config {%+v}: %v", string(c), err) } + // Unmarshal and validate contents of the RLS proto. m := jsonpb.Unmarshaler{AllowUnknownFields: true} rlsProto := &rlspb.RouteLookupConfig{} if err := m.Unmarshal(bytes.NewReader(cfgJSON.RouteLookupConfig), rlsProto); err != nil { return nil, fmt.Errorf("rls: bad RouteLookupConfig proto {%+v}: %v", string(cfgJSON.RouteLookupConfig), err) } + lbCfg, err := parseRLSProto(rlsProto) + if err != nil { + return nil, err + } - var childPolicy *loadBalancingConfig - for _, lbcfg := range cfgJSON.ChildPolicy { - if balancer.Get(lbcfg.Name) != nil { - childPolicy = lbcfg - break - } + // Unmarshal and validate child policy configs. + if cfgJSON.ChildPolicyConfigTargetFieldName == "" { + return nil, fmt.Errorf("rls: childPolicyConfigTargetFieldName field is not set in service config {%+v}", string(c)) } + name, config, err := parseChildPolicyConfigs(cfgJSON.ChildPolicy, cfgJSON.ChildPolicyConfigTargetFieldName) + if err != nil { + return nil, err + } + lbCfg.childPolicyName = name + lbCfg.childPolicyConfig = config + lbCfg.childPolicyTargetField = cfgJSON.ChildPolicyConfigTargetFieldName + return lbCfg, nil +} +// parseRLSProto fetches relevant information from the RouteLookupConfig proto +// and validates the values in the process. +// +// The following validation checks are performed: +// ** grpc_keybuilders field: +// - must have at least one entry +// - must not have two entries with the same Name +// - must not have any entry with a Name with the service field unset or +// empty +// - must not have any entries without a Name +// - must not have a headers entry that has required_match set +// - must not have two headers entries with the same key within one entry +// ** lookup_service field: +// - must be set and non-empty and must parse as a target URI +// ** max_age field: +// - if not specified or is greater than maxMaxAge, it will be reset to +// maxMaxAge +// ** stale_age field: +// - if the value is greater than or equal to max_age, it is ignored +// - if set, then max_age must also be set +// ** valid_targets field: +// - will be ignored +// ** cache_size_bytes field: +// - must be greater than zero +// - TODO(easwars): Define a minimum value for this field, to be used when +// left unspecified +func parseRLSProto(rlsProto *rlspb.RouteLookupConfig) (*lbConfig, error) { kbMap, err := keys.MakeBuilderMap(rlsProto) if err != nil { return nil, err @@ -199,64 +186,54 @@ func (*rlsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, lookupService := rlsProto.GetLookupService() if lookupService == "" { - return nil, fmt.Errorf("rls: empty lookup_service in service config {%+v}", string(c)) + return nil, fmt.Errorf("rls: empty lookup_service in route lookup config {%+v}", rlsProto) + } + parsedTarget, err := url.Parse(lookupService) + if err != nil { + // If the first attempt failed because of a missing scheme, try again + // with the default scheme. + parsedTarget, err = url.Parse(resolver.GetDefaultScheme() + ":///" + lookupService) + if err != nil { + return nil, fmt.Errorf("rls: invalid target URI in lookup_service {%s}", lookupService) + } } - parsedTarget := grpcutil.ParseTarget(lookupService, false) if parsedTarget.Scheme == "" { parsedTarget.Scheme = resolver.GetDefaultScheme() } if resolver.Get(parsedTarget.Scheme) == nil { - return nil, fmt.Errorf("rls: invalid target URI in lookup_service {%s}", lookupService) + return nil, fmt.Errorf("rls: unregistered scheme in lookup_service {%s}", lookupService) } lookupServiceTimeout, err := convertDuration(rlsProto.GetLookupServiceTimeout()) if err != nil { - return nil, fmt.Errorf("rls: failed to parse lookup_service_timeout in service config {%+v}: %v", string(c), err) + return nil, fmt.Errorf("rls: failed to parse lookup_service_timeout in route lookup config {%+v}: %v", rlsProto, err) } if lookupServiceTimeout == 0 { lookupServiceTimeout = defaultLookupServiceTimeout } maxAge, err := convertDuration(rlsProto.GetMaxAge()) if err != nil { - return nil, fmt.Errorf("rls: failed to parse max_age in service config {%+v}: %v", string(c), err) + return nil, fmt.Errorf("rls: failed to parse max_age in route lookup config {%+v}: %v", rlsProto, err) } staleAge, err := convertDuration(rlsProto.GetStaleAge()) if err != nil { - return nil, fmt.Errorf("rls: failed to parse staleAge in service config {%+v}: %v", string(c), err) + return nil, fmt.Errorf("rls: failed to parse staleAge in route lookup config {%+v}: %v", rlsProto, err) } if staleAge != 0 && maxAge == 0 { - return nil, fmt.Errorf("rls: stale_age is set, but max_age is not in service config {%+v}", string(c)) + return nil, fmt.Errorf("rls: stale_age is set, but max_age is not in route lookup config {%+v}", rlsProto) } if staleAge >= maxAge { logger.Info("rls: stale_age {%v} is greater than max_age {%v}, ignoring it", staleAge, maxAge) staleAge = 0 } if maxAge == 0 || maxAge > maxMaxAge { - logger.Infof("rls: max_age in service config is %v, using %v", maxAge, maxMaxAge) + logger.Infof("rls: max_age in route lookup config is %v, using %v", maxAge, maxMaxAge) maxAge = maxMaxAge } cacheSizeBytes := rlsProto.GetCacheSizeBytes() if cacheSizeBytes <= 0 { - return nil, fmt.Errorf("rls: cache_size_bytes must be greater than 0 in service config {%+v}", string(c)) + return nil, fmt.Errorf("rls: cache_size_bytes must be greater than 0 in route lookup config {%+v}", rlsProto) } - if childPolicy == nil { - return nil, fmt.Errorf("rls: childPolicy is invalid in service config {%+v}", string(c)) - } - if cfgJSON.ChildPolicyConfigTargetFieldName == "" { - return nil, fmt.Errorf("rls: childPolicyConfigTargetFieldName field is not set in service config {%+v}", string(c)) - } - // TODO(easwars): When we start instantiating the child policy from the - // parent RLS LB policy, we could make this function a method on the - // lbConfig object and share the code. We would be parsing the child policy - // config again during that time. The only difference betweeen now and then - // would be that we would be using real targetField name instead of the - // dummy. So, we could make the targetName field a parameter to this - // function during the refactor. - cpCfg, err := validateChildPolicyConfig(childPolicy, cfgJSON.ChildPolicyConfigTargetFieldName) - if err != nil { - return nil, err - } - return &lbConfig{ kbMap: kbMap, lookupService: lookupService, @@ -265,57 +242,50 @@ func (*rlsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, staleAge: staleAge, cacheSizeBytes: cacheSizeBytes, defaultTarget: rlsProto.GetDefaultTarget(), - // TODO(easwars): Once we refactor validateChildPolicyConfig and make - // it a method on the lbConfig object, we could directly store the - // balancer.Builder and/or balancer.ConfigParser here instead of the - // Name. That would mean that we would have to create the lbConfig - // object here first before validating the childPolicy config, but - // that's a minor detail. - cpName: childPolicy.Name, - cpTargetField: cfgJSON.ChildPolicyConfigTargetFieldName, - cpConfig: cpCfg, }, nil } -// validateChildPolicyConfig validates the child policy config received in the -// service config. This makes it possible for us to reject service configs -// which contain invalid child policy configs which we know will fail for sure. -// -// It does the following: -// * Unmarshals the provided child policy config into a map of string to -// json.RawMessage. This allows us to add an entry to the map corresponding -// to the targetFieldName that we received in the service config. -// * Marshals the map back into JSON, finds the config parser associated with -// the child policy and asks it to validate the config. -// * If the validation succeeded, removes the dummy entry from the map and -// returns it. If any of the above steps failed, it returns an error. -func validateChildPolicyConfig(cp *loadBalancingConfig, cpTargetField string) (map[string]json.RawMessage, error) { - var childConfig map[string]json.RawMessage - if err := json.Unmarshal(cp.Config, &childConfig); err != nil { - return nil, fmt.Errorf("rls: json unmarshal failed for child policy config {%+v}: %v", cp.Config, err) - } - childConfig[cpTargetField], _ = json.Marshal(dummyChildPolicyTarget) +// parseChildPolicyConfigs iterates through the list of child policies and picks +// the first registered policy and validates its config. +func parseChildPolicyConfigs(childPolicies []map[string]json.RawMessage, targetFieldName string) (string, map[string]json.RawMessage, error) { + for i, config := range childPolicies { + if len(config) != 1 { + return "", nil, fmt.Errorf("rls: invalid childPolicy: entry %v does not contain exactly 1 policy/config pair: %q", i, config) + } - jsonCfg, err := json.Marshal(childConfig) - if err != nil { - return nil, fmt.Errorf("rls: json marshal failed for child policy config {%+v}: %v", childConfig, err) - } - builder := balancer.Get(cp.Name) - if builder == nil { - // This should never happen since we already made sure that the child - // policy name mentioned in the service config is a valid one. - return nil, fmt.Errorf("rls: balancer builder not found for child_policy %v", cp.Name) - } - parser, ok := builder.(balancer.ConfigParser) - if !ok { - return nil, fmt.Errorf("rls: balancer builder for child_policy does not implement balancer.ConfigParser: %v", cp.Name) - } - _, err = parser.ParseConfig(jsonCfg) - if err != nil { - return nil, fmt.Errorf("rls: childPolicy config validation failed: %v", err) + var name string + var rawCfg json.RawMessage + for name, rawCfg = range config { + } + builder := balancer.Get(name) + if builder == nil { + continue + } + parser, ok := builder.(balancer.ConfigParser) + if !ok { + return "", nil, fmt.Errorf("rls: childPolicy %q with config %q does not support config parsing", name, string(rawCfg)) + } + + // To validate child policy configs we do the following: + // - unmarshal the raw JSON bytes of the child policy config into a map + // - add an entry with key set to `target_field_name` and a dummy value + // - marshal the map back to JSON and parse the config using the parser + // retrieved previously + var childConfig map[string]json.RawMessage + if err := json.Unmarshal(rawCfg, &childConfig); err != nil { + return "", nil, fmt.Errorf("rls: json unmarshal failed for child policy config %q: %v", string(rawCfg), err) + } + childConfig[targetFieldName], _ = json.Marshal(dummyChildPolicyTarget) + jsonCfg, err := json.Marshal(childConfig) + if err != nil { + return "", nil, fmt.Errorf("rls: json marshal failed for child policy config {%+v}: %v", childConfig, err) + } + if _, err := parser.ParseConfig(jsonCfg); err != nil { + return "", nil, fmt.Errorf("rls: childPolicy config validation failed: %v", err) + } + return name, childConfig, nil } - delete(childConfig, cpTargetField) - return childConfig, nil + return "", nil, fmt.Errorf("rls: invalid childPolicy config: no supported policies found in %+v", childPolicies) } func convertDuration(d *durationpb.Duration) (time.Duration, error) { diff --git a/balancer/rls/internal/config_test.go b/balancer/rls/internal/config_test.go index 1efd054512b..41d330c604e 100644 --- a/balancer/rls/internal/config_test.go +++ b/balancer/rls/internal/config_test.go @@ -25,8 +25,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer" _ "google.golang.org/grpc/balancer/grpclb" // grpclb for config parsing. _ "google.golang.org/grpc/internal/resolver/passthrough" // passthrough resolver. @@ -58,12 +56,13 @@ func testEqual(a, b *lbConfig) bool { a.staleAge == b.staleAge && a.cacheSizeBytes == b.cacheSizeBytes && a.defaultTarget == b.defaultTarget && - a.cpName == b.cpName && - a.cpTargetField == b.cpTargetField && - cmp.Equal(a.cpConfig, b.cpConfig) + a.childPolicyName == b.childPolicyName && + a.childPolicyTargetField == b.childPolicyTargetField && + childPolicyConfigEqual(a.childPolicyConfig, b.childPolicyConfig) } func TestParseConfig(t *testing.T) { + childPolicyTargetFieldVal, _ := json.Marshal(dummyChildPolicyTarget) tests := []struct { desc string input []byte @@ -85,7 +84,7 @@ func TestParseConfig(t *testing.T) { "names": [{"service": "service", "method": "method"}], "headers": [{"key": "k1", "names": ["v1"]}] }], - "lookupService": "passthrough:///target", + "lookupService": ":///target", "maxAge" : "500s", "staleAge": "600s", "cacheSizeBytes": 1000, @@ -99,15 +98,18 @@ func TestParseConfig(t *testing.T) { "childPolicyConfigTargetFieldName": "service_name" }`), wantCfg: &lbConfig{ - lookupService: "passthrough:///target", - lookupServiceTimeout: 10 * time.Second, // This is the default value. - maxAge: 5 * time.Minute, // This is max maxAge. - staleAge: time.Duration(0), // StaleAge is ignore because it was higher than maxAge. - cacheSizeBytes: 1000, - defaultTarget: "passthrough:///default", - cpName: "grpclb", - cpTargetField: "service_name", - cpConfig: map[string]json.RawMessage{"childPolicy": json.RawMessage(`[{"pickfirst": {}}]`)}, + lookupService: ":///target", + lookupServiceTimeout: 10 * time.Second, // This is the default value. + maxAge: 5 * time.Minute, // This is max maxAge. + staleAge: time.Duration(0), // StaleAge is ignore because it was higher than maxAge. + cacheSizeBytes: 1000, + defaultTarget: "passthrough:///default", + childPolicyName: "grpclb", + childPolicyTargetField: "service_name", + childPolicyConfig: map[string]json.RawMessage{ + "childPolicy": json.RawMessage(`[{"pickfirst": {}}]`), + "service_name": json.RawMessage(childPolicyTargetFieldVal), + }, }, }, { @@ -118,7 +120,7 @@ func TestParseConfig(t *testing.T) { "names": [{"service": "service", "method": "method"}], "headers": [{"key": "k1", "names": ["v1"]}] }], - "lookupService": "passthrough:///target", + "lookupService": "target", "lookupServiceTimeout" : "100s", "maxAge": "60s", "staleAge" : "50s", @@ -129,15 +131,18 @@ func TestParseConfig(t *testing.T) { "childPolicyConfigTargetFieldName": "service_name" }`), wantCfg: &lbConfig{ - lookupService: "passthrough:///target", - lookupServiceTimeout: 100 * time.Second, - maxAge: 60 * time.Second, - staleAge: 50 * time.Second, - cacheSizeBytes: 1000, - defaultTarget: "passthrough:///default", - cpName: "grpclb", - cpTargetField: "service_name", - cpConfig: map[string]json.RawMessage{"childPolicy": json.RawMessage(`[{"pickfirst": {}}]`)}, + lookupService: "target", + lookupServiceTimeout: 100 * time.Second, + maxAge: 60 * time.Second, + staleAge: 50 * time.Second, + cacheSizeBytes: 1000, + defaultTarget: "passthrough:///default", + childPolicyName: "grpclb", + childPolicyTargetField: "service_name", + childPolicyConfig: map[string]json.RawMessage{ + "childPolicy": json.RawMessage(`[{"pickfirst": {}}]`), + "service_name": json.RawMessage(childPolicyTargetFieldVal), + }, }, }, } @@ -191,10 +196,10 @@ func TestParseConfigErrors(t *testing.T) { }] } }`), - wantErr: "rls: empty lookup_service in service config", + wantErr: "rls: empty lookup_service in route lookup config", }, { - desc: "invalid lookup service URI", + desc: "unregistered scheme in lookup service URI", input: []byte(`{ "routeLookupConfig": { "grpcKeybuilders": [{ @@ -204,7 +209,7 @@ func TestParseConfigErrors(t *testing.T) { "lookupService": "badScheme:///target" } }`), - wantErr: "rls: invalid target URI in lookup_service", + wantErr: "rls: unregistered scheme in lookup_service", }, { desc: "invalid lookup service timeout", @@ -264,7 +269,7 @@ func TestParseConfigErrors(t *testing.T) { "staleAge" : "10s" } }`), - wantErr: "rls: stale_age is set, but max_age is not in service config", + wantErr: "rls: stale_age is set, but max_age is not in route lookup config", }, { desc: "invalid cache size", @@ -280,7 +285,7 @@ func TestParseConfigErrors(t *testing.T) { "staleAge" : "25s" } }`), - wantErr: "rls: cache_size_bytes must be greater than 0 in service config", + wantErr: "rls: cache_size_bytes must be greater than 0 in route lookup config", }, { desc: "no child policy", @@ -296,9 +301,10 @@ func TestParseConfigErrors(t *testing.T) { "staleAge" : "25s", "cacheSizeBytes": 1000, "defaultTarget": "passthrough:///default" - } + }, + "childPolicyConfigTargetFieldName": "service_name" }`), - wantErr: "rls: childPolicy is invalid in service config", + wantErr: "rls: invalid childPolicy config: no supported policies found", }, { desc: "no known child policy", @@ -318,9 +324,35 @@ func TestParseConfigErrors(t *testing.T) { "childPolicy": [ {"cds_experimental": {"Cluster": "my-fav-cluster"}}, {"unknown-policy": {"unknown-field": "unknown-value"}} - ] + ], + "childPolicyConfigTargetFieldName": "service_name" + }`), + wantErr: "rls: invalid childPolicy config: no supported policies found", + }, + { + desc: "invalid child policy config - more than one entry in map", + input: []byte(`{ + "routeLookupConfig": { + "grpcKeybuilders": [{ + "names": [{"service": "service", "method": "method"}], + "headers": [{"key": "k1", "names": ["v1"]}] + }], + "lookupService": "passthrough:///target", + "lookupServiceTimeout" : "10s", + "maxAge": "30s", + "staleAge" : "25s", + "cacheSizeBytes": 1000, + "defaultTarget": "passthrough:///default" + }, + "childPolicy": [ + { + "cds_experimental": {"Cluster": "my-fav-cluster"}, + "unknown-policy": {"unknown-field": "unknown-value"} + } + ], + "childPolicyConfigTargetFieldName": "service_name" }`), - wantErr: "rls: childPolicy is invalid in service config", + wantErr: "does not contain exactly 1 policy/config pair", }, { desc: "no childPolicyConfigTargetFieldName", @@ -381,60 +413,3 @@ func TestParseConfigErrors(t *testing.T) { }) } } - -func TestValidateChildPolicyConfig(t *testing.T) { - jsonCfg := json.RawMessage(`[{"round_robin" : {}}, {"pick_first" : {}}]`) - wantChildConfig := map[string]json.RawMessage{"childPolicy": jsonCfg} - cp := &loadBalancingConfig{ - Name: "grpclb", - Config: []byte(`{"childPolicy": [{"round_robin" : {}}, {"pick_first" : {}}]}`), - } - cpTargetField := "serviceName" - - gotChildConfig, err := validateChildPolicyConfig(cp, cpTargetField) - if err != nil || !cmp.Equal(gotChildConfig, wantChildConfig) { - t.Errorf("validateChildPolicyConfig(%v, %v) = {%v, %v}, want {%v, nil}", cp, cpTargetField, gotChildConfig, err, wantChildConfig) - } -} - -func TestValidateChildPolicyConfigErrors(t *testing.T) { - tests := []struct { - desc string - cp *loadBalancingConfig - wantErrPrefix string - }{ - { - desc: "unknown child policy", - cp: &loadBalancingConfig{ - Name: "unknown", - Config: []byte(`{}`), - }, - wantErrPrefix: "rls: balancer builder not found for child_policy", - }, - { - desc: "balancer builder does not implement ConfigParser", - cp: &loadBalancingConfig{ - Name: balancerWithoutConfigParserName, - Config: []byte(`{}`), - }, - wantErrPrefix: "rls: balancer builder for child_policy does not implement balancer.ConfigParser", - }, - { - desc: "child policy config parsing failure", - cp: &loadBalancingConfig{ - Name: "grpclb", - Config: []byte(`{"childPolicy": "not-an-array"}`), - }, - wantErrPrefix: "rls: childPolicy config validation failed", - }, - } - - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - gotChildConfig, gotErr := validateChildPolicyConfig(test.cp, "") - if gotChildConfig != nil || !strings.HasPrefix(fmt.Sprint(gotErr), test.wantErrPrefix) { - t.Errorf("validateChildPolicyConfig(%v) = {%v, %v}, want {nil, %v}", test.cp, gotChildConfig, gotErr, test.wantErrPrefix) - } - }) - } -} diff --git a/balancer/rls/internal/keys/builder.go b/balancer/rls/internal/keys/builder.go index 5ce5a9da508..24767b405f0 100644 --- a/balancer/rls/internal/keys/builder.go +++ b/balancer/rls/internal/keys/builder.go @@ -218,7 +218,7 @@ func (b builder) keys(md metadata.MD) KeyMap { } func mapToString(kv map[string]string) string { - var keys []string + keys := make([]string, 0, len(kv)) for k := range kv { keys = append(keys, k) } diff --git a/balancer/rls/internal/proto/grpc_lookup_v1/rls.pb.go b/balancer/rls/internal/proto/grpc_lookup_v1/rls.pb.go index d48a1a6de84..9f063fd3d62 100644 --- a/balancer/rls/internal/proto/grpc_lookup_v1/rls.pb.go +++ b/balancer/rls/internal/proto/grpc_lookup_v1/rls.pb.go @@ -39,6 +39,56 @@ const ( // of the legacy proto package is being used. const _ = proto.ProtoPackageIsVersion4 +// Possible reasons for making a request. +type RouteLookupRequest_Reason int32 + +const ( + RouteLookupRequest_REASON_UNKNOWN RouteLookupRequest_Reason = 0 // Unused + RouteLookupRequest_REASON_MISS RouteLookupRequest_Reason = 1 // No data available in local cache + RouteLookupRequest_REASON_STALE RouteLookupRequest_Reason = 2 // Data in local cache is stale +) + +// Enum value maps for RouteLookupRequest_Reason. +var ( + RouteLookupRequest_Reason_name = map[int32]string{ + 0: "REASON_UNKNOWN", + 1: "REASON_MISS", + 2: "REASON_STALE", + } + RouteLookupRequest_Reason_value = map[string]int32{ + "REASON_UNKNOWN": 0, + "REASON_MISS": 1, + "REASON_STALE": 2, + } +) + +func (x RouteLookupRequest_Reason) Enum() *RouteLookupRequest_Reason { + p := new(RouteLookupRequest_Reason) + *p = x + return p +} + +func (x RouteLookupRequest_Reason) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (RouteLookupRequest_Reason) Descriptor() protoreflect.EnumDescriptor { + return file_grpc_lookup_v1_rls_proto_enumTypes[0].Descriptor() +} + +func (RouteLookupRequest_Reason) Type() protoreflect.EnumType { + return &file_grpc_lookup_v1_rls_proto_enumTypes[0] +} + +func (x RouteLookupRequest_Reason) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use RouteLookupRequest_Reason.Descriptor instead. +func (RouteLookupRequest_Reason) EnumDescriptor() ([]byte, []int) { + return file_grpc_lookup_v1_rls_proto_rawDescGZIP(), []int{0, 0} +} + type RouteLookupRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -46,13 +96,23 @@ type RouteLookupRequest struct { // Full host name of the target server, e.g. firestore.googleapis.com. // Only set for gRPC requests; HTTP requests must use key_map explicitly. + // Deprecated in favor of setting key_map keys with GrpcKeyBuilder.extra_keys. + // + // Deprecated: Do not use. Server string `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"` // Full path of the request, i.e. "/service/method". // Only set for gRPC requests; HTTP requests must use key_map explicitly. + // Deprecated in favor of setting key_map keys with GrpcKeyBuilder.extra_keys. + // + // Deprecated: Do not use. Path string `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"` // Target type allows the client to specify what kind of target format it // would like from RLS to allow it to find the regional server, e.g. "grpc". TargetType string `protobuf:"bytes,3,opt,name=target_type,json=targetType,proto3" json:"target_type,omitempty"` + // Reason for making this request. + Reason RouteLookupRequest_Reason `protobuf:"varint,5,opt,name=reason,proto3,enum=grpc.lookup.v1.RouteLookupRequest_Reason" json:"reason,omitempty"` + // For REASON_STALE, the header_data from the stale response, if any. + StaleHeaderData string `protobuf:"bytes,6,opt,name=stale_header_data,json=staleHeaderData,proto3" json:"stale_header_data,omitempty"` // Map of key values extracted via key builders for the gRPC or HTTP request. KeyMap map[string]string `protobuf:"bytes,4,rep,name=key_map,json=keyMap,proto3" json:"key_map,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } @@ -89,6 +149,7 @@ func (*RouteLookupRequest) Descriptor() ([]byte, []int) { return file_grpc_lookup_v1_rls_proto_rawDescGZIP(), []int{0} } +// Deprecated: Do not use. func (x *RouteLookupRequest) GetServer() string { if x != nil { return x.Server @@ -96,6 +157,7 @@ func (x *RouteLookupRequest) GetServer() string { return "" } +// Deprecated: Do not use. func (x *RouteLookupRequest) GetPath() string { if x != nil { return x.Path @@ -110,6 +172,20 @@ func (x *RouteLookupRequest) GetTargetType() string { return "" } +func (x *RouteLookupRequest) GetReason() RouteLookupRequest_Reason { + if x != nil { + return x.Reason + } + return RouteLookupRequest_REASON_UNKNOWN +} + +func (x *RouteLookupRequest) GetStaleHeaderData() string { + if x != nil { + return x.StaleHeaderData + } + return "" +} + func (x *RouteLookupRequest) GetKeyMap() map[string]string { if x != nil { return x.KeyMap @@ -183,40 +259,52 @@ var File_grpc_lookup_v1_rls_proto protoreflect.FileDescriptor var file_grpc_lookup_v1_rls_proto_rawDesc = []byte{ 0x0a, 0x18, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2f, 0x76, 0x31, 0x2f, 0x72, 0x6c, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e, 0x67, 0x72, 0x70, 0x63, - 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x22, 0xe5, 0x01, 0x0a, 0x12, 0x52, + 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x22, 0x9d, 0x03, 0x0a, 0x12, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, - 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, - 0x0b, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0a, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x47, - 0x0a, 0x07, 0x6b, 0x65, 0x79, 0x5f, 0x6d, 0x61, 0x70, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x2e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, + 0x74, 0x12, 0x1a, 0x0a, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x16, 0x0a, + 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, + 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x5f, + 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x74, 0x61, 0x72, 0x67, + 0x65, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x41, 0x0a, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x29, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, + 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, + 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x52, 0x65, 0x61, 0x73, 0x6f, + 0x6e, 0x52, 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x12, 0x2a, 0x0a, 0x11, 0x73, 0x74, 0x61, + 0x6c, 0x65, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x74, 0x61, 0x6c, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x44, 0x61, 0x74, 0x61, 0x12, 0x47, 0x0a, 0x07, 0x6b, 0x65, 0x79, 0x5f, 0x6d, 0x61, 0x70, + 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, + 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, + 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x4b, 0x65, 0x79, 0x4d, 0x61, + 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x06, 0x6b, 0x65, 0x79, 0x4d, 0x61, 0x70, 0x1a, 0x39, + 0x0a, 0x0b, 0x4b, 0x65, 0x79, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, + 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, + 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x3f, 0x0a, 0x06, 0x52, 0x65, 0x61, + 0x73, 0x6f, 0x6e, 0x12, 0x12, 0x0a, 0x0e, 0x52, 0x45, 0x41, 0x53, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, + 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x52, 0x45, 0x41, 0x53, 0x4f, + 0x4e, 0x5f, 0x4d, 0x49, 0x53, 0x53, 0x10, 0x01, 0x12, 0x10, 0x0a, 0x0c, 0x52, 0x45, 0x41, 0x53, + 0x4f, 0x4e, 0x5f, 0x53, 0x54, 0x41, 0x4c, 0x45, 0x10, 0x02, 0x22, 0x5e, 0x0a, 0x13, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x18, 0x0a, 0x07, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x07, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x68, + 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x44, 0x61, 0x74, 0x61, 0x4a, 0x04, 0x08, 0x01, + 0x10, 0x02, 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x32, 0x6e, 0x0a, 0x12, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x12, 0x58, 0x0a, 0x0b, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x12, + 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x2e, 0x4b, 0x65, 0x79, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, - 0x06, 0x6b, 0x65, 0x79, 0x4d, 0x61, 0x70, 0x1a, 0x39, 0x0a, 0x0b, 0x4b, 0x65, 0x79, 0x4d, 0x61, - 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, - 0x38, 0x01, 0x22, 0x5e, 0x0a, 0x13, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, - 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x74, 0x61, 0x72, - 0x67, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x74, 0x61, 0x72, 0x67, - 0x65, 0x74, 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x64, 0x61, - 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, - 0x44, 0x61, 0x74, 0x61, 0x4a, 0x04, 0x08, 0x01, 0x10, 0x02, 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, - 0x65, 0x74, 0x32, 0x6e, 0x0a, 0x12, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, - 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x58, 0x0a, 0x0b, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x12, 0x22, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, - 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, - 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x72, - 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x42, 0x4d, 0x0a, 0x11, 0x69, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, - 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x42, 0x08, 0x52, 0x6c, 0x73, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, - 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x6c, 0x6f, 0x6f, 0x6b, - 0x75, 0x70, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x76, - 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, + 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x4d, 0x0a, 0x11, 0x69, 0x6f, + 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x42, + 0x08, 0x52, 0x6c, 0x73, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, + 0x72, 0x70, 0x63, 0x2f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, + 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( @@ -231,21 +319,24 @@ func file_grpc_lookup_v1_rls_proto_rawDescGZIP() []byte { return file_grpc_lookup_v1_rls_proto_rawDescData } +var file_grpc_lookup_v1_rls_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_grpc_lookup_v1_rls_proto_msgTypes = make([]protoimpl.MessageInfo, 3) var file_grpc_lookup_v1_rls_proto_goTypes = []interface{}{ - (*RouteLookupRequest)(nil), // 0: grpc.lookup.v1.RouteLookupRequest - (*RouteLookupResponse)(nil), // 1: grpc.lookup.v1.RouteLookupResponse - nil, // 2: grpc.lookup.v1.RouteLookupRequest.KeyMapEntry + (RouteLookupRequest_Reason)(0), // 0: grpc.lookup.v1.RouteLookupRequest.Reason + (*RouteLookupRequest)(nil), // 1: grpc.lookup.v1.RouteLookupRequest + (*RouteLookupResponse)(nil), // 2: grpc.lookup.v1.RouteLookupResponse + nil, // 3: grpc.lookup.v1.RouteLookupRequest.KeyMapEntry } var file_grpc_lookup_v1_rls_proto_depIdxs = []int32{ - 2, // 0: grpc.lookup.v1.RouteLookupRequest.key_map:type_name -> grpc.lookup.v1.RouteLookupRequest.KeyMapEntry - 0, // 1: grpc.lookup.v1.RouteLookupService.RouteLookup:input_type -> grpc.lookup.v1.RouteLookupRequest - 1, // 2: grpc.lookup.v1.RouteLookupService.RouteLookup:output_type -> grpc.lookup.v1.RouteLookupResponse - 2, // [2:3] is the sub-list for method output_type - 1, // [1:2] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 0, // 0: grpc.lookup.v1.RouteLookupRequest.reason:type_name -> grpc.lookup.v1.RouteLookupRequest.Reason + 3, // 1: grpc.lookup.v1.RouteLookupRequest.key_map:type_name -> grpc.lookup.v1.RouteLookupRequest.KeyMapEntry + 1, // 2: grpc.lookup.v1.RouteLookupService.RouteLookup:input_type -> grpc.lookup.v1.RouteLookupRequest + 2, // 3: grpc.lookup.v1.RouteLookupService.RouteLookup:output_type -> grpc.lookup.v1.RouteLookupResponse + 3, // [3:4] is the sub-list for method output_type + 2, // [2:3] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name } func init() { file_grpc_lookup_v1_rls_proto_init() } @@ -284,13 +375,14 @@ func file_grpc_lookup_v1_rls_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_grpc_lookup_v1_rls_proto_rawDesc, - NumEnums: 0, + NumEnums: 1, NumMessages: 3, NumExtensions: 0, NumServices: 1, }, GoTypes: file_grpc_lookup_v1_rls_proto_goTypes, DependencyIndexes: file_grpc_lookup_v1_rls_proto_depIdxs, + EnumInfos: file_grpc_lookup_v1_rls_proto_enumTypes, MessageInfos: file_grpc_lookup_v1_rls_proto_msgTypes, }.Build() File_grpc_lookup_v1_rls_proto = out.File diff --git a/balancer/rls/internal/proto/grpc_lookup_v1/rls_config.pb.go b/balancer/rls/internal/proto/grpc_lookup_v1/rls_config.pb.go index 6b0924b335f..414b74cdb3b 100644 --- a/balancer/rls/internal/proto/grpc_lookup_v1/rls_config.pb.go +++ b/balancer/rls/internal/proto/grpc_lookup_v1/rls_config.pb.go @@ -50,6 +50,9 @@ type NameMatcher struct { unknownFields protoimpl.UnknownFields // The name that will be used in the RLS key_map to refer to this value. + // If required_match is true, you may omit this field or set it to an empty + // string, in which case the matcher will require a match, but won't update + // the key_map. Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` // Ordered list of names (headers or query parameter names) that can supply // this value; the first one with a non-empty value is used. @@ -118,11 +121,17 @@ type GrpcKeyBuilder struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Names []*GrpcKeyBuilder_Name `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"` + Names []*GrpcKeyBuilder_Name `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"` + ExtraKeys *GrpcKeyBuilder_ExtraKeys `protobuf:"bytes,3,opt,name=extra_keys,json=extraKeys,proto3" json:"extra_keys,omitempty"` // Extract keys from all listed headers. // For gRPC, it is an error to specify "required_match" on the NameMatcher // protos. Headers []*NameMatcher `protobuf:"bytes,2,rep,name=headers,proto3" json:"headers,omitempty"` + // You can optionally set one or more specific key/value pairs to be added to + // the key_map. This can be useful to identify which builder built the key, + // for example if you are suppressing the actual method, but need to + // separately cache and request all the matched methods. + ConstantKeys map[string]string `protobuf:"bytes,4,rep,name=constant_keys,json=constantKeys,proto3" json:"constant_keys,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } func (x *GrpcKeyBuilder) Reset() { @@ -164,6 +173,13 @@ func (x *GrpcKeyBuilder) GetNames() []*GrpcKeyBuilder_Name { return nil } +func (x *GrpcKeyBuilder) GetExtraKeys() *GrpcKeyBuilder_ExtraKeys { + if x != nil { + return x.ExtraKeys + } + return nil +} + func (x *GrpcKeyBuilder) GetHeaders() []*NameMatcher { if x != nil { return x.Headers @@ -171,6 +187,13 @@ func (x *GrpcKeyBuilder) GetHeaders() []*NameMatcher { return nil } +func (x *GrpcKeyBuilder) GetConstantKeys() map[string]string { + if x != nil { + return x.ConstantKeys + } + return nil +} + // An HttpKeyBuilder applies to a given HTTP URL and headers. // // Path and host patterns use the matching syntax from gRPC transcoding to @@ -245,6 +268,11 @@ type HttpKeyBuilder struct { // to match. If a given header appears multiple times in the request we will // report it as a comma-separated string, in standard HTTP fashion. Headers []*NameMatcher `protobuf:"bytes,4,rep,name=headers,proto3" json:"headers,omitempty"` + // You can optionally set one or more specific key/value pairs to be added to + // the key_map. This can be useful to identify which builder built the key, + // for example if you are suppressing a lot of information from the URL, but + // need to separately cache and request URLs with that content. + ConstantKeys map[string]string `protobuf:"bytes,5,rep,name=constant_keys,json=constantKeys,proto3" json:"constant_keys,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } func (x *HttpKeyBuilder) Reset() { @@ -307,6 +335,13 @@ func (x *HttpKeyBuilder) GetHeaders() []*NameMatcher { return nil } +func (x *HttpKeyBuilder) GetConstantKeys() map[string]string { + if x != nil { + return x.ConstantKeys + } + return nil +} + type RouteLookupConfig struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -510,6 +545,75 @@ func (x *GrpcKeyBuilder_Name) GetMethod() string { return "" } +// If you wish to include the host, service, or method names as keys in the +// generated RouteLookupRequest, specify key names to use in the extra_keys +// submessage. If a key name is empty, no key will be set for that value. +// If this submessage is specified, the normal host/path fields will be left +// unset in the RouteLookupRequest. We are deprecating host/path in the +// RouteLookupRequest, so services should migrate to the ExtraKeys approach. +type GrpcKeyBuilder_ExtraKeys struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Host string `protobuf:"bytes,1,opt,name=host,proto3" json:"host,omitempty"` + Service string `protobuf:"bytes,2,opt,name=service,proto3" json:"service,omitempty"` + Method string `protobuf:"bytes,3,opt,name=method,proto3" json:"method,omitempty"` +} + +func (x *GrpcKeyBuilder_ExtraKeys) Reset() { + *x = GrpcKeyBuilder_ExtraKeys{} + if protoimpl.UnsafeEnabled { + mi := &file_grpc_lookup_v1_rls_config_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GrpcKeyBuilder_ExtraKeys) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GrpcKeyBuilder_ExtraKeys) ProtoMessage() {} + +func (x *GrpcKeyBuilder_ExtraKeys) ProtoReflect() protoreflect.Message { + mi := &file_grpc_lookup_v1_rls_config_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GrpcKeyBuilder_ExtraKeys.ProtoReflect.Descriptor instead. +func (*GrpcKeyBuilder_ExtraKeys) Descriptor() ([]byte, []int) { + return file_grpc_lookup_v1_rls_config_proto_rawDescGZIP(), []int{1, 1} +} + +func (x *GrpcKeyBuilder_ExtraKeys) GetHost() string { + if x != nil { + return x.Host + } + return "" +} + +func (x *GrpcKeyBuilder_ExtraKeys) GetService() string { + if x != nil { + return x.Service + } + return "" +} + +func (x *GrpcKeyBuilder_ExtraKeys) GetMethod() string { + if x != nil { + return x.Method + } + return "" +} + var File_grpc_lookup_v1_rls_config_proto protoreflect.FileDescriptor var file_grpc_lookup_v1_rls_config_proto_rawDesc = []byte{ @@ -524,72 +628,101 @@ var file_grpc_lookup_v1_rls_config_proto_rawDesc = []byte{ 0x09, 0x52, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x5f, 0x6d, 0x61, 0x74, 0x63, 0x68, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x22, - 0xbc, 0x01, 0x0a, 0x0e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, + 0xf0, 0x03, 0x0a, 0x0e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, - 0x72, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x52, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x35, 0x0a, - 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, - 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, - 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x52, 0x07, 0x68, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x73, 0x1a, 0x38, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, + 0x72, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x52, 0x05, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x47, 0x0a, + 0x0a, 0x65, 0x78, 0x74, 0x72, 0x61, 0x5f, 0x6b, 0x65, 0x79, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x28, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, + 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, + 0x72, 0x2e, 0x45, 0x78, 0x74, 0x72, 0x61, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x09, 0x65, 0x78, 0x74, + 0x72, 0x61, 0x4b, 0x65, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, + 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, 0x74, + 0x63, 0x68, 0x65, 0x72, 0x52, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x12, 0x55, 0x0a, + 0x0d, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x5f, 0x6b, 0x65, 0x79, 0x73, 0x18, 0x04, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, + 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, + 0x6c, 0x64, 0x65, 0x72, 0x2e, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4b, 0x65, 0x79, + 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, + 0x4b, 0x65, 0x79, 0x73, 0x1a, 0x38, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x22, 0xd9, - 0x01, 0x0a, 0x0e, 0x48, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, - 0x72, 0x12, 0x23, 0x0a, 0x0d, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x74, 0x74, 0x65, 0x72, - 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x68, 0x6f, 0x73, 0x74, 0x50, 0x61, - 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x70, 0x61, 0x74, 0x68, 0x5f, 0x70, - 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x70, - 0x61, 0x74, 0x68, 0x50, 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x71, - 0x75, 0x65, 0x72, 0x79, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, - 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, - 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, - 0x65, 0x72, 0x52, 0x0f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, - 0x65, 0x72, 0x73, 0x12, 0x35, 0x0a, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, - 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x65, - 0x72, 0x52, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x22, 0xa6, 0x04, 0x0a, 0x11, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x49, 0x0a, 0x10, 0x68, 0x74, 0x74, 0x70, 0x5f, 0x6b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, - 0x64, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x72, 0x70, - 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x74, 0x74, 0x70, - 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x52, 0x0f, 0x68, 0x74, 0x74, 0x70, - 0x4b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x12, 0x49, 0x0a, 0x10, 0x67, - 0x72, 0x70, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, - 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, - 0x69, 0x6c, 0x64, 0x65, 0x72, 0x52, 0x0f, 0x67, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x62, 0x75, - 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, - 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, - 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x4f, 0x0a, - 0x16, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, - 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x14, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, - 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x32, - 0x0a, 0x07, 0x6d, 0x61, 0x78, 0x5f, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x6d, 0x61, 0x78, 0x41, - 0x67, 0x65, 0x12, 0x36, 0x0a, 0x09, 0x73, 0x74, 0x61, 0x6c, 0x65, 0x5f, 0x61, 0x67, 0x65, 0x18, - 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x52, 0x08, 0x73, 0x74, 0x61, 0x6c, 0x65, 0x41, 0x67, 0x65, 0x12, 0x28, 0x0a, 0x10, 0x63, 0x61, - 0x63, 0x68, 0x65, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x0e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x42, - 0x79, 0x74, 0x65, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x5f, 0x74, 0x61, - 0x72, 0x67, 0x65, 0x74, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x76, 0x61, 0x6c, - 0x69, 0x64, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x66, - 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0d, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, - 0x4a, 0x04, 0x08, 0x0a, 0x10, 0x0b, 0x52, 0x1b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, - 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x74, 0x72, 0x61, 0x74, - 0x65, 0x67, 0x79, 0x42, 0x53, 0x0a, 0x11, 0x69, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, - 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x42, 0x0e, 0x52, 0x6c, 0x73, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x72, - 0x70, 0x63, 0x2f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x6c, - 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x1a, 0x51, + 0x0a, 0x09, 0x45, 0x78, 0x74, 0x72, 0x61, 0x4b, 0x65, 0x79, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x68, + 0x6f, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, + 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, + 0x68, 0x6f, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, + 0x64, 0x1a, 0x3f, 0x0a, 0x11, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4b, 0x65, 0x79, + 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, + 0x38, 0x01, 0x22, 0xf1, 0x02, 0x0a, 0x0e, 0x48, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x42, 0x75, + 0x69, 0x6c, 0x64, 0x65, 0x72, 0x12, 0x23, 0x0a, 0x0d, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x70, 0x61, + 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x68, 0x6f, + 0x73, 0x74, 0x50, 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x70, 0x61, + 0x74, 0x68, 0x5f, 0x70, 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0c, 0x70, 0x61, 0x74, 0x68, 0x50, 0x61, 0x74, 0x74, 0x65, 0x72, 0x6e, 0x73, 0x12, + 0x46, 0x0a, 0x10, 0x71, 0x75, 0x65, 0x72, 0x79, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, + 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, + 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, + 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x52, 0x0f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x50, 0x61, 0x72, + 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, 0x35, 0x0a, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, + 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x4d, 0x61, + 0x74, 0x63, 0x68, 0x65, 0x72, 0x52, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x12, 0x55, + 0x0a, 0x0d, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x5f, 0x6b, 0x65, 0x79, 0x73, 0x18, + 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, + 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x42, 0x75, + 0x69, 0x6c, 0x64, 0x65, 0x72, 0x2e, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4b, 0x65, + 0x79, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, + 0x74, 0x4b, 0x65, 0x79, 0x73, 0x1a, 0x3f, 0x0a, 0x11, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, + 0x74, 0x4b, 0x65, 0x79, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xa6, 0x04, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x4c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x49, 0x0a, 0x10, + 0x68, 0x74, 0x74, 0x70, 0x5f, 0x6b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, + 0x6f, 0x6b, 0x75, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x48, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x42, + 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x52, 0x0f, 0x68, 0x74, 0x74, 0x70, 0x4b, 0x65, 0x79, 0x62, + 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x12, 0x49, 0x0a, 0x10, 0x67, 0x72, 0x70, 0x63, 0x5f, + 0x6b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2e, + 0x76, 0x31, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x65, + 0x72, 0x52, 0x0f, 0x67, 0x72, 0x70, 0x63, 0x4b, 0x65, 0x79, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x65, + 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x5f, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6c, 0x6f, 0x6f, 0x6b, + 0x75, 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x4f, 0x0a, 0x16, 0x6c, 0x6f, 0x6f, + 0x6b, 0x75, 0x70, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, + 0x6f, 0x75, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x14, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, + 0x78, 0x5f, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, + 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x6d, 0x61, 0x78, 0x41, 0x67, 0x65, 0x12, 0x36, + 0x0a, 0x09, 0x73, 0x74, 0x61, 0x6c, 0x65, 0x5f, 0x61, 0x67, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x73, 0x74, + 0x61, 0x6c, 0x65, 0x41, 0x67, 0x65, 0x12, 0x28, 0x0a, 0x10, 0x63, 0x61, 0x63, 0x68, 0x65, 0x5f, + 0x73, 0x69, 0x7a, 0x65, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x0e, 0x63, 0x61, 0x63, 0x68, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x42, 0x79, 0x74, 0x65, 0x73, + 0x12, 0x23, 0x0a, 0x0d, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x5f, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, + 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x54, 0x61, + 0x72, 0x67, 0x65, 0x74, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, + 0x5f, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, + 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4a, 0x04, 0x08, 0x0a, + 0x10, 0x0b, 0x52, 0x1b, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x72, 0x6f, 0x63, + 0x65, 0x73, 0x73, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x42, + 0x53, 0x0a, 0x11, 0x69, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, + 0x70, 0x2e, 0x76, 0x31, 0x42, 0x0e, 0x52, 0x6c, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x50, + 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, + 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x6c, + 0x6f, 0x6f, 0x6b, 0x75, 0x70, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x6c, 0x6f, 0x6f, 0x6b, 0x75, + 0x70, 0x5f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -604,30 +737,36 @@ func file_grpc_lookup_v1_rls_config_proto_rawDescGZIP() []byte { return file_grpc_lookup_v1_rls_config_proto_rawDescData } -var file_grpc_lookup_v1_rls_config_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_grpc_lookup_v1_rls_config_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_grpc_lookup_v1_rls_config_proto_goTypes = []interface{}{ - (*NameMatcher)(nil), // 0: grpc.lookup.v1.NameMatcher - (*GrpcKeyBuilder)(nil), // 1: grpc.lookup.v1.GrpcKeyBuilder - (*HttpKeyBuilder)(nil), // 2: grpc.lookup.v1.HttpKeyBuilder - (*RouteLookupConfig)(nil), // 3: grpc.lookup.v1.RouteLookupConfig - (*GrpcKeyBuilder_Name)(nil), // 4: grpc.lookup.v1.GrpcKeyBuilder.Name - (*durationpb.Duration)(nil), // 5: google.protobuf.Duration + (*NameMatcher)(nil), // 0: grpc.lookup.v1.NameMatcher + (*GrpcKeyBuilder)(nil), // 1: grpc.lookup.v1.GrpcKeyBuilder + (*HttpKeyBuilder)(nil), // 2: grpc.lookup.v1.HttpKeyBuilder + (*RouteLookupConfig)(nil), // 3: grpc.lookup.v1.RouteLookupConfig + (*GrpcKeyBuilder_Name)(nil), // 4: grpc.lookup.v1.GrpcKeyBuilder.Name + (*GrpcKeyBuilder_ExtraKeys)(nil), // 5: grpc.lookup.v1.GrpcKeyBuilder.ExtraKeys + nil, // 6: grpc.lookup.v1.GrpcKeyBuilder.ConstantKeysEntry + nil, // 7: grpc.lookup.v1.HttpKeyBuilder.ConstantKeysEntry + (*durationpb.Duration)(nil), // 8: google.protobuf.Duration } var file_grpc_lookup_v1_rls_config_proto_depIdxs = []int32{ - 4, // 0: grpc.lookup.v1.GrpcKeyBuilder.names:type_name -> grpc.lookup.v1.GrpcKeyBuilder.Name - 0, // 1: grpc.lookup.v1.GrpcKeyBuilder.headers:type_name -> grpc.lookup.v1.NameMatcher - 0, // 2: grpc.lookup.v1.HttpKeyBuilder.query_parameters:type_name -> grpc.lookup.v1.NameMatcher - 0, // 3: grpc.lookup.v1.HttpKeyBuilder.headers:type_name -> grpc.lookup.v1.NameMatcher - 2, // 4: grpc.lookup.v1.RouteLookupConfig.http_keybuilders:type_name -> grpc.lookup.v1.HttpKeyBuilder - 1, // 5: grpc.lookup.v1.RouteLookupConfig.grpc_keybuilders:type_name -> grpc.lookup.v1.GrpcKeyBuilder - 5, // 6: grpc.lookup.v1.RouteLookupConfig.lookup_service_timeout:type_name -> google.protobuf.Duration - 5, // 7: grpc.lookup.v1.RouteLookupConfig.max_age:type_name -> google.protobuf.Duration - 5, // 8: grpc.lookup.v1.RouteLookupConfig.stale_age:type_name -> google.protobuf.Duration - 9, // [9:9] is the sub-list for method output_type - 9, // [9:9] is the sub-list for method input_type - 9, // [9:9] is the sub-list for extension type_name - 9, // [9:9] is the sub-list for extension extendee - 0, // [0:9] is the sub-list for field type_name + 4, // 0: grpc.lookup.v1.GrpcKeyBuilder.names:type_name -> grpc.lookup.v1.GrpcKeyBuilder.Name + 5, // 1: grpc.lookup.v1.GrpcKeyBuilder.extra_keys:type_name -> grpc.lookup.v1.GrpcKeyBuilder.ExtraKeys + 0, // 2: grpc.lookup.v1.GrpcKeyBuilder.headers:type_name -> grpc.lookup.v1.NameMatcher + 6, // 3: grpc.lookup.v1.GrpcKeyBuilder.constant_keys:type_name -> grpc.lookup.v1.GrpcKeyBuilder.ConstantKeysEntry + 0, // 4: grpc.lookup.v1.HttpKeyBuilder.query_parameters:type_name -> grpc.lookup.v1.NameMatcher + 0, // 5: grpc.lookup.v1.HttpKeyBuilder.headers:type_name -> grpc.lookup.v1.NameMatcher + 7, // 6: grpc.lookup.v1.HttpKeyBuilder.constant_keys:type_name -> grpc.lookup.v1.HttpKeyBuilder.ConstantKeysEntry + 2, // 7: grpc.lookup.v1.RouteLookupConfig.http_keybuilders:type_name -> grpc.lookup.v1.HttpKeyBuilder + 1, // 8: grpc.lookup.v1.RouteLookupConfig.grpc_keybuilders:type_name -> grpc.lookup.v1.GrpcKeyBuilder + 8, // 9: grpc.lookup.v1.RouteLookupConfig.lookup_service_timeout:type_name -> google.protobuf.Duration + 8, // 10: grpc.lookup.v1.RouteLookupConfig.max_age:type_name -> google.protobuf.Duration + 8, // 11: grpc.lookup.v1.RouteLookupConfig.stale_age:type_name -> google.protobuf.Duration + 12, // [12:12] is the sub-list for method output_type + 12, // [12:12] is the sub-list for method input_type + 12, // [12:12] is the sub-list for extension type_name + 12, // [12:12] is the sub-list for extension extendee + 0, // [0:12] is the sub-list for field type_name } func init() { file_grpc_lookup_v1_rls_config_proto_init() } @@ -696,6 +835,18 @@ func file_grpc_lookup_v1_rls_config_proto_init() { return nil } } + file_grpc_lookup_v1_rls_config_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GrpcKeyBuilder_ExtraKeys); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -703,7 +854,7 @@ func file_grpc_lookup_v1_rls_config_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_grpc_lookup_v1_rls_config_proto_rawDesc, NumEnums: 0, - NumMessages: 5, + NumMessages: 8, NumExtensions: 0, NumServices: 0, }, diff --git a/balancer/rls/internal/proto/grpc_lookup_v1/rls_grpc.pb.go b/balancer/rls/internal/proto/grpc_lookup_v1/rls_grpc.pb.go index b469089ed57..39d79e13343 100644 --- a/balancer/rls/internal/proto/grpc_lookup_v1/rls_grpc.pb.go +++ b/balancer/rls/internal/proto/grpc_lookup_v1/rls_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/lookup/v1/rls.proto package grpc_lookup_v1 diff --git a/balancer/roundrobin/roundrobin.go b/balancer/roundrobin/roundrobin.go index 43c2a15373a..274eb2f8580 100644 --- a/balancer/roundrobin/roundrobin.go +++ b/balancer/roundrobin/roundrobin.go @@ -47,11 +47,11 @@ func init() { type rrPickerBuilder struct{} func (*rrPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { - logger.Infof("roundrobinPicker: newPicker called with info: %v", info) + logger.Infof("roundrobinPicker: Build called with info: %v", info) if len(info.ReadySCs) == 0 { return base.NewErrPicker(balancer.ErrNoSubConnAvailable) } - var scs []balancer.SubConn + scs := make([]balancer.SubConn, 0, len(info.ReadySCs)) for sc := range info.ReadySCs { scs = append(scs, sc) } diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index 41061d6d3dc..f4ea6174682 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -37,14 +37,20 @@ type scStateUpdate struct { err error } +// exitIdle contains no data and is just a signal sent on the updateCh in +// ccBalancerWrapper to instruct the balancer to exit idle. +type exitIdle struct{} + // ccBalancerWrapper is a wrapper on top of cc for balancers. // It implements balancer.ClientConn interface. type ccBalancerWrapper struct { - cc *ClientConn - balancerMu sync.Mutex // synchronizes calls to the balancer - balancer balancer.Balancer - scBuffer *buffer.Unbounded - done *grpcsync.Event + cc *ClientConn + balancerMu sync.Mutex // synchronizes calls to the balancer + balancer balancer.Balancer + hasExitIdle bool + updateCh *buffer.Unbounded + closed *grpcsync.Event + done *grpcsync.Event mu sync.Mutex subConns map[*acBalancerWrapper]struct{} @@ -53,12 +59,14 @@ type ccBalancerWrapper struct { func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper { ccb := &ccBalancerWrapper{ cc: cc, - scBuffer: buffer.NewUnbounded(), + updateCh: buffer.NewUnbounded(), + closed: grpcsync.NewEvent(), done: grpcsync.NewEvent(), subConns: make(map[*acBalancerWrapper]struct{}), } go ccb.watcher() ccb.balancer = b.Build(ccb, bopts) + _, ccb.hasExitIdle = ccb.balancer.(balancer.ExitIdler) return ccb } @@ -67,35 +75,72 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui func (ccb *ccBalancerWrapper) watcher() { for { select { - case t := <-ccb.scBuffer.Get(): - ccb.scBuffer.Load() - if ccb.done.HasFired() { + case t := <-ccb.updateCh.Get(): + ccb.updateCh.Load() + if ccb.closed.HasFired() { break } - ccb.balancerMu.Lock() - su := t.(*scStateUpdate) - ccb.balancer.UpdateSubConnState(su.sc, balancer.SubConnState{ConnectivityState: su.state, ConnectionError: su.err}) - ccb.balancerMu.Unlock() - case <-ccb.done.Done(): + switch u := t.(type) { + case *scStateUpdate: + ccb.balancerMu.Lock() + ccb.balancer.UpdateSubConnState(u.sc, balancer.SubConnState{ConnectivityState: u.state, ConnectionError: u.err}) + ccb.balancerMu.Unlock() + case *acBalancerWrapper: + ccb.mu.Lock() + if ccb.subConns != nil { + delete(ccb.subConns, u) + ccb.cc.removeAddrConn(u.getAddrConn(), errConnDrain) + } + ccb.mu.Unlock() + case exitIdle: + if ccb.cc.GetState() == connectivity.Idle { + if ei, ok := ccb.balancer.(balancer.ExitIdler); ok { + // We already checked that the balancer implements + // ExitIdle before pushing the event to updateCh, but + // check conditionally again as defensive programming. + ccb.balancerMu.Lock() + ei.ExitIdle() + ccb.balancerMu.Unlock() + } + } + default: + logger.Errorf("ccBalancerWrapper.watcher: unknown update %+v, type %T", t, t) + } + case <-ccb.closed.Done(): } - if ccb.done.HasFired() { + if ccb.closed.HasFired() { + ccb.balancerMu.Lock() ccb.balancer.Close() + ccb.balancerMu.Unlock() ccb.mu.Lock() scs := ccb.subConns ccb.subConns = nil ccb.mu.Unlock() + ccb.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: nil}) + ccb.done.Fire() + // Fire done before removing the addr conns. We can safely unblock + // ccb.close and allow the removeAddrConns to happen + // asynchronously. for acbw := range scs { ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) } - ccb.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: nil}) return } } } func (ccb *ccBalancerWrapper) close() { - ccb.done.Fire() + ccb.closed.Fire() + <-ccb.done.Done() +} + +func (ccb *ccBalancerWrapper) exitIdle() bool { + if !ccb.hasExitIdle { + return false + } + ccb.updateCh.Put(exitIdle{}) + return true } func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State, err error) { @@ -109,7 +154,7 @@ func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s co if sc == nil { return } - ccb.scBuffer.Put(&scStateUpdate{ + ccb.updateCh.Put(&scStateUpdate{ sc: sc, state: s, err: err, @@ -124,8 +169,8 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat func (ccb *ccBalancerWrapper) resolverError(err error) { ccb.balancerMu.Lock() + defer ccb.balancerMu.Unlock() ccb.balancer.ResolverError(err) - ccb.balancerMu.Unlock() } func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { @@ -150,17 +195,10 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer } func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { - acbw, ok := sc.(*acBalancerWrapper) - if !ok { - return - } - ccb.mu.Lock() - defer ccb.mu.Unlock() - if ccb.subConns == nil { - return - } - delete(ccb.subConns, acbw) - ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) + // The RemoveSubConn() is handled in the run() goroutine, to avoid deadlock + // during switchBalancer() if the old balancer calls RemoveSubConn() in its + // Close(). + ccb.updateCh.Put(sc) } func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { @@ -205,7 +243,7 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { acbw.mu.Lock() defer acbw.mu.Unlock() if len(addrs) <= 0 { - acbw.ac.tearDown(errConnDrain) + acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain) return } if !acbw.ac.tryUpdateAddrs(addrs) { @@ -220,23 +258,23 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { acbw.ac.acbw = nil acbw.ac.mu.Unlock() acState := acbw.ac.getState() - acbw.ac.tearDown(errConnDrain) + acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain) if acState == connectivity.Shutdown { return } - ac, err := cc.newAddrConn(addrs, opts) + newAC, err := cc.newAddrConn(addrs, opts) if err != nil { channelz.Warningf(logger, acbw.ac.channelzID, "acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err) return } - acbw.ac = ac - ac.mu.Lock() - ac.acbw = acbw - ac.mu.Unlock() + acbw.ac = newAC + newAC.mu.Lock() + newAC.acbw = acbw + newAC.mu.Unlock() if acState != connectivity.Idle { - ac.connect() + go newAC.connect() } } } @@ -244,7 +282,7 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { func (acbw *acBalancerWrapper) Connect() { acbw.mu.Lock() defer acbw.mu.Unlock() - acbw.ac.connect() + go acbw.ac.connect() } func (acbw *acBalancerWrapper) getAddrConn() *addrConn { diff --git a/balancer_conn_wrappers_test.go b/balancer_conn_wrappers_test.go deleted file mode 100644 index 935d11d1d39..00000000000 --- a/balancer_conn_wrappers_test.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package grpc - -import ( - "fmt" - "net" - "testing" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/roundrobin" - "google.golang.org/grpc/internal/balancer/stub" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/resolver/manual" -) - -// TestBalancerErrorResolverPolling injects balancer errors and verifies -// ResolveNow is called on the resolver with the appropriate backoff strategy -// being consulted between ResolveNow calls. -func (s) TestBalancerErrorResolverPolling(t *testing.T) { - // The test balancer will return ErrBadResolverState iff the - // ClientConnState contains no addresses. - bf := stub.BalancerFuncs{ - UpdateClientConnState: func(_ *stub.BalancerData, s balancer.ClientConnState) error { - if len(s.ResolverState.Addresses) == 0 { - return balancer.ErrBadResolverState - } - return nil - }, - } - const balName = "BalancerErrorResolverPolling" - stub.Register(balName, bf) - - testResolverErrorPolling(t, - func(r *manual.Resolver) { - // No addresses so the balancer will fail. - r.CC.UpdateState(resolver.State{}) - }, func(r *manual.Resolver) { - // UpdateState will block if ResolveNow is being called (which blocks on - // rn), so call it in a goroutine. Include some address so the balancer - // will be happy. - go r.CC.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: "x"}}}) - }, - WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, balName))) -} - -// TestRoundRobinZeroAddressesResolverPolling reports no addresses to the round -// robin balancer and verifies ResolveNow is called on the resolver with the -// appropriate backoff strategy being consulted between ResolveNow calls. -func (s) TestRoundRobinZeroAddressesResolverPolling(t *testing.T) { - // We need to start a real server or else the connecting loop will call - // ResolveNow after every iteration, even after a valid resolver result is - // returned. - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Error while listening. Err: %v", err) - } - defer lis.Close() - s := NewServer() - defer s.Stop() - go s.Serve(lis) - - testResolverErrorPolling(t, - func(r *manual.Resolver) { - // No addresses so the balancer will fail. - r.CC.UpdateState(resolver.State{}) - }, func(r *manual.Resolver) { - // UpdateState will block if ResolveNow is being called (which - // blocks on rn), so call it in a goroutine. Include a valid - // address so the balancer will be happy. - go r.CC.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) - }, - WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, roundrobin.Name))) -} diff --git a/balancer_switching_test.go b/balancer_switching_test.go index 2c6ed576620..5d9a1f9fffc 100644 --- a/balancer_switching_test.go +++ b/balancer_switching_test.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/serviceconfig" @@ -57,6 +58,8 @@ func (b *magicalLB) UpdateClientConnState(balancer.ClientConnState) error { func (b *magicalLB) Close() {} +func (b *magicalLB) ExitIdle() {} + func init() { balancer.Register(&magicalLB{}) } @@ -531,6 +534,51 @@ func (s) TestSwitchBalancerGRPCLBWithGRPCLBNotRegistered(t *testing.T) { } } +const inlineRemoveSubConnBalancerName = "test-inline-remove-subconn-balancer" + +func init() { + stub.Register(inlineRemoveSubConnBalancerName, stub.BalancerFuncs{ + Close: func(data *stub.BalancerData) { + data.ClientConn.RemoveSubConn(&acBalancerWrapper{}) + }, + }) +} + +// Test that when switching to balancers, the old balancer calls RemoveSubConn +// in Close. +// +// This test is to make sure this close doesn't cause a deadlock. +func (s) TestSwitchBalancerOldRemoveSubConn(t *testing.T) { + r := manual.NewBuilderWithScheme("whatever") + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithResolvers(r)) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + cc.updateResolverState(resolver.State{ServiceConfig: parseCfg(r, fmt.Sprintf(`{"loadBalancingPolicy": "%v"}`, inlineRemoveSubConnBalancerName))}, nil) + // This service config update will switch balancer from + // "test-inline-remove-subconn-balancer" to "pick_first". The test balancer + // will be closed, which will call cc.RemoveSubConn() inline (this + // RemoveSubConn is not required by the API, but some balancers might do + // it). + // + // This is to make sure the cc.RemoveSubConn() from Close() doesn't cause a + // deadlock (e.g. trying to grab a mutex while it's already locked). + // + // Do it in a goroutine so this test will fail with a helpful message + // (though the goroutine will still leak). + done := make(chan struct{}) + go func() { + cc.updateResolverState(resolver.State{ServiceConfig: parseCfg(r, `{"loadBalancingPolicy": "pick_first"}`)}, nil) + close(done) + }() + select { + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for updateResolverState to finish") + case <-done: + } +} + func parseCfg(r *manual.Resolver, s string) *serviceconfig.ParseResult { scpr := r.CC.ParseServiceConfig(s) if scpr.Err != nil { diff --git a/benchmark/stats/histogram.go b/benchmark/stats/histogram.go index f038d26ed0a..461135f0125 100644 --- a/benchmark/stats/histogram.go +++ b/benchmark/stats/histogram.go @@ -118,10 +118,6 @@ func (h *Histogram) PrintWithUnit(w io.Writer, unit float64) { } maxBucketDigitLen := len(strconv.FormatFloat(h.Buckets[len(h.Buckets)-1].LowBound, 'f', 6, 64)) - if maxBucketDigitLen < 3 { - // For "inf". - maxBucketDigitLen = 3 - } maxCountDigitLen := len(strconv.FormatInt(h.Count, 10)) percentMulti := 100 / float64(h.Count) @@ -131,9 +127,9 @@ func (h *Histogram) PrintWithUnit(w io.Writer, unit float64) { if i+1 < len(h.Buckets) { fmt.Fprintf(w, "%*f)", maxBucketDigitLen, h.Buckets[i+1].LowBound/unit) } else { - fmt.Fprintf(w, "%*s)", maxBucketDigitLen, "inf") + upperBound := float64(h.opts.MinValue) + (b.LowBound-float64(h.opts.MinValue))*(1.0+h.opts.GrowthFactor) + fmt.Fprintf(w, "%*f)", maxBucketDigitLen, upperBound/unit) } - accCount += b.Count fmt.Fprintf(w, " %*d %5.1f%% %5.1f%%", maxCountDigitLen, b.Count, float64(b.Count)*percentMulti, float64(accCount)*percentMulti) @@ -188,6 +184,9 @@ func (h *Histogram) Add(value int64) error { func (h *Histogram) findBucket(value int64) (int, error) { delta := float64(value - h.opts.MinValue) + if delta < 0 { + return 0, fmt.Errorf("no bucket for value: %d", value) + } var b int if delta >= h.opts.BaseBucketSize { // b = log_{1+growthFactor} (delta / baseBucketSize) + 1 diff --git a/call_test.go b/call_test.go index abc4537ddb7..8fdbc9b7eb7 100644 --- a/call_test.go +++ b/call_test.go @@ -160,7 +160,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) { config := &transport.ServerConfig{ MaxStreams: maxStreams, } - st, err := transport.NewServerTransport("http2", conn, config) + st, err := transport.NewServerTransport(conn, config) if err != nil { continue } diff --git a/channelz/grpc_channelz_v1/channelz_grpc.pb.go b/channelz/grpc_channelz_v1/channelz_grpc.pb.go index 051d1ac440c..ee425c21994 100644 --- a/channelz/grpc_channelz_v1/channelz_grpc.pb.go +++ b/channelz/grpc_channelz_v1/channelz_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/channelz/v1/channelz.proto package grpc_channelz_v1 diff --git a/channelz/service/func_linux.go b/channelz/service/func_linux.go index ce38a921b97..2e52d5f5a98 100644 --- a/channelz/service/func_linux.go +++ b/channelz/service/func_linux.go @@ -25,6 +25,7 @@ import ( durpb "github.com/golang/protobuf/ptypes/duration" channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/testutils" ) func convertToPtypesDuration(sec int64, usec int64) *durpb.Duration { @@ -34,41 +35,32 @@ func convertToPtypesDuration(sec int64, usec int64) *durpb.Duration { func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOption { var opts []*channelzpb.SocketOption if skopts.Linger != nil { - additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionLinger{ - Active: skopts.Linger.Onoff != 0, - Duration: convertToPtypesDuration(int64(skopts.Linger.Linger), 0), + opts = append(opts, &channelzpb.SocketOption{ + Name: "SO_LINGER", + Additional: testutils.MarshalAny(&channelzpb.SocketOptionLinger{ + Active: skopts.Linger.Onoff != 0, + Duration: convertToPtypesDuration(int64(skopts.Linger.Linger), 0), + }), }) - if err == nil { - opts = append(opts, &channelzpb.SocketOption{ - Name: "SO_LINGER", - Additional: additional, - }) - } } if skopts.RecvTimeout != nil { - additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{ - Duration: convertToPtypesDuration(int64(skopts.RecvTimeout.Sec), int64(skopts.RecvTimeout.Usec)), + opts = append(opts, &channelzpb.SocketOption{ + Name: "SO_RCVTIMEO", + Additional: testutils.MarshalAny(&channelzpb.SocketOptionTimeout{ + Duration: convertToPtypesDuration(int64(skopts.RecvTimeout.Sec), int64(skopts.RecvTimeout.Usec)), + }), }) - if err == nil { - opts = append(opts, &channelzpb.SocketOption{ - Name: "SO_RCVTIMEO", - Additional: additional, - }) - } } if skopts.SendTimeout != nil { - additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{ - Duration: convertToPtypesDuration(int64(skopts.SendTimeout.Sec), int64(skopts.SendTimeout.Usec)), + opts = append(opts, &channelzpb.SocketOption{ + Name: "SO_SNDTIMEO", + Additional: testutils.MarshalAny(&channelzpb.SocketOptionTimeout{ + Duration: convertToPtypesDuration(int64(skopts.SendTimeout.Sec), int64(skopts.SendTimeout.Usec)), + }), }) - if err == nil { - opts = append(opts, &channelzpb.SocketOption{ - Name: "SO_SNDTIMEO", - Additional: additional, - }) - } } if skopts.TCPInfo != nil { - additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTcpInfo{ + additional := testutils.MarshalAny(&channelzpb.SocketOptionTcpInfo{ TcpiState: uint32(skopts.TCPInfo.State), TcpiCaState: uint32(skopts.TCPInfo.Ca_state), TcpiRetransmits: uint32(skopts.TCPInfo.Retransmits), @@ -99,12 +91,10 @@ func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOptio TcpiAdvmss: skopts.TCPInfo.Advmss, TcpiReordering: skopts.TCPInfo.Reordering, }) - if err == nil { - opts = append(opts, &channelzpb.SocketOption{ - Name: "TCP_INFO", - Additional: additional, - }) - } + opts = append(opts, &channelzpb.SocketOption{ + Name: "TCP_INFO", + Additional: additional, + }) } return opts } diff --git a/channelz/service/func_nonlinux.go b/channelz/service/func_nonlinux.go index eb53334ed0d..473495d6655 100644 --- a/channelz/service/func_nonlinux.go +++ b/channelz/service/func_nonlinux.go @@ -1,4 +1,5 @@ -// +build !linux appengine +//go:build !linux +// +build !linux /* * diff --git a/channelz/service/service.go b/channelz/service/service.go index 4d175fef823..9e325376f6c 100644 --- a/channelz/service/service.go +++ b/channelz/service/service.go @@ -43,7 +43,11 @@ func init() { var logger = grpclog.Component("channelz") // RegisterChannelzServiceToServer registers the channelz service to the given server. -func RegisterChannelzServiceToServer(s *grpc.Server) { +// +// Note: it is preferred to use the admin API +// (https://pkg.go.dev/google.golang.org/grpc/admin#Register) instead to +// register Channelz and other administrative services. +func RegisterChannelzServiceToServer(s grpc.ServiceRegistrar) { channelzgrpc.RegisterChannelzServer(s, newCZServer()) } @@ -78,7 +82,7 @@ func channelTraceToProto(ct *channelz.ChannelTrace) *channelzpb.ChannelTrace { if ts, err := ptypes.TimestampProto(ct.CreationTime); err == nil { pbt.CreationTimestamp = ts } - var events []*channelzpb.ChannelTraceEvent + events := make([]*channelzpb.ChannelTraceEvent, 0, len(ct.Events)) for _, e := range ct.Events { cte := &channelzpb.ChannelTraceEvent{ Description: e.Desc, diff --git a/channelz/service/service_sktopt_test.go b/channelz/service/service_sktopt_test.go index ecd4a2ad05f..4ea6b20cd6a 100644 --- a/channelz/service/service_sktopt_test.go +++ b/channelz/service/service_sktopt_test.go @@ -1,3 +1,4 @@ +//go:build linux && (386 || amd64) // +build linux // +build 386 amd64 diff --git a/channelz/service/util_sktopt_386_test.go b/channelz/service/util_sktopt_386_test.go index d9c98127136..3ba3dc96e7c 100644 --- a/channelz/service/util_sktopt_386_test.go +++ b/channelz/service/util_sktopt_386_test.go @@ -1,3 +1,4 @@ +//go:build 386 && linux // +build 386,linux /* diff --git a/channelz/service/util_sktopt_amd64_test.go b/channelz/service/util_sktopt_amd64_test.go index 0ff06d12833..124d7b75819 100644 --- a/channelz/service/util_sktopt_amd64_test.go +++ b/channelz/service/util_sktopt_amd64_test.go @@ -1,3 +1,4 @@ +//go:build amd64 && linux // +build amd64,linux /* diff --git a/clientconn.go b/clientconn.go index 77a08fd33bf..34cc4c948db 100644 --- a/clientconn.go +++ b/clientconn.go @@ -143,6 +143,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * firstResolveEvent: grpcsync.NewEvent(), } cc.retryThrottler.Store((*retryThrottler)(nil)) + cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{nil}) cc.ctx, cc.cancel = context.WithCancel(context.Background()) for _, opt := range opts { @@ -321,6 +322,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * // A blocking dial blocks until the clientConn is ready. if cc.dopts.block { for { + cc.Connect() s := cc.GetState() if s == connectivity.Ready { break @@ -538,12 +540,31 @@ func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connec // // Experimental // -// Notice: This API is EXPERIMENTAL and may be changed or removed in a -// later release. +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. func (cc *ClientConn) GetState() connectivity.State { return cc.csMgr.getState() } +// Connect causes all subchannels in the ClientConn to attempt to connect if +// the channel is idle. Does not wait for the connection attempts to begin +// before returning. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +func (cc *ClientConn) Connect() { + cc.mu.Lock() + defer cc.mu.Unlock() + if cc.balancerWrapper != nil && cc.balancerWrapper.exitIdle() { + return + } + for ac := range cc.conns { + go ac.connect() + } +} + func (cc *ClientConn) scWatcher() { for { select { @@ -710,7 +731,12 @@ func (cc *ClientConn) switchBalancer(name string) { return } if cc.balancerWrapper != nil { + // Don't hold cc.mu while closing the balancers. The balancers may call + // methods that require cc.mu (e.g. cc.NewSubConn()). Holding the mutex + // would cause a deadlock in that case. + cc.mu.Unlock() cc.balancerWrapper.close() + cc.mu.Lock() } builder := balancer.Get(name) @@ -839,8 +865,7 @@ func (ac *addrConn) connect() error { ac.updateConnectivityState(connectivity.Connecting, nil) ac.mu.Unlock() - // Start a goroutine connecting to the server asynchronously. - go ac.resetTransport() + ac.resetTransport() return nil } @@ -877,6 +902,10 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool { // ac.state is Ready, try to find the connected address. var curAddrFound bool for _, a := range addrs { + // a.ServerName takes precedent over ClientConn authority, if present. + if a.ServerName == "" { + a.ServerName = ac.cc.authority + } if reflect.DeepEqual(ac.curAddr, a) { curAddrFound = true break @@ -1045,12 +1074,12 @@ func (cc *ClientConn) Close() error { cc.blockingpicker.close() - if rWrapper != nil { - rWrapper.close() - } if bWrapper != nil { bWrapper.close() } + if rWrapper != nil { + rWrapper.close() + } for ac := range conns { ac.tearDown(ErrClientConnClosing) @@ -1129,112 +1158,86 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) { } func (ac *addrConn) resetTransport() { - for i := 0; ; i++ { - if i > 0 { - ac.cc.resolveNow(resolver.ResolveNowOptions{}) - } + ac.mu.Lock() + if ac.state == connectivity.Shutdown { + ac.mu.Unlock() + return + } + addrs := ac.addrs + backoffFor := ac.dopts.bs.Backoff(ac.backoffIdx) + // This will be the duration that dial gets to finish. + dialDuration := minConnectTimeout + if ac.dopts.minConnectTimeout != nil { + dialDuration = ac.dopts.minConnectTimeout() + } + + if dialDuration < backoffFor { + // Give dial more time as we keep failing to connect. + dialDuration = backoffFor + } + // We can potentially spend all the time trying the first address, and + // if the server accepts the connection and then hangs, the following + // addresses will never be tried. + // + // The spec doesn't mention what should be done for multiple addresses. + // https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md#proposed-backoff-algorithm + connectDeadline := time.Now().Add(dialDuration) + + ac.updateConnectivityState(connectivity.Connecting, nil) + ac.mu.Unlock() + + if err := ac.tryAllAddrs(addrs, connectDeadline); err != nil { + ac.cc.resolveNow(resolver.ResolveNowOptions{}) + // After exhausting all addresses, the addrConn enters + // TRANSIENT_FAILURE. ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() return } + ac.updateConnectivityState(connectivity.TransientFailure, err) - addrs := ac.addrs - backoffFor := ac.dopts.bs.Backoff(ac.backoffIdx) - // This will be the duration that dial gets to finish. - dialDuration := minConnectTimeout - if ac.dopts.minConnectTimeout != nil { - dialDuration = ac.dopts.minConnectTimeout() - } - - if dialDuration < backoffFor { - // Give dial more time as we keep failing to connect. - dialDuration = backoffFor - } - // We can potentially spend all the time trying the first address, and - // if the server accepts the connection and then hangs, the following - // addresses will never be tried. - // - // The spec doesn't mention what should be done for multiple addresses. - // https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md#proposed-backoff-algorithm - connectDeadline := time.Now().Add(dialDuration) - - ac.updateConnectivityState(connectivity.Connecting, nil) - ac.transport = nil + // Backoff. + b := ac.resetBackoff ac.mu.Unlock() - newTr, addr, reconnect, err := ac.tryAllAddrs(addrs, connectDeadline) - if err != nil { - // After exhausting all addresses, the addrConn enters - // TRANSIENT_FAILURE. + timer := time.NewTimer(backoffFor) + select { + case <-timer.C: ac.mu.Lock() - if ac.state == connectivity.Shutdown { - ac.mu.Unlock() - return - } - ac.updateConnectivityState(connectivity.TransientFailure, err) - - // Backoff. - b := ac.resetBackoff + ac.backoffIdx++ ac.mu.Unlock() - - timer := time.NewTimer(backoffFor) - select { - case <-timer.C: - ac.mu.Lock() - ac.backoffIdx++ - ac.mu.Unlock() - case <-b: - timer.Stop() - case <-ac.ctx.Done(): - timer.Stop() - return - } - continue + case <-b: + timer.Stop() + case <-ac.ctx.Done(): + timer.Stop() + return } ac.mu.Lock() - if ac.state == connectivity.Shutdown { - ac.mu.Unlock() - newTr.Close() - return + if ac.state != connectivity.Shutdown { + ac.updateConnectivityState(connectivity.Idle, err) } - ac.curAddr = addr - ac.transport = newTr - ac.backoffIdx = 0 - - hctx, hcancel := context.WithCancel(ac.ctx) - ac.startHealthCheck(hctx) ac.mu.Unlock() - - // Block until the created transport is down. And when this happens, - // we restart from the top of the addr list. - <-reconnect.Done() - hcancel() - // restart connecting - the top of the loop will set state to - // CONNECTING. This is against the current connectivity semantics doc, - // however it allows for graceful behavior for RPCs not yet dispatched - // - unfortunate timing would otherwise lead to the RPC failing even - // though the TRANSIENT_FAILURE state (called for by the doc) would be - // instantaneous. - // - // Ideally we should transition to Idle here and block until there is - // RPC activity that leads to the balancer requesting a reconnect of - // the associated SubConn. + return } + // Success; reset backoff. + ac.mu.Lock() + ac.backoffIdx = 0 + ac.mu.Unlock() } -// tryAllAddrs tries to creates a connection to the addresses, and stop when at the -// first successful one. It returns the transport, the address and a Event in -// the successful case. The Event fires when the returned transport disconnects. -func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.Time) (transport.ClientTransport, resolver.Address, *grpcsync.Event, error) { +// tryAllAddrs tries to creates a connection to the addresses, and stop when at +// the first successful one. It returns an error if no address was successfully +// connected, or updates ac appropriately with the new transport. +func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.Time) error { var firstConnErr error for _, addr := range addrs { ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() - return nil, resolver.Address{}, nil, errConnClosing + return errConnClosing } ac.cc.mu.RLock() @@ -1249,9 +1252,9 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T channelz.Infof(logger, ac.channelzID, "Subchannel picks a new address %q to connect", addr.Addr) - newTr, reconnect, err := ac.createTransport(addr, copts, connectDeadline) + err := ac.createTransport(addr, copts, connectDeadline) if err == nil { - return newTr, addr, reconnect, nil + return nil } if firstConnErr == nil { firstConnErr = err @@ -1260,57 +1263,54 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T } // Couldn't connect to any address. - return nil, resolver.Address{}, nil, firstConnErr + return firstConnErr } -// createTransport creates a connection to addr. It returns the transport and a -// Event in the successful case. The Event fires when the returned transport -// disconnects. -func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) (transport.ClientTransport, *grpcsync.Event, error) { - prefaceReceived := make(chan struct{}) - onCloseCalled := make(chan struct{}) - reconnect := grpcsync.NewEvent() +// createTransport creates a connection to addr. It returns an error if the +// address was not successfully connected, or updates ac appropriately with the +// new transport. +func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error { + // TODO: Delete prefaceReceived and move the logic to wait for it into the + // transport. + prefaceReceived := grpcsync.NewEvent() + connClosed := grpcsync.NewEvent() // addr.ServerName takes precedent over ClientConn authority, if present. if addr.ServerName == "" { addr.ServerName = ac.cc.authority } - once := sync.Once{} - onGoAway := func(r transport.GoAwayReason) { - ac.mu.Lock() - ac.adjustParams(r) - once.Do(func() { - if ac.state == connectivity.Ready { - // Prevent this SubConn from being used for new RPCs by setting its - // state to Connecting. - // - // TODO: this should be Idle when grpc-go properly supports it. - ac.updateConnectivityState(connectivity.Connecting, nil) - } - }) - ac.mu.Unlock() - reconnect.Fire() - } + hctx, hcancel := context.WithCancel(ac.ctx) + hcStarted := false // protected by ac.mu onClose := func() { ac.mu.Lock() - once.Do(func() { - if ac.state == connectivity.Ready { - // Prevent this SubConn from being used for new RPCs by setting its - // state to Connecting. - // - // TODO: this should be Idle when grpc-go properly supports it. - ac.updateConnectivityState(connectivity.Connecting, nil) - } - }) - ac.mu.Unlock() - close(onCloseCalled) - reconnect.Fire() + defer ac.mu.Unlock() + defer connClosed.Fire() + if !hcStarted || hctx.Err() != nil { + // We didn't start the health check or set the state to READY, so + // no need to do anything else here. + // + // OR, we have already cancelled the health check context, meaning + // we have already called onClose once for this transport. In this + // case it would be dangerous to clear the transport and update the + // state, since there may be a new transport in this addrConn. + return + } + hcancel() + ac.transport = nil + // Refresh the name resolver + ac.cc.resolveNow(resolver.ResolveNowOptions{}) + if ac.state != connectivity.Shutdown { + ac.updateConnectivityState(connectivity.Idle, nil) + } } - onPrefaceReceipt := func() { - close(prefaceReceived) + onGoAway := func(r transport.GoAwayReason) { + ac.mu.Lock() + ac.adjustParams(r) + ac.mu.Unlock() + onClose() } connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline) @@ -1319,27 +1319,67 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne copts.ChannelzParentID = ac.channelzID } - newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onPrefaceReceipt, onGoAway, onClose) + newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, func() { prefaceReceived.Fire() }, onGoAway, onClose) if err != nil { // newTr is either nil, or closed. - channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v. Err: %v. Reconnecting...", addr, err) - return nil, nil, err + channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v. Err: %v", addr, err) + return err } select { - case <-time.After(time.Until(connectDeadline)): + case <-connectCtx.Done(): // We didn't get the preface in time. - newTr.Close() - channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v: didn't receive server preface in time. Reconnecting...", addr) - return nil, nil, errors.New("timed out waiting for server handshake") - case <-prefaceReceived: + // The error we pass to Close() is immaterial since there are no open + // streams at this point, so no trailers with error details will be sent + // out. We just need to pass a non-nil error. + newTr.Close(transport.ErrConnClosing) + if connectCtx.Err() == context.DeadlineExceeded { + err := errors.New("failed to receive server preface within timeout") + channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %v: %v", addr, err) + return err + } + return nil + case <-prefaceReceived.Done(): // We got the preface - huzzah! things are good. - case <-onCloseCalled: - // The transport has already closed - noop. - return nil, nil, errors.New("connection closed") - // TODO(deklerk) this should bail on ac.ctx.Done(). Add a test and fix. + ac.mu.Lock() + defer ac.mu.Unlock() + if connClosed.HasFired() { + // onClose called first; go idle but do nothing else. + if ac.state != connectivity.Shutdown { + ac.updateConnectivityState(connectivity.Idle, nil) + } + return nil + } + if ac.state == connectivity.Shutdown { + // This can happen if the subConn was removed while in `Connecting` + // state. tearDown() would have set the state to `Shutdown`, but + // would not have closed the transport since ac.transport would not + // been set at that point. + // + // We run this in a goroutine because newTr.Close() calls onClose() + // inline, which requires locking ac.mu. + // + // The error we pass to Close() is immaterial since there are no open + // streams at this point, so no trailers with error details will be sent + // out. We just need to pass a non-nil error. + go newTr.Close(transport.ErrConnClosing) + return nil + } + ac.curAddr = addr + ac.transport = newTr + hcStarted = true + ac.startHealthCheck(hctx) // Will set state to READY if appropriate. + return nil + case <-connClosed.Done(): + // The transport has already closed. If we received the preface, too, + // this is not an error. + select { + case <-prefaceReceived.Done(): + return nil + default: + return errors.New("connection closed before server preface received") + } } - return newTr, reconnect, nil } // startHealthCheck starts the health checking stream (RPC) to watch the health @@ -1423,33 +1463,20 @@ func (ac *addrConn) resetConnectBackoff() { ac.mu.Unlock() } -// getReadyTransport returns the transport if ac's state is READY. -// Otherwise it returns nil, false. -// If ac's state is IDLE, it will trigger ac to connect. -func (ac *addrConn) getReadyTransport() (transport.ClientTransport, bool) { +// getReadyTransport returns the transport if ac's state is READY or nil if not. +func (ac *addrConn) getReadyTransport() transport.ClientTransport { ac.mu.Lock() - if ac.state == connectivity.Ready && ac.transport != nil { - t := ac.transport - ac.mu.Unlock() - return t, true - } - var idle bool - if ac.state == connectivity.Idle { - idle = true - } - ac.mu.Unlock() - // Trigger idle ac to connect. - if idle { - ac.connect() + defer ac.mu.Unlock() + if ac.state == connectivity.Ready { + return ac.transport } - return nil, false + return nil } // tearDown starts to tear down the addrConn. -// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in -// some edge cases (e.g., the caller opens and closes many addrConn's in a -// tight loop. -// tearDown doesn't remove ac from ac.cc.conns. +// +// Note that tearDown doesn't remove ac from ac.cc.conns, so the addrConn struct +// will leak. In most cases, call cc.removeAddrConn() instead. func (ac *addrConn) tearDown(err error) { ac.mu.Lock() if ac.state == connectivity.Shutdown { diff --git a/clientconn_authority_test.go b/clientconn_authority_test.go new file mode 100644 index 00000000000..5cd705e2d4f --- /dev/null +++ b/clientconn_authority_test.go @@ -0,0 +1,122 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "context" + "net" + "testing" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/testdata" +) + +func (s) TestClientConnAuthority(t *testing.T) { + serverNameOverride := "over.write.server.name" + creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), serverNameOverride) + if err != nil { + t.Fatalf("credentials.NewClientTLSFromFile(_, %q) failed: %v", err, serverNameOverride) + } + + tests := []struct { + name string + target string + opts []DialOption + wantAuthority string + }{ + { + name: "default", + target: "Non-Existent.Server:8080", + opts: []DialOption{WithInsecure()}, + wantAuthority: "Non-Existent.Server:8080", + }, + { + name: "override-via-creds", + target: "Non-Existent.Server:8080", + opts: []DialOption{WithTransportCredentials(creds)}, + wantAuthority: serverNameOverride, + }, + { + name: "override-via-WithAuthority", + target: "Non-Existent.Server:8080", + opts: []DialOption{WithInsecure(), WithAuthority("authority-override")}, + wantAuthority: "authority-override", + }, + { + name: "override-via-creds-and-WithAuthority", + target: "Non-Existent.Server:8080", + // WithAuthority override works only for insecure creds. + opts: []DialOption{WithTransportCredentials(creds), WithAuthority("authority-override")}, + wantAuthority: serverNameOverride, + }, + { + name: "unix relative", + target: "unix:sock.sock", + opts: []DialOption{WithInsecure()}, + wantAuthority: "localhost", + }, + { + name: "unix relative with custom dialer", + target: "unix:sock.sock", + opts: []DialOption{WithInsecure(), WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, "", addr) + })}, + wantAuthority: "localhost", + }, + { + name: "unix absolute", + target: "unix:/sock.sock", + opts: []DialOption{WithInsecure()}, + wantAuthority: "localhost", + }, + { + name: "unix absolute with custom dialer", + target: "unix:///sock.sock", + opts: []DialOption{WithInsecure(), WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, "", addr) + })}, + wantAuthority: "localhost", + }, + { + name: "localhost colon port", + target: "localhost:50051", + opts: []DialOption{WithInsecure()}, + wantAuthority: "localhost:50051", + }, + { + name: "colon port", + target: ":50051", + opts: []DialOption{WithInsecure()}, + wantAuthority: "localhost:50051", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cc, err := Dial(test.target, test.opts...) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", test.target, err) + } + defer cc.Close() + if cc.authority != test.wantAuthority { + t.Fatalf("cc.authority = %q, want %q", cc.authority, test.wantAuthority) + } + }) + } +} diff --git a/clientconn_parsed_target_test.go b/clientconn_parsed_target_test.go new file mode 100644 index 00000000000..fda06f9fa14 --- /dev/null +++ b/clientconn_parsed_target_test.go @@ -0,0 +1,183 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "google.golang.org/grpc/resolver" +) + +func (s) TestParsedTarget_Success_WithoutCustomDialer(t *testing.T) { + defScheme := resolver.GetDefaultScheme() + tests := []struct { + target string + wantParsed resolver.Target + }{ + // No scheme is specified. + {target: "", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: ""}}, + {target: "://", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "://"}}, + {target: ":///", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: ":///"}}, + {target: "://a/", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "://a/"}}, + {target: ":///a", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: ":///a"}}, + {target: "://a/b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "://a/b"}}, + {target: "/", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "/"}}, + {target: "a/b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a/b"}}, + {target: "a//b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a//b"}}, + {target: "google.com", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "google.com"}}, + {target: "google.com/?a=b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "google.com/?a=b"}}, + {target: "/unix/socket/address", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "/unix/socket/address"}}, + + // An unregistered scheme is specified. + {target: "a:///", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a:///"}}, + {target: "a://b/", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a://b/"}}, + {target: "a:///b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a:///b"}}, + {target: "a://b/c", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a://b/c"}}, + {target: "a:b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a:b"}}, + {target: "a:/b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a:/b"}}, + {target: "a://b", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "a://b"}}, + + // A registered scheme is specified. + {target: "dns:///google.com", wantParsed: resolver.Target{Scheme: "dns", Authority: "", Endpoint: "google.com"}}, + {target: "dns://a.server.com/google.com", wantParsed: resolver.Target{Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com"}}, + {target: "dns://a.server.com/google.com/?a=b", wantParsed: resolver.Target{Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com/?a=b"}}, + {target: "unix:///a/b/c", wantParsed: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}}, + {target: "unix-abstract:a/b/c", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a/b/c"}}, + {target: "unix-abstract:a b", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a b"}}, + {target: "unix-abstract:a:b", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a:b"}}, + {target: "unix-abstract:a-b", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a-b"}}, + {target: "unix-abstract:/ a///://::!@#$%^&*()b", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/ a///://::!@#$%^&*()b"}}, + {target: "unix-abstract:passthrough:abc", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "passthrough:abc"}}, + {target: "unix-abstract:unix:///abc", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "unix:///abc"}}, + {target: "unix-abstract:///a/b/c", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/a/b/c"}}, + {target: "unix-abstract:///", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/"}}, + {target: "unix-abstract://authority", wantParsed: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "//authority"}}, + {target: "unix://domain", wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "unix://domain"}}, + {target: "passthrough:///unix:///a/b/c", wantParsed: resolver.Target{Scheme: "passthrough", Authority: "", Endpoint: "unix:///a/b/c"}}, + } + + for _, test := range tests { + t.Run(test.target, func(t *testing.T) { + cc, err := Dial(test.target, WithInsecure()) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", test.target, err) + } + defer cc.Close() + + if gotParsed := cc.parsedTarget; gotParsed != test.wantParsed { + t.Errorf("cc.parsedTarget = %+v, want %+v", gotParsed, test.wantParsed) + } + }) + } +} + +func (s) TestParsedTarget_Failure_WithoutCustomDialer(t *testing.T) { + targets := []string{ + "unix://a/b/c", + "unix-abstract://authority/a/b/c", + } + + for _, target := range targets { + t.Run(target, func(t *testing.T) { + if cc, err := Dial(target, WithInsecure()); err == nil { + defer cc.Close() + t.Fatalf("Dial(%q) succeeded cc.parsedTarget = %+v, expected to fail", target, cc.parsedTarget) + } + }) + } +} + +func (s) TestParsedTarget_WithCustomDialer(t *testing.T) { + defScheme := resolver.GetDefaultScheme() + tests := []struct { + target string + wantParsed resolver.Target + wantDialerAddress string + }{ + // unix:[local_path], unix:[/absolute], and unix://[/absolute] have + // different behaviors with a custom dialer. + { + target: "unix:a/b/c", + wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "unix:a/b/c"}, + wantDialerAddress: "unix:a/b/c", + }, + { + target: "unix:/a/b/c", + wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "unix:/a/b/c"}, + wantDialerAddress: "unix:/a/b/c", + }, + { + target: "unix:///a/b/c", + wantParsed: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}, + wantDialerAddress: "unix:///a/b/c", + }, + { + target: "dns:///127.0.0.1:50051", + wantParsed: resolver.Target{Scheme: "dns", Authority: "", Endpoint: "127.0.0.1:50051"}, + wantDialerAddress: "127.0.0.1:50051", + }, + { + target: ":///127.0.0.1:50051", + wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: ":///127.0.0.1:50051"}, + wantDialerAddress: ":///127.0.0.1:50051", + }, + { + target: "dns://authority/127.0.0.1:50051", + wantParsed: resolver.Target{Scheme: "dns", Authority: "authority", Endpoint: "127.0.0.1:50051"}, + wantDialerAddress: "127.0.0.1:50051", + }, + { + target: "://authority/127.0.0.1:50051", + wantParsed: resolver.Target{Scheme: defScheme, Authority: "", Endpoint: "://authority/127.0.0.1:50051"}, + wantDialerAddress: "://authority/127.0.0.1:50051", + }, + } + + for _, test := range tests { + t.Run(test.target, func(t *testing.T) { + addrCh := make(chan string, 1) + dialer := func(ctx context.Context, address string) (net.Conn, error) { + addrCh <- address + return nil, errors.New("dialer error") + } + + cc, err := Dial(test.target, WithInsecure(), WithContextDialer(dialer)) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", test.target, err) + } + defer cc.Close() + + select { + case addr := <-addrCh: + if addr != test.wantDialerAddress { + t.Fatalf("address in custom dialer is %q, want %q", addr, test.wantDialerAddress) + } + case <-time.After(time.Second): + t.Fatal("timeout when waiting for custom dialer to be invoked") + } + if gotParsed := cc.parsedTarget; gotParsed != test.wantParsed { + t.Errorf("cc.parsedTarget for dial target %q = %+v, want %+v", test.target, gotParsed, test.wantParsed) + } + }) + } +} diff --git a/clientconn_state_transition_test.go b/clientconn_state_transition_test.go index 0c58131a1c6..2090c8de689 100644 --- a/clientconn_state_transition_test.go +++ b/clientconn_state_transition_test.go @@ -75,7 +75,7 @@ func (s) TestStateTransitions_SingleAddress(t *testing.T) { }, }, { - desc: "When the connection is closed, the client enters TRANSIENT FAILURE.", + desc: "When the connection is closed before the preface is sent, the client enters TRANSIENT FAILURE.", want: []connectivity.State{ connectivity.Connecting, connectivity.TransientFailure, @@ -167,6 +167,7 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s t.Fatal(err) } defer client.Close() + go stayConnected(client) stateNotifications := testBalancerBuilder.nextStateNotifier() @@ -193,11 +194,12 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s } } -// When a READY connection is closed, the client enters CONNECTING. +// When a READY connection is closed, the client enters IDLE then CONNECTING. func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { want := []connectivity.State{ connectivity.Connecting, connectivity.Ready, + connectivity.Idle, connectivity.Connecting, } @@ -210,7 +212,8 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { } defer lis.Close() - sawReady := make(chan struct{}) + sawReady := make(chan struct{}, 1) + defer close(sawReady) // Launch the server. go func() { @@ -239,6 +242,7 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { t.Fatal(err) } defer client.Close() + go stayConnected(client) stateNotifications := testBalancerBuilder.nextStateNotifier() @@ -250,7 +254,7 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want) case seen := <-stateNotifications: if seen == connectivity.Ready { - close(sawReady) + sawReady <- struct{}{} } if seen != want[i] { t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen) @@ -358,6 +362,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { want := []connectivity.State{ connectivity.Connecting, connectivity.Ready, + connectivity.Idle, connectivity.Connecting, } @@ -378,7 +383,8 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { defer lis2.Close() server1Done := make(chan struct{}) - sawReady := make(chan struct{}) + sawReady := make(chan struct{}, 1) + defer close(sawReady) // Launch server 1. go func() { @@ -400,12 +406,6 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { conn.Close() - _, err = lis1.Accept() - if err != nil { - t.Error(err) - return - } - close(server1Done) }() @@ -419,6 +419,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { t.Fatal(err) } defer client.Close() + go stayConnected(client) stateNotifications := testBalancerBuilder.nextStateNotifier() @@ -430,7 +431,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want) case seen := <-stateNotifications: if seen == connectivity.Ready { - close(sawReady) + sawReady <- struct{}{} } if seen != want[i] { t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen) diff --git a/clientconn_test.go b/clientconn_test.go index 6c61666b7ef..d276c7b5f2f 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -217,7 +217,7 @@ func (s) TestDialWaitsForServerSettingsAndFails(t *testing.T) { client.Close() t.Fatalf("Unexpected success (err=nil) while dialing") } - expectedMsg := "server handshake" + expectedMsg := "server preface" if !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) || !strings.Contains(err.Error(), expectedMsg) { t.Fatalf("DialContext(_) = %v; want a message that includes both %q and %q", err, context.DeadlineExceeded.Error(), expectedMsg) } @@ -289,6 +289,9 @@ func (s) TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) { if err != nil { t.Fatalf("Error while dialing. Err: %v", err) } + + go stayConnected(client) + // wait for connection to be accepted on the server. timer := time.NewTimer(time.Second * 10) select { @@ -311,9 +314,7 @@ func (s) TestBackoffWhenNoServerPrefaceReceived(t *testing.T) { defer lis.Close() done := make(chan struct{}) go func() { // Launch the server. - defer func() { - close(done) - }() + defer close(done) conn, err := lis.Accept() // Accept the connection only to close it immediately. if err != nil { t.Errorf("Error while accepting. Err: %v", err) @@ -340,13 +341,13 @@ func (s) TestBackoffWhenNoServerPrefaceReceived(t *testing.T) { prevAt = meow } }() - client, err := Dial(lis.Addr().String(), WithInsecure()) + cc, err := Dial(lis.Addr().String(), WithInsecure()) if err != nil { t.Fatalf("Error while dialing. Err: %v", err) } - defer client.Close() + defer cc.Close() + go stayConnected(cc) <-done - } func (s) TestWithTimeout(t *testing.T) { @@ -375,62 +376,6 @@ func (s) TestWithTransportCredentialsTLS(t *testing.T) { } } -func (s) TestDefaultAuthority(t *testing.T) { - target := "Non-Existent.Server:8080" - conn, err := Dial(target, WithInsecure()) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v, want _, ", err) - } - defer conn.Close() - if conn.authority != target { - t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, target) - } -} - -func (s) TestTLSServerNameOverwrite(t *testing.T) { - overwriteServerName := "over.write.server.name" - creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), overwriteServerName) - if err != nil { - t.Fatalf("Failed to create credentials %v", err) - } - conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds)) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v, want _, ", err) - } - defer conn.Close() - if conn.authority != overwriteServerName { - t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) - } -} - -func (s) TestWithAuthority(t *testing.T) { - overwriteServerName := "over.write.server.name" - conn, err := Dial("passthrough:///Non-Existent.Server:80", WithInsecure(), WithAuthority(overwriteServerName)) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v, want _, ", err) - } - defer conn.Close() - if conn.authority != overwriteServerName { - t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) - } -} - -func (s) TestWithAuthorityAndTLS(t *testing.T) { - overwriteServerName := "over.write.server.name" - creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), overwriteServerName) - if err != nil { - t.Fatalf("Failed to create credentials %v", err) - } - conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithAuthority("no.effect.authority")) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v, want _, ", err) - } - defer conn.Close() - if conn.authority != overwriteServerName { - t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) - } -} - // When creating a transport configured with n addresses, only calculate the // backoff once per "round" of attempts instead of once per address (n times // per "round" of attempts). @@ -735,16 +680,15 @@ func (s) TestClientUpdatesParamsAfterGoAway(t *testing.T) { time.Sleep(10 * time.Millisecond) cc.mu.RLock() v := cc.mkp.Time + cc.mu.RUnlock() if v == 20*time.Second { // Success - cc.mu.RUnlock() return } if ctx.Err() != nil { // Timeout t.Fatalf("cc.dopts.copts.Keepalive.Time = %v , want 20s", v) } - cc.mu.RUnlock() } } @@ -832,6 +776,7 @@ func (s) TestResetConnectBackoff(t *testing.T) { t.Fatalf("Dial() = _, %v; want _, nil", err) } defer cc.Close() + go stayConnected(cc) select { case <-dials: case <-time.NewTimer(10 * time.Second).C: @@ -986,6 +931,7 @@ func (s) TestUpdateAddresses_RetryFromFirstAddr(t *testing.T) { t.Fatal(err) } defer client.Close() + go stayConnected(client) timeout := time.After(5 * time.Second) @@ -1113,3 +1059,23 @@ func testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t *testing.T t.Fatal("default service config failed to be applied after 1s") } } + +// stayConnected makes cc stay connected by repeatedly calling cc.Connect() +// until the state becomes Shutdown or until 10 seconds elapses. +func stayConnected(cc *ClientConn) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for { + state := cc.GetState() + switch state { + case connectivity.Idle: + cc.Connect() + case connectivity.Shutdown: + return + } + if !cc.WaitForStateChange(ctx, state) { + return + } + } +} diff --git a/cmd/protoc-gen-go-grpc/grpc.go b/cmd/protoc-gen-go-grpc/grpc.go index 1e787344ebc..f45e0403fd4 100644 --- a/cmd/protoc-gen-go-grpc/grpc.go +++ b/cmd/protoc-gen-go-grpc/grpc.go @@ -24,7 +24,6 @@ import ( "strings" "google.golang.org/protobuf/compiler/protogen" - "google.golang.org/protobuf/types/descriptorpb" ) @@ -43,6 +42,14 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated filename := file.GeneratedFilenamePrefix + "_grpc.pb.go" g := gen.NewGeneratedFile(filename, file.GoImportPath) g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.") + g.P("// versions:") + g.P("// - protoc-gen-go-grpc v", version) + g.P("// - protoc ", protocVersion(gen)) + if file.Proto.GetOptions().GetDeprecated() { + g.P("// ", file.Desc.Path(), " is a deprecated file.") + } else { + g.P("// source: ", file.Desc.Path()) + } g.P() g.P("package ", file.GoPackageName) g.P() @@ -50,6 +57,18 @@ func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated return g } +func protocVersion(gen *protogen.Plugin) string { + v := gen.Request.GetCompilerVersion() + if v == nil { + return "(unknown)" + } + var suffix string + if s := v.GetSuffix(); s != "" { + suffix = "-" + s + } + return fmt.Sprintf("v%d.%d.%d%s", v.GetMajor(), v.GetMinor(), v.GetPatch(), suffix) +} + // generateFileContent generates the gRPC service definitions, excluding the package statement. func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) { if len(file.Services) == 0 { @@ -188,7 +207,7 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated g.P() // Server handler implementations. - var handlerNames []string + handlerNames := make([]string, 0, len(service.Methods)) for _, method := range service.Methods { hname := genServerMethod(gen, file, g, method) handlerNames = append(handlerNames, hname) diff --git a/connectivity/connectivity.go b/connectivity/connectivity.go index 01015626150..4a89926422b 100644 --- a/connectivity/connectivity.go +++ b/connectivity/connectivity.go @@ -18,7 +18,6 @@ // Package connectivity defines connectivity semantics. // For details, see https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md. -// All APIs in this package are experimental. package connectivity import ( @@ -45,7 +44,7 @@ func (s State) String() string { return "SHUTDOWN" default: logger.Errorf("unknown connectivity state: %d", s) - return "Invalid-State" + return "INVALID_STATE" } } @@ -61,3 +60,35 @@ const ( // Shutdown indicates the ClientConn has started shutting down. Shutdown ) + +// ServingMode indicates the current mode of operation of the server. +// +// Only xDS enabled gRPC servers currently report their serving mode. +type ServingMode int + +const ( + // ServingModeStarting indicates that the server is starting up. + ServingModeStarting ServingMode = iota + // ServingModeServing indicates that the server contains all required + // configuration and is serving RPCs. + ServingModeServing + // ServingModeNotServing indicates that the server is not accepting new + // connections. Existing connections will be closed gracefully, allowing + // in-progress RPCs to complete. A server enters this mode when it does not + // contain the required configuration to serve RPCs. + ServingModeNotServing +) + +func (s ServingMode) String() string { + switch s { + case ServingModeStarting: + return "STARTING" + case ServingModeServing: + return "SERVING" + case ServingModeNotServing: + return "NOT_SERVING" + default: + logger.Errorf("unknown serving mode: %d", s) + return "INVALID_MODE" + } +} diff --git a/credentials/alts/alts.go b/credentials/alts/alts.go index 729c4b43b5f..579adf210c4 100644 --- a/credentials/alts/alts.go +++ b/credentials/alts/alts.go @@ -37,6 +37,7 @@ import ( "google.golang.org/grpc/credentials/alts/internal/handshaker/service" altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/googlecloud" ) const ( @@ -54,6 +55,7 @@ const ( ) var ( + vmOnGCP bool once sync.Once maxRPCVersion = &altspb.RpcProtocolVersions_Version{ Major: protocolVersionMaxMajor, @@ -149,9 +151,8 @@ func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials { func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials { once.Do(func() { - vmOnGCP = isRunningOnGCP() + vmOnGCP = googlecloud.OnGCE() }) - if hsAddress == "" { hsAddress = hypervisorHandshakerServiceAddress } diff --git a/credentials/alts/alts_test.go b/credentials/alts/alts_test.go index cbb1656d20c..22ad5a48b09 100644 --- a/credentials/alts/alts_test.go +++ b/credentials/alts/alts_test.go @@ -1,3 +1,4 @@ +//go:build linux || windows // +build linux windows /* diff --git a/credentials/alts/internal/conn/record_test.go b/credentials/alts/internal/conn/record_test.go index 59d4f41e9e1..c18f902b401 100644 --- a/credentials/alts/internal/conn/record_test.go +++ b/credentials/alts/internal/conn/record_test.go @@ -40,11 +40,15 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } +const ( + rekeyRecordProtocol = "ALTSRP_GCM_AES128_REKEY" +) + var ( - nextProtocols = []string{"ALTSRP_GCM_AES128"} + recordProtocols = []string{rekeyRecordProtocol} altsRecordFuncs = map[string]ALTSRecordFunc{ // ALTS handshaker protocols. - "ALTSRP_GCM_AES128": func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) { + rekeyRecordProtocol: func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) { return NewAES128GCM(s, keyData) }, } @@ -77,7 +81,7 @@ func (c *testConn) Close() error { return nil } -func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string, protected []byte) *conn { +func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, rp string, protected []byte) *conn { key := []byte{ // 16 arbitrary bytes. 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} @@ -85,23 +89,23 @@ func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string, pro in: in, out: out, } - c, err := NewConn(&tc, side, np, key, protected) + c, err := NewConn(&tc, side, rp, key, protected) if err != nil { panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err)) } return c.(*conn) } -func newConnPair(np string, clientProtected []byte, serverProtected []byte) (client, server *conn) { +func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (client, server *conn) { clientBuf := new(bytes.Buffer) serverBuf := new(bytes.Buffer) - clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np, clientProtected) - serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np, serverProtected) + clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, rp, clientProtected) + serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, rp, serverProtected) return clientConn, serverConn } -func testPingPong(t *testing.T, np string) { - clientConn, serverConn := newConnPair(np, nil, nil) +func testPingPong(t *testing.T, rp string) { + clientConn, serverConn := newConnPair(rp, nil, nil) clientMsg := []byte("Client Message") if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { t.Fatalf("Client Write() = %v, %v; want %v, ", n, err, len(clientMsg)) @@ -128,13 +132,13 @@ func testPingPong(t *testing.T, np string) { } func (s) TestPingPong(t *testing.T) { - for _, np := range nextProtocols { - testPingPong(t, np) + for _, rp := range recordProtocols { + testPingPong(t, rp) } } -func testSmallReadBuffer(t *testing.T, np string) { - clientConn, serverConn := newConnPair(np, nil, nil) +func testSmallReadBuffer(t *testing.T, rp string) { + clientConn, serverConn := newConnPair(rp, nil, nil) msg := []byte("Very Important Message") if n, err := clientConn.Write(msg); err != nil { t.Fatalf("Write() = %v, %v; want %v, ", n, err, len(msg)) @@ -155,13 +159,13 @@ func testSmallReadBuffer(t *testing.T, np string) { } func (s) TestSmallReadBuffer(t *testing.T) { - for _, np := range nextProtocols { - testSmallReadBuffer(t, np) + for _, rp := range recordProtocols { + testSmallReadBuffer(t, rp) } } -func testLargeMsg(t *testing.T, np string) { - clientConn, serverConn := newConnPair(np, nil, nil) +func testLargeMsg(t *testing.T, rp string) { + clientConn, serverConn := newConnPair(rp, nil, nil) // msgLen is such that the length in the framing is larger than the // default size of one frame. msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 @@ -179,12 +183,12 @@ func testLargeMsg(t *testing.T, np string) { } func (s) TestLargeMsg(t *testing.T) { - for _, np := range nextProtocols { - testLargeMsg(t, np) + for _, rp := range recordProtocols { + testLargeMsg(t, rp) } } -func testIncorrectMsgType(t *testing.T, np string) { +func testIncorrectMsgType(t *testing.T, rp string) { // framedMsg is an empty ciphertext with correct framing but wrong // message type. framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize) @@ -193,7 +197,7 @@ func testIncorrectMsgType(t *testing.T, np string) { binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType) in := bytes.NewBuffer(framedMsg) - c := newTestALTSRecordConn(in, nil, core.ClientSide, np, nil) + c := newTestALTSRecordConn(in, nil, core.ClientSide, rp, nil) b := make([]byte, 1) if n, err := c.Read(b); n != 0 || err == nil { t.Fatalf("Read() = , want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType)) @@ -201,15 +205,15 @@ func testIncorrectMsgType(t *testing.T, np string) { } func (s) TestIncorrectMsgType(t *testing.T) { - for _, np := range nextProtocols { - testIncorrectMsgType(t, np) + for _, rp := range recordProtocols { + testIncorrectMsgType(t, rp) } } -func testFrameTooLarge(t *testing.T, np string) { +func testFrameTooLarge(t *testing.T, rp string) { buf := new(bytes.Buffer) - clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np, nil) - serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np, nil) + clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, rp, nil) + serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, rp, nil) // payloadLen is such that the length in the framing is larger than // allowed in one frame. payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 @@ -234,15 +238,15 @@ func testFrameTooLarge(t *testing.T, np string) { } func (s) TestFrameTooLarge(t *testing.T) { - for _, np := range nextProtocols { - testFrameTooLarge(t, np) + for _, rp := range recordProtocols { + testFrameTooLarge(t, rp) } } -func testWriteLargeData(t *testing.T, np string) { +func testWriteLargeData(t *testing.T, rp string) { // Test sending and receiving messages larger than the maximum write // buffer size. - clientConn, serverConn := newConnPair(np, nil, nil) + clientConn, serverConn := newConnPair(rp, nil, nil) // Message size is intentionally chosen to not be multiple of // payloadLengthLimtit. msgSize := altsWriteBufferMaxSize + (100 * 1024) @@ -277,25 +281,25 @@ func testWriteLargeData(t *testing.T, np string) { } func (s) TestWriteLargeData(t *testing.T) { - for _, np := range nextProtocols { - testWriteLargeData(t, np) + for _, rp := range recordProtocols { + testWriteLargeData(t, rp) } } -func testProtectedBuffer(t *testing.T, np string) { +func testProtectedBuffer(t *testing.T, rp string) { key := []byte{ // 16 arbitrary bytes. 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} // Encrypt a message to be passed to NewConn as a client-side protected // buffer. - newCrypto := protocols[np] + newCrypto := protocols[rp] if newCrypto == nil { - t.Fatalf("Unknown next protocol %q", np) + t.Fatalf("Unknown record protocol %q", rp) } crypto, err := newCrypto(core.ClientSide, key) if err != nil { - t.Fatalf("Failed to create a crypter for protocol %q: %v", np, err) + t.Fatalf("Failed to create a crypter for protocol %q: %v", rp, err) } msg := []byte("Client Protected Message") encryptedMsg, err := crypto.Encrypt(nil, msg) @@ -307,7 +311,7 @@ func testProtectedBuffer(t *testing.T, np string) { binary.LittleEndian.PutUint32(protectedMsg[4:], altsRecordMsgType) protectedMsg = append(protectedMsg, encryptedMsg...) - _, serverConn := newConnPair(np, nil, protectedMsg) + _, serverConn := newConnPair(rp, nil, protectedMsg) rcvClientMsg := make([]byte, len(msg)) if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil { t.Fatalf("Server Read() = %v, %v; want %v, ", n, err, len(rcvClientMsg)) @@ -318,7 +322,7 @@ func testProtectedBuffer(t *testing.T, np string) { } func (s) TestProtectedBuffer(t *testing.T) { - for _, np := range nextProtocols { - testProtectedBuffer(t, np) + for _, rp := range recordProtocols { + testProtectedBuffer(t, rp) } } diff --git a/credentials/alts/internal/handshaker/handshaker_test.go b/credentials/alts/internal/handshaker/handshaker_test.go index bf516dc53c8..14a0721054f 100644 --- a/credentials/alts/internal/handshaker/handshaker_test.go +++ b/credentials/alts/internal/handshaker/handshaker_test.go @@ -21,6 +21,7 @@ package handshaker import ( "bytes" "context" + "errors" "testing" "time" @@ -163,7 +164,8 @@ func (s) TestClientHandshake(t *testing.T) { go func() { _, context, err := chs.ClientHandshake(ctx) if err == nil && context == nil { - panic("expected non-nil ALTS context") + errc <- errors.New("expected non-nil ALTS context") + return } errc <- err chs.Close() @@ -219,7 +221,8 @@ func (s) TestServerHandshake(t *testing.T) { go func() { _, context, err := shs.ServerHandshake(ctx) if err == nil && context == nil { - panic("expected non-nil ALTS context") + errc <- errors.New("expected non-nil ALTS context") + return } errc <- err shs.Close() diff --git a/credentials/alts/internal/proto/grpc_gcp/handshaker_grpc.pb.go b/credentials/alts/internal/proto/grpc_gcp/handshaker_grpc.pb.go index efdbd13fa30..a02c4582815 100644 --- a/credentials/alts/internal/proto/grpc_gcp/handshaker_grpc.pb.go +++ b/credentials/alts/internal/proto/grpc_gcp/handshaker_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/gcp/handshaker.proto package grpc_gcp diff --git a/credentials/alts/utils.go b/credentials/alts/utils.go index 9a300bc19aa..cbfd056cfb1 100644 --- a/credentials/alts/utils.go +++ b/credentials/alts/utils.go @@ -21,14 +21,6 @@ package alts import ( "context" "errors" - "fmt" - "io" - "io/ioutil" - "log" - "os" - "os/exec" - "regexp" - "runtime" "strings" "google.golang.org/grpc/codes" @@ -36,92 +28,6 @@ import ( "google.golang.org/grpc/status" ) -const ( - linuxProductNameFile = "/sys/class/dmi/id/product_name" - windowsCheckCommand = "powershell.exe" - windowsCheckCommandArgs = "Get-WmiObject -Class Win32_BIOS" - powershellOutputFilter = "Manufacturer" - windowsManufacturerRegex = ":(.*)" -) - -type platformError string - -func (k platformError) Error() string { - return fmt.Sprintf("%s is not supported", string(k)) -} - -var ( - // The following two variables will be reassigned in tests. - runningOS = runtime.GOOS - manufacturerReader = func() (io.Reader, error) { - switch runningOS { - case "linux": - return os.Open(linuxProductNameFile) - case "windows": - cmd := exec.Command(windowsCheckCommand, windowsCheckCommandArgs) - out, err := cmd.Output() - if err != nil { - return nil, err - } - - for _, line := range strings.Split(strings.TrimSuffix(string(out), "\n"), "\n") { - if strings.HasPrefix(line, powershellOutputFilter) { - re := regexp.MustCompile(windowsManufacturerRegex) - name := re.FindString(line) - name = strings.TrimLeft(name, ":") - return strings.NewReader(name), nil - } - } - - return nil, errors.New("cannot determine the machine's manufacturer") - default: - return nil, platformError(runningOS) - } - } - vmOnGCP bool -) - -// isRunningOnGCP checks whether the local system, without doing a network request is -// running on GCP. -func isRunningOnGCP() bool { - manufacturer, err := readManufacturer() - if os.IsNotExist(err) { - return false - } - if err != nil { - log.Fatalf("failure to read manufacturer information: %v", err) - } - name := string(manufacturer) - switch runningOS { - case "linux": - name = strings.TrimSpace(name) - return name == "Google" || name == "Google Compute Engine" - case "windows": - name = strings.Replace(name, " ", "", -1) - name = strings.Replace(name, "\n", "", -1) - name = strings.Replace(name, "\r", "", -1) - return name == "Google" - default: - log.Fatal(platformError(runningOS)) - } - return false -} - -func readManufacturer() ([]byte, error) { - reader, err := manufacturerReader() - if err != nil { - return nil, err - } - if reader == nil { - return nil, errors.New("got nil reader") - } - manufacturer, err := ioutil.ReadAll(reader) - if err != nil { - return nil, fmt.Errorf("failed reading %v: %v", linuxProductNameFile, err) - } - return manufacturer, nil -} - // AuthInfoFromContext extracts the alts.AuthInfo object from the given context, // if it exists. This API should be used by gRPC server RPC handlers to get // information about the communicating peer. For client-side, use grpc.Peer() diff --git a/credentials/alts/utils_test.go b/credentials/alts/utils_test.go index 5b54b1d5f77..531cdfce6e3 100644 --- a/credentials/alts/utils_test.go +++ b/credentials/alts/utils_test.go @@ -1,3 +1,4 @@ +//go:build linux || windows // +build linux windows /* @@ -22,8 +23,6 @@ package alts import ( "context" - "io" - "os" "strings" "testing" "time" @@ -42,67 +41,6 @@ const ( defaultTestTimeout = 10 * time.Second ) -func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() { - tmpOS := runningOS - tmpReader := manufacturerReader - - // Set test OS and reader function. - runningOS = testOS - manufacturerReader = reader - return func() { - runningOS = tmpOS - manufacturerReader = tmpReader - } - -} - -func setup(testOS string, testReader io.Reader) func() { - reader := func() (io.Reader, error) { - return testReader, nil - } - return setupManufacturerReader(testOS, reader) -} - -func setupError(testOS string, err error) func() { - reader := func() (io.Reader, error) { - return nil, err - } - return setupManufacturerReader(testOS, reader) -} - -func (s) TestIsRunningOnGCP(t *testing.T) { - for _, tc := range []struct { - description string - testOS string - testReader io.Reader - out bool - }{ - // Linux tests. - {"linux: not a GCP platform", "linux", strings.NewReader("not GCP"), false}, - {"Linux: GCP platform (Google)", "linux", strings.NewReader("Google"), true}, - {"Linux: GCP platform (Google Compute Engine)", "linux", strings.NewReader("Google Compute Engine"), true}, - {"Linux: GCP platform (Google Compute Engine) with extra spaces", "linux", strings.NewReader(" Google Compute Engine "), true}, - // Windows tests. - {"windows: not a GCP platform", "windows", strings.NewReader("not GCP"), false}, - {"windows: GCP platform (Google)", "windows", strings.NewReader("Google"), true}, - {"windows: GCP platform (Google) with extra spaces", "windows", strings.NewReader(" Google "), true}, - } { - reverseFunc := setup(tc.testOS, tc.testReader) - if got, want := isRunningOnGCP(), tc.out; got != want { - t.Errorf("%v: isRunningOnGCP()=%v, want %v", tc.description, got, want) - } - reverseFunc() - } -} - -func (s) TestIsRunningOnGCPNoProductNameFile(t *testing.T) { - reverseFunc := setupError("linux", os.ErrNotExist) - if isRunningOnGCP() { - t.Errorf("ErrNotExist: isRunningOnGCP()=true, want false") - } - reverseFunc() -} - func (s) TestAuthInfoFromContext(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() diff --git a/credentials/credentials.go b/credentials/credentials.go index e69562e7878..7eee7e4ec12 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -30,7 +30,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc/attributes" - "google.golang.org/grpc/internal" + icredentials "google.golang.org/grpc/internal/credentials" ) // PerRPCCredentials defines the common interface for the credentials which need to @@ -188,15 +188,12 @@ type RequestInfo struct { AuthInfo AuthInfo } -// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object. -type requestInfoKey struct{} - // RequestInfoFromContext extracts the RequestInfo from the context if it exists. // // This API is experimental. func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) { - ri, ok = ctx.Value(requestInfoKey{}).(RequestInfo) - return + ri, ok = icredentials.RequestInfoFromContext(ctx).(RequestInfo) + return ri, ok } // ClientHandshakeInfo holds data to be passed to ClientHandshake. This makes @@ -211,16 +208,12 @@ type ClientHandshakeInfo struct { Attributes *attributes.Attributes } -// clientHandshakeInfoKey is a struct used as the key to store -// ClientHandshakeInfo in a context. -type clientHandshakeInfoKey struct{} - // ClientHandshakeInfoFromContext returns the ClientHandshakeInfo struct stored // in ctx. // // This API is experimental. func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo { - chi, _ := ctx.Value(clientHandshakeInfoKey{}).(ClientHandshakeInfo) + chi, _ := icredentials.ClientHandshakeInfoFromContext(ctx).(ClientHandshakeInfo) return chi } @@ -249,15 +242,6 @@ func CheckSecurityLevel(ai AuthInfo, level SecurityLevel) error { return nil } -func init() { - internal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context { - return context.WithValue(ctx, requestInfoKey{}, ri) - } - internal.NewClientHandshakeInfoContext = func(ctx context.Context, chi ClientHandshakeInfo) context.Context { - return context.WithValue(ctx, clientHandshakeInfoKey{}, chi) - } -} - // ChannelzSecurityInfo defines the interface that security protocols should implement // in order to provide security info to channelz. // diff --git a/credentials/go12.go b/credentials/go12.go deleted file mode 100644 index ccbf35b3312..00000000000 --- a/credentials/go12.go +++ /dev/null @@ -1,30 +0,0 @@ -// +build go1.12 - -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package credentials - -import "crypto/tls" - -// This init function adds cipher suite constants only defined in Go 1.12. -func init() { - cipherSuiteLookup[tls.TLS_AES_128_GCM_SHA256] = "TLS_AES_128_GCM_SHA256" - cipherSuiteLookup[tls.TLS_AES_256_GCM_SHA384] = "TLS_AES_256_GCM_SHA384" - cipherSuiteLookup[tls.TLS_CHACHA20_POLY1305_SHA256] = "TLS_CHACHA20_POLY1305_SHA256" -} diff --git a/credentials/google/google.go b/credentials/google/google.go index 7f3e240e475..63625a4b680 100644 --- a/credentials/google/google.go +++ b/credentials/google/google.go @@ -35,57 +35,63 @@ const tokenRequestTimeout = 30 * time.Second var logger = grpclog.Component("credentials") -// NewDefaultCredentials returns a credentials bundle that is configured to work -// with google services. +// DefaultCredentialsOptions constructs options to build DefaultCredentials. +type DefaultCredentialsOptions struct { + // PerRPCCreds is a per RPC credentials that is passed to a bundle. + PerRPCCreds credentials.PerRPCCredentials +} + +// NewDefaultCredentialsWithOptions returns a credentials bundle that is +// configured to work with google services. // // This API is experimental. -func NewDefaultCredentials() credentials.Bundle { - c := &creds{ - newPerRPCCreds: func() credentials.PerRPCCredentials { - ctx, cancel := context.WithTimeout(context.Background(), tokenRequestTimeout) - defer cancel() - perRPCCreds, err := oauth.NewApplicationDefault(ctx) - if err != nil { - logger.Warningf("google default creds: failed to create application oauth: %v", err) - } - return perRPCCreds - }, +func NewDefaultCredentialsWithOptions(opts DefaultCredentialsOptions) credentials.Bundle { + if opts.PerRPCCreds == nil { + ctx, cancel := context.WithTimeout(context.Background(), tokenRequestTimeout) + defer cancel() + var err error + opts.PerRPCCreds, err = oauth.NewApplicationDefault(ctx) + if err != nil { + logger.Warningf("NewDefaultCredentialsWithOptions: failed to create application oauth: %v", err) + } } + c := &creds{opts: opts} bundle, err := c.NewWithMode(internal.CredsBundleModeFallback) if err != nil { - logger.Warningf("google default creds: failed to create new creds: %v", err) + logger.Warningf("NewDefaultCredentialsWithOptions: failed to create new creds: %v", err) } return bundle } +// NewDefaultCredentials returns a credentials bundle that is configured to work +// with google services. +// +// This API is experimental. +func NewDefaultCredentials() credentials.Bundle { + return NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}) +} + // NewComputeEngineCredentials returns a credentials bundle that is configured to work // with google services. This API must only be used when running on GCE. Authentication configured // by this API represents the GCE VM's default service account. // // This API is experimental. func NewComputeEngineCredentials() credentials.Bundle { - c := &creds{ - newPerRPCCreds: func() credentials.PerRPCCredentials { - return oauth.NewComputeEngine() - }, - } - bundle, err := c.NewWithMode(internal.CredsBundleModeFallback) - if err != nil { - logger.Warningf("compute engine creds: failed to create new creds: %v", err) - } - return bundle + return NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{ + PerRPCCreds: oauth.NewComputeEngine(), + }) } // creds implements credentials.Bundle. type creds struct { + opts DefaultCredentialsOptions + // Supported modes are defined in internal/internal.go. mode string - // The transport credentials associated with this bundle. + // The active transport credentials associated with this bundle. transportCreds credentials.TransportCredentials - // The per RPC credentials associated with this bundle. + // The active per RPC credentials associated with this bundle. perRPCCreds credentials.PerRPCCredentials - // Creates new per RPC credentials - newPerRPCCreds func() credentials.PerRPCCredentials } func (c *creds) TransportCredentials() credentials.TransportCredentials { @@ -99,28 +105,37 @@ func (c *creds) PerRPCCredentials() credentials.PerRPCCredentials { return c.perRPCCreds } +var ( + newTLS = func() credentials.TransportCredentials { + return credentials.NewTLS(nil) + } + newALTS = func() credentials.TransportCredentials { + return alts.NewClientCreds(alts.DefaultClientOptions()) + } +) + // NewWithMode should make a copy of Bundle, and switch mode. Modifying the // existing Bundle may cause races. func (c *creds) NewWithMode(mode string) (credentials.Bundle, error) { newCreds := &creds{ - mode: mode, - newPerRPCCreds: c.newPerRPCCreds, + opts: c.opts, + mode: mode, } // Create transport credentials. switch mode { case internal.CredsBundleModeFallback: - newCreds.transportCreds = credentials.NewTLS(nil) + newCreds.transportCreds = newClusterTransportCreds(newTLS(), newALTS()) case internal.CredsBundleModeBackendFromBalancer, internal.CredsBundleModeBalancer: // Only the clients can use google default credentials, so we only need // to create new ALTS client creds here. - newCreds.transportCreds = alts.NewClientCreds(alts.DefaultClientOptions()) + newCreds.transportCreds = newALTS() default: return nil, fmt.Errorf("unsupported mode: %v", mode) } if mode == internal.CredsBundleModeFallback || mode == internal.CredsBundleModeBackendFromBalancer { - newCreds.perRPCCreds = newCreds.newPerRPCCreds() + newCreds.perRPCCreds = newCreds.opts.PerRPCCreds } return newCreds, nil diff --git a/credentials/google/google_test.go b/credentials/google/google_test.go new file mode 100644 index 00000000000..6a6e492ee77 --- /dev/null +++ b/credentials/google/google_test.go @@ -0,0 +1,131 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package google + +import ( + "context" + "net" + "testing" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" + icredentials "google.golang.org/grpc/internal/credentials" + "google.golang.org/grpc/resolver" +) + +type testCreds struct { + credentials.TransportCredentials + typ string +} + +func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, &testAuthInfo{typ: c.typ}, nil +} + +func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, &testAuthInfo{typ: c.typ}, nil +} + +type testAuthInfo struct { + typ string +} + +func (t *testAuthInfo) AuthType() string { + return t.typ +} + +var ( + testTLS = &testCreds{typ: "tls"} + testALTS = &testCreds{typ: "alts"} +) + +func overrideNewCredsFuncs() func() { + oldNewTLS := newTLS + newTLS = func() credentials.TransportCredentials { + return testTLS + } + oldNewALTS := newALTS + newALTS = func() credentials.TransportCredentials { + return testALTS + } + return func() { + newTLS = oldNewTLS + newALTS = oldNewALTS + } +} + +// TestClientHandshakeBasedOnClusterName that by default (without switching +// modes), ClientHandshake does either tls or alts base on the cluster name in +// attributes. +func TestClientHandshakeBasedOnClusterName(t *testing.T) { + defer overrideNewCredsFuncs()() + for bundleTyp, tc := range map[string]credentials.Bundle{ + "defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}), + "defaultCreds": NewDefaultCredentials(), + "computeCreds": NewComputeEngineCredentials(), + } { + tests := []struct { + name string + ctx context.Context + wantTyp string + }{ + { + name: "no cluster name", + ctx: context.Background(), + wantTyp: "tls", + }, + { + name: "with non-CFE cluster name", + ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ + Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes, + }), + // non-CFE backends should use alts. + wantTyp: "alts", + }, + { + name: "with CFE cluster name", + ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ + Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, cfeClusterName).Attributes, + }), + // CFE should use tls. + wantTyp: "tls", + }, + } + for _, tt := range tests { + t.Run(bundleTyp+" "+tt.name, func(t *testing.T) { + _, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil) + if err != nil { + t.Fatalf("ClientHandshake failed: %v", err) + } + if gotType := info.AuthType(); gotType != tt.wantTyp { + t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp) + } + + _, infoServer, err := tc.TransportCredentials().ServerHandshake(nil) + if err != nil { + t.Fatalf("ClientHandshake failed: %v", err) + } + // ServerHandshake should always do TLS. + if gotType := infoServer.AuthType(); gotType != "tls" { + t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls") + } + }) + } + } +} diff --git a/credentials/google/xds.go b/credentials/google/xds.go new file mode 100644 index 00000000000..588c685e259 --- /dev/null +++ b/credentials/google/xds.go @@ -0,0 +1,90 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package google + +import ( + "context" + "net" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" +) + +const cfeClusterName = "google-cfe" + +// clusterTransportCreds is a combo of TLS + ALTS. +// +// On the client, ClientHandshake picks TLS or ALTS based on address attributes. +// - if attributes has cluster name +// - if cluster name is "google_cfe", use TLS +// - otherwise, use ALTS +// - else, do TLS +// +// On the server, ServerHandshake always does TLS. +type clusterTransportCreds struct { + tls credentials.TransportCredentials + alts credentials.TransportCredentials +} + +func newClusterTransportCreds(tls, alts credentials.TransportCredentials) *clusterTransportCreds { + return &clusterTransportCreds{ + tls: tls, + alts: alts, + } +} + +func (c *clusterTransportCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + chi := credentials.ClientHandshakeInfoFromContext(ctx) + if chi.Attributes == nil { + return c.tls.ClientHandshake(ctx, authority, rawConn) + } + cn, ok := internal.GetXDSHandshakeClusterName(chi.Attributes) + if !ok || cn == cfeClusterName { + return c.tls.ClientHandshake(ctx, authority, rawConn) + } + // If attributes have cluster name, and cluster name is not cfe, it's a + // backend address, use ALTS. + return c.alts.ClientHandshake(ctx, authority, rawConn) +} + +func (c *clusterTransportCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return c.tls.ServerHandshake(conn) +} + +func (c *clusterTransportCreds) Info() credentials.ProtocolInfo { + // TODO: this always returns tls.Info now, because we don't have a cluster + // name to check when this method is called. This method doesn't affect + // anything important now. We may want to revisit this if it becomes more + // important later. + return c.tls.Info() +} + +func (c *clusterTransportCreds) Clone() credentials.TransportCredentials { + return &clusterTransportCreds{ + tls: c.tls.Clone(), + alts: c.alts.Clone(), + } +} + +func (c *clusterTransportCreds) OverrideServerName(s string) error { + if err := c.tls.OverrideServerName(s); err != nil { + return err + } + return c.alts.OverrideServerName(s) +} diff --git a/credentials/local/local_test.go b/credentials/local/local_test.go index 00ae39f07e5..47f8dbb4ec8 100644 --- a/credentials/local/local_test.go +++ b/credentials/local/local_test.go @@ -131,11 +131,13 @@ func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net. serverRawConn, err := lis.Accept() if err != nil { done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)} + return } serverAuthInfo, err := hs(serverRawConn) if err != nil { serverRawConn.Close() done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)} + return } done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil} } diff --git a/credentials/oauth/oauth.go b/credentials/oauth/oauth.go index 852ae375cfc..c748fd21ce2 100644 --- a/credentials/oauth/oauth.go +++ b/credentials/oauth/oauth.go @@ -23,6 +23,7 @@ import ( "context" "fmt" "io/ioutil" + "net/url" "sync" "golang.org/x/oauth2" @@ -56,6 +57,16 @@ func (ts TokenSource) RequireTransportSecurity() bool { return true } +// removeServiceNameFromJWTURI removes RPC service name from URI. +func removeServiceNameFromJWTURI(uri string) (string, error) { + parsed, err := url.Parse(uri) + if err != nil { + return "", err + } + parsed.Path = "/" + return parsed.String(), nil +} + type jwtAccess struct { jsonKey []byte } @@ -75,9 +86,15 @@ func NewJWTAccessFromKey(jsonKey []byte) (credentials.PerRPCCredentials, error) } func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + // Remove RPC service name from URI that will be used as audience + // in a self-signed JWT token. It follows https://google.aip.dev/auth/4111. + aud, err := removeServiceNameFromJWTURI(uri[0]) + if err != nil { + return nil, err + } // TODO: the returned TokenSource is reusable. Store it in a sync.Map, with // uri as the key, to avoid recreating for every RPC. - ts, err := google.JWTAccessTokenSourceFromJSON(j.jsonKey, uri[0]) + ts, err := google.JWTAccessTokenSourceFromJSON(j.jsonKey, aud) if err != nil { return nil, err } diff --git a/credentials/oauth/oauth_test.go b/credentials/oauth/oauth_test.go new file mode 100644 index 00000000000..7e62ecb36c1 --- /dev/null +++ b/credentials/oauth/oauth_test.go @@ -0,0 +1,60 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package oauth + +import ( + "strings" + "testing" +) + +func checkErrorMsg(err error, msg string) bool { + if err == nil && msg == "" { + return true + } else if err != nil { + return strings.Contains(err.Error(), msg) + } + return false +} + +func TestRemoveServiceNameFromJWTURI(t *testing.T) { + tests := []struct { + name string + uri string + wantedURI string + wantedErrMsg string + }{ + { + name: "invalid URI", + uri: "ht tp://foo.com", + wantedErrMsg: "first path segment in URL cannot contain colon", + }, + { + name: "valid URI", + uri: "https://foo.com/go/", + wantedURI: "https://foo.com/", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, err := removeServiceNameFromJWTURI(tt.uri); got != tt.wantedURI || !checkErrorMsg(err, tt.wantedErrMsg) { + t.Errorf("RemoveServiceNameFromJWTURI() = %s, %v, want %s, %v", got, err, tt.wantedURI, tt.wantedErrMsg) + } + }) + } +} diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go index 9285192a8eb..da5fa1ad16f 100644 --- a/credentials/sts/sts.go +++ b/credentials/sts/sts.go @@ -1,5 +1,3 @@ -// +build go1.13 - /* * * Copyright 2020 gRPC authors. diff --git a/credentials/sts/sts_test.go b/credentials/sts/sts_test.go index ac680e00111..dd634361d7c 100644 --- a/credentials/sts/sts_test.go +++ b/credentials/sts/sts_test.go @@ -1,5 +1,3 @@ -// +build go1.13 - /* * * Copyright 2020 gRPC authors. @@ -37,7 +35,7 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal" + icredentials "google.golang.org/grpc/internal/credentials" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" ) @@ -104,7 +102,7 @@ func createTestContext(ctx context.Context, s credentials.SecurityLevel) context Method: "testInfo", AuthInfo: auth, } - return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) + return icredentials.NewRequestInfoContext(ctx, ri) } // errReader implements the io.Reader interface and returns an error from the diff --git a/credentials/tls.go b/credentials/tls.go index 8ee7124f226..784822d0560 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -230,4 +230,7 @@ var cipherSuiteLookup = map[uint16]string{ tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + tls.TLS_AES_128_GCM_SHA256: "TLS_AES_128_GCM_SHA256", + tls.TLS_AES_256_GCM_SHA384: "TLS_AES_256_GCM_SHA384", + tls.TLS_CHACHA20_POLY1305_SHA256: "TLS_CHACHA20_POLY1305_SHA256", } diff --git a/credentials/tls/certprovider/distributor_test.go b/credentials/tls/certprovider/distributor_test.go index bec00e919bc..48d51375616 100644 --- a/credentials/tls/certprovider/distributor_test.go +++ b/credentials/tls/certprovider/distributor_test.go @@ -1,5 +1,3 @@ -// +build go1.13 - /* * * Copyright 2020 gRPC authors. diff --git a/credentials/tls/certprovider/meshca/builder.go b/credentials/tls/certprovider/meshca/builder.go deleted file mode 100644 index 4b8af7c9b3c..00000000000 --- a/credentials/tls/certprovider/meshca/builder.go +++ /dev/null @@ -1,165 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "crypto/x509" - "encoding/json" - "fmt" - "sync" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/sts" - "google.golang.org/grpc/credentials/tls/certprovider" - "google.golang.org/grpc/internal/backoff" -) - -const pluginName = "mesh_ca" - -// For overriding in unit tests. -var ( - grpcDialFunc = grpc.Dial - backoffFunc = backoff.DefaultExponential.Backoff -) - -func init() { - certprovider.Register(newPluginBuilder()) -} - -func newPluginBuilder() *pluginBuilder { - return &pluginBuilder{clients: make(map[ccMapKey]*refCountedCC)} -} - -// Key for the map containing ClientConns to the MeshCA server. Only the server -// name and the STS options (which is used to create call creds) from the plugin -// configuration determine if two configs can share the same ClientConn. Hence -// only those form the key to this map. -type ccMapKey struct { - name string - stsOpts sts.Options -} - -// refCountedCC wraps a grpc.ClientConn to MeshCA along with a reference count. -type refCountedCC struct { - cc *grpc.ClientConn - refCnt int -} - -// pluginBuilder is an implementation of the certprovider.Builder interface, -// which builds certificate provider instances to get certificates signed from -// the MeshCA. -type pluginBuilder struct { - // A collection of ClientConns to the MeshCA server along with a reference - // count. Provider instances whose config point to the same server name will - // end up sharing the ClientConn. - mu sync.Mutex - clients map[ccMapKey]*refCountedCC -} - -// ParseConfig parses the configuration to be passed to the MeshCA plugin -// implementation. Expects the config to be a json.RawMessage which contains a -// serialized JSON representation of the meshca_experimental.GoogleMeshCaConfig -// proto message. -// -// Takes care of sharing the ClientConn to the MeshCA server among -// different plugin instantiations. -func (b *pluginBuilder) ParseConfig(c interface{}) (*certprovider.BuildableConfig, error) { - data, ok := c.(json.RawMessage) - if !ok { - return nil, fmt.Errorf("meshca: unsupported config type: %T", c) - } - cfg, err := pluginConfigFromJSON(data) - if err != nil { - return nil, err - } - return certprovider.NewBuildableConfig(pluginName, cfg.canonical(), func(opts certprovider.BuildOptions) certprovider.Provider { - return b.buildFromConfig(cfg, opts) - }), nil -} - -// buildFromConfig builds a certificate provider instance for the given config -// and options. Provider instances are shared wherever possible. -func (b *pluginBuilder) buildFromConfig(cfg *pluginConfig, opts certprovider.BuildOptions) certprovider.Provider { - b.mu.Lock() - defer b.mu.Unlock() - - ccmk := ccMapKey{ - name: cfg.serverURI, - stsOpts: cfg.stsOpts, - } - rcc, ok := b.clients[ccmk] - if !ok { - // STS call credentials take care of exchanging a locally provisioned - // JWT token for an access token which will be accepted by the MeshCA. - callCreds, err := sts.NewCredentials(cfg.stsOpts) - if err != nil { - logger.Errorf("sts.NewCredentials() failed: %v", err) - return nil - } - - // MeshCA is a public endpoint whose certificate is Web-PKI compliant. - // So, we just need to use the system roots to authenticate the MeshCA. - cp, err := x509.SystemCertPool() - if err != nil { - logger.Errorf("x509.SystemCertPool() failed: %v", err) - return nil - } - transportCreds := credentials.NewClientTLSFromCert(cp, "") - - cc, err := grpcDialFunc(cfg.serverURI, grpc.WithTransportCredentials(transportCreds), grpc.WithPerRPCCredentials(callCreds)) - if err != nil { - logger.Errorf("grpc.Dial(%s) failed: %v", cfg.serverURI, err) - return nil - } - - rcc = &refCountedCC{cc: cc} - b.clients[ccmk] = rcc - } - rcc.refCnt++ - - p := newProviderPlugin(providerParams{ - cc: rcc.cc, - cfg: cfg, - opts: opts, - backoff: backoffFunc, - doneFunc: func() { - // The plugin implementation will invoke this function when it is - // being closed, and here we take care of closing the ClientConn - // when there are no more plugins using it. We need to acquire the - // lock before accessing the rcc from the enclosing function. - b.mu.Lock() - defer b.mu.Unlock() - rcc.refCnt-- - if rcc.refCnt == 0 { - logger.Infof("Closing grpc.ClientConn to %s", ccmk.name) - rcc.cc.Close() - delete(b.clients, ccmk) - } - }, - }) - return p -} - -// Name returns the MeshCA plugin name. -func (b *pluginBuilder) Name() string { - return pluginName -} diff --git a/credentials/tls/certprovider/meshca/builder_test.go b/credentials/tls/certprovider/meshca/builder_test.go deleted file mode 100644 index 79035d008d9..00000000000 --- a/credentials/tls/certprovider/meshca/builder_test.go +++ /dev/null @@ -1,177 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "context" - "encoding/json" - "fmt" - "testing" - - "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials/tls/certprovider" - "google.golang.org/grpc/internal/testutils" -) - -func overrideHTTPFuncs() func() { - // Directly override the functions which are used to read the zone and - // audience instead of overriding the http.Client. - origReadZone := readZoneFunc - readZoneFunc = func(httpDoer) string { return "test-zone" } - origReadAudience := readAudienceFunc - readAudienceFunc = func(httpDoer) string { return "test-audience" } - return func() { - readZoneFunc = origReadZone - readAudienceFunc = origReadAudience - } -} - -func (s) TestBuildSameConfig(t *testing.T) { - defer overrideHTTPFuncs()() - - // We will attempt to create `cnt` number of providers. So we create a - // channel of the same size here, even though we expect only one ClientConn - // to be pushed into this channel. This makes sure that even if more than - // one ClientConn ends up being created, the Build() call does not block. - const cnt = 5 - ccChan := testutils.NewChannelWithSize(cnt) - - // Override the dial func to dial a dummy MeshCA endpoint, and also push the - // returned ClientConn on a channel to be inspected by the test. - origDialFunc := grpcDialFunc - grpcDialFunc = func(string, ...grpc.DialOption) (*grpc.ClientConn, error) { - cc, err := grpc.Dial("dummy-meshca-endpoint", grpc.WithInsecure()) - ccChan.Send(cc) - return cc, err - } - defer func() { grpcDialFunc = origDialFunc }() - - // Parse a good config to generate a stable config which will be passed to - // invocations of Build(). - builder := newPluginBuilder() - buildableConfig, err := builder.ParseConfig(goodConfigFullySpecified) - if err != nil { - t.Fatalf("builder.ParseConfig(%q) failed: %v", goodConfigFullySpecified, err) - } - - // Create multiple providers with the same config. All these providers must - // end up sharing the same ClientConn. - providers := []certprovider.Provider{} - for i := 0; i < cnt; i++ { - p, err := buildableConfig.Build(certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("Build(%+v) failed: %v", buildableConfig, err) - } - providers = append(providers, p) - } - - // Make sure only one ClientConn is created. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - val, err := ccChan.Receive(ctx) - if err != nil { - t.Fatalf("Failed to create ClientConn: %v", err) - } - testCC := val.(*grpc.ClientConn) - - // Attempt to read the second ClientConn should timeout. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer cancel() - if _, err := ccChan.Receive(ctx); err != context.DeadlineExceeded { - t.Fatal("Builder created more than one ClientConn") - } - - for _, p := range providers { - p.Close() - } - - for { - state := testCC.GetState() - if state == connectivity.Shutdown { - break - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if !testCC.WaitForStateChange(ctx, state) { - t.Fatalf("timeout waiting for clientConn state to change from %s", state) - } - } -} - -func (s) TestBuildDifferentConfig(t *testing.T) { - defer overrideHTTPFuncs()() - - // We will attempt to create two providers with different configs. So we - // expect two ClientConns to be pushed on to this channel. - const cnt = 2 - ccChan := testutils.NewChannelWithSize(cnt) - - // Override the dial func to dial a dummy MeshCA endpoint, and also push the - // returned ClientConn on a channel to be inspected by the test. - origDialFunc := grpcDialFunc - grpcDialFunc = func(string, ...grpc.DialOption) (*grpc.ClientConn, error) { - cc, err := grpc.Dial("dummy-meshca-endpoint", grpc.WithInsecure()) - ccChan.Send(cc) - return cc, err - } - defer func() { grpcDialFunc = origDialFunc }() - - builder := newPluginBuilder() - providers := []certprovider.Provider{} - for i := 0; i < cnt; i++ { - // Copy the good test config and modify the serverURI to make sure that - // a new provider is created for the config. - inputConfig := json.RawMessage(fmt.Sprintf(goodConfigFormatStr, fmt.Sprintf("test-mesh-ca:%d", i))) - buildableConfig, err := builder.ParseConfig(inputConfig) - if err != nil { - t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err) - } - - p, err := buildableConfig.Build(certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("Build(%+v) failed: %v", buildableConfig, err) - } - providers = append(providers, p) - } - - // Make sure two ClientConns are created. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - for i := 0; i < cnt; i++ { - if _, err := ccChan.Receive(ctx); err != nil { - t.Fatalf("Failed to create ClientConn: %v", err) - } - } - - // Close the first provider, and attempt to read key material from the - // second provider. The call to read key material should timeout, but it - // should not return certprovider.errProviderClosed. - providers[0].Close() - ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer cancel() - if _, err := providers[1].KeyMaterial(ctx); err != context.DeadlineExceeded { - t.Fatalf("provider.KeyMaterial(ctx) = %v, want contextDeadlineExceeded", err) - } - - // Close the second provider to make sure that the leakchecker is happy. - providers[1].Close() -} diff --git a/credentials/tls/certprovider/meshca/config.go b/credentials/tls/certprovider/meshca/config.go deleted file mode 100644 index c0772b3bb7e..00000000000 --- a/credentials/tls/certprovider/meshca/config.go +++ /dev/null @@ -1,310 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/http/httputil" - "path" - "strings" - "time" - - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - "github.com/golang/protobuf/ptypes" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/types/known/durationpb" - - "google.golang.org/grpc/credentials/sts" -) - -const ( - // GKE metadata server endpoint. - mdsBaseURI = "http://metadata.google.internal/" - mdsRequestTimeout = 5 * time.Second - - // The following are default values used in the interaction with MeshCA. - defaultMeshCaEndpoint = "meshca.googleapis.com" - defaultCallTimeout = 10 * time.Second - defaultCertLifetimeSecs = 86400 // 24h in seconds - defaultCertGraceTimeSecs = 43200 // 12h in seconds - defaultKeyTypeRSA = "RSA" - defaultKeySize = 2048 - - // The following are default values used in the interaction with STS or - // Secure Token Service, which is used to exchange the JWT token for an - // access token. - defaultSTSEndpoint = "securetoken.googleapis.com" - defaultCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" - defaultRequestedTokenType = "urn:ietf:params:oauth:token-type:access_token" - defaultSubjectTokenType = "urn:ietf:params:oauth:token-type:jwt" -) - -// For overriding in unit tests. -var ( - makeHTTPDoer = makeHTTPClient - readZoneFunc = readZone - readAudienceFunc = readAudience -) - -type pluginConfig struct { - serverURI string - stsOpts sts.Options - callTimeout time.Duration - certLifetime time.Duration - certGraceTime time.Duration - keyType string - keySize int - location string -} - -// Type of key to be embedded in CSRs sent to the MeshCA. -const ( - keyTypeUnknown = 0 - keyTypeRSA = 1 -) - -// pluginConfigFromJSON parses the provided config in JSON. -// -// For certain values missing in the config, we use default values defined at -// the top of this file. -// -// If the location field or STS audience field is missing, we try talking to the -// GKE Metadata server and try to infer these values. If this attempt does not -// succeed, we let those fields have empty values. -func pluginConfigFromJSON(data json.RawMessage) (*pluginConfig, error) { - // This anonymous struct corresponds to the expected JSON config. - cfgJSON := &struct { - Server json.RawMessage `json:"server,omitempty"` // Expect a v3corepb.ApiConfigSource - CertificateLifetime json.RawMessage `json:"certificate_lifetime,omitempty"` // Expect a durationpb.Duration - RenewalGracePeriod json.RawMessage `json:"renewal_grace_period,omitempty"` // Expect a durationpb.Duration - KeyType int `json:"key_type,omitempty"` - KeySize int `json:"key_size,omitempty"` - Location string `json:"location,omitempty"` - }{} - if err := json.Unmarshal(data, cfgJSON); err != nil { - return nil, fmt.Errorf("meshca: failed to unmarshal config: %v", err) - } - - // Further unmarshal fields represented as json.RawMessage in the above - // anonymous struct, and use default values if not specified. - server := &v3corepb.ApiConfigSource{} - if cfgJSON.Server != nil { - if err := protojson.Unmarshal(cfgJSON.Server, server); err != nil { - return nil, fmt.Errorf("meshca: protojson.Unmarshal(%+v) failed: %v", cfgJSON.Server, err) - } - } - certLifetime := &durationpb.Duration{Seconds: defaultCertLifetimeSecs} - if cfgJSON.CertificateLifetime != nil { - if err := protojson.Unmarshal(cfgJSON.CertificateLifetime, certLifetime); err != nil { - return nil, fmt.Errorf("meshca: protojson.Unmarshal(%+v) failed: %v", cfgJSON.CertificateLifetime, err) - } - } - certGraceTime := &durationpb.Duration{Seconds: defaultCertGraceTimeSecs} - if cfgJSON.RenewalGracePeriod != nil { - if err := protojson.Unmarshal(cfgJSON.RenewalGracePeriod, certGraceTime); err != nil { - return nil, fmt.Errorf("meshca: protojson.Unmarshal(%+v) failed: %v", cfgJSON.RenewalGracePeriod, err) - } - } - - if api := server.GetApiType(); api != v3corepb.ApiConfigSource_GRPC { - return nil, fmt.Errorf("meshca: server has apiType %s, want %s", api, v3corepb.ApiConfigSource_GRPC) - } - - pc := &pluginConfig{ - certLifetime: certLifetime.AsDuration(), - certGraceTime: certGraceTime.AsDuration(), - } - gs := server.GetGrpcServices() - if l := len(gs); l != 1 { - return nil, fmt.Errorf("meshca: number of gRPC services in config is %d, expected 1", l) - } - grpcService := gs[0] - googGRPC := grpcService.GetGoogleGrpc() - if googGRPC == nil { - return nil, errors.New("meshca: missing google gRPC service in config") - } - pc.serverURI = googGRPC.GetTargetUri() - if pc.serverURI == "" { - pc.serverURI = defaultMeshCaEndpoint - } - - callCreds := googGRPC.GetCallCredentials() - if len(callCreds) == 0 { - return nil, errors.New("meshca: missing call credentials in config") - } - var stsCallCreds *v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService - for _, cc := range callCreds { - if stsCallCreds = cc.GetStsService(); stsCallCreds != nil { - break - } - } - if stsCallCreds == nil { - return nil, errors.New("meshca: missing STS call credentials in config") - } - if stsCallCreds.GetSubjectTokenPath() == "" { - return nil, errors.New("meshca: missing subjectTokenPath in STS call credentials config") - } - pc.stsOpts = makeStsOptsWithDefaults(stsCallCreds) - - var err error - if pc.callTimeout, err = ptypes.Duration(grpcService.GetTimeout()); err != nil { - pc.callTimeout = defaultCallTimeout - } - switch cfgJSON.KeyType { - case keyTypeUnknown, keyTypeRSA: - pc.keyType = defaultKeyTypeRSA - default: - return nil, fmt.Errorf("meshca: unsupported key type: %s, only support RSA keys", pc.keyType) - } - pc.keySize = cfgJSON.KeySize - if pc.keySize == 0 { - pc.keySize = defaultKeySize - } - pc.location = cfgJSON.Location - if pc.location == "" { - pc.location = readZoneFunc(makeHTTPDoer()) - } - - return pc, nil -} - -func (pc *pluginConfig) canonical() []byte { - return []byte(fmt.Sprintf("%s:%s:%s:%s:%s:%s:%d:%s", pc.serverURI, pc.stsOpts, pc.callTimeout, pc.certLifetime, pc.certGraceTime, pc.keyType, pc.keySize, pc.location)) -} - -func makeStsOptsWithDefaults(stsCallCreds *v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService) sts.Options { - opts := sts.Options{ - TokenExchangeServiceURI: stsCallCreds.GetTokenExchangeServiceUri(), - Resource: stsCallCreds.GetResource(), - Audience: stsCallCreds.GetAudience(), - Scope: stsCallCreds.GetScope(), - RequestedTokenType: stsCallCreds.GetRequestedTokenType(), - SubjectTokenPath: stsCallCreds.GetSubjectTokenPath(), - SubjectTokenType: stsCallCreds.GetSubjectTokenType(), - ActorTokenPath: stsCallCreds.GetActorTokenPath(), - ActorTokenType: stsCallCreds.GetActorTokenType(), - } - - // Use sane defaults for unspecified fields. - if opts.TokenExchangeServiceURI == "" { - opts.TokenExchangeServiceURI = defaultSTSEndpoint - } - if opts.Audience == "" { - opts.Audience = readAudienceFunc(makeHTTPDoer()) - } - if opts.Scope == "" { - opts.Scope = defaultCloudPlatformScope - } - if opts.RequestedTokenType == "" { - opts.RequestedTokenType = defaultRequestedTokenType - } - if opts.SubjectTokenType == "" { - opts.SubjectTokenType = defaultSubjectTokenType - } - return opts -} - -// httpDoer wraps the single method on the http.Client type that we use. This -// helps with overriding in unit tests. -type httpDoer interface { - Do(req *http.Request) (*http.Response, error) -} - -func makeHTTPClient() httpDoer { - return &http.Client{Timeout: mdsRequestTimeout} -} - -func readMetadata(client httpDoer, uriPath string) (string, error) { - req, err := http.NewRequest("GET", mdsBaseURI+uriPath, nil) - if err != nil { - return "", err - } - req.Header.Add("Metadata-Flavor", "Google") - - resp, err := client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", err - } - if resp.StatusCode != http.StatusOK { - dump, err := httputil.DumpRequestOut(req, false) - if err != nil { - logger.Warningf("Failed to dump HTTP request: %v", err) - } - logger.Warningf("Request %q returned status %v", dump, resp.StatusCode) - } - return string(body), err -} - -func readZone(client httpDoer) string { - zoneURI := "computeMetadata/v1/instance/zone" - data, err := readMetadata(client, zoneURI) - if err != nil { - logger.Warningf("GET %s failed: %v", path.Join(mdsBaseURI, zoneURI), err) - return "" - } - - // The output returned by the metadata server looks like this: - // projects//zones/ - parts := strings.Split(data, "/") - if len(parts) == 0 { - logger.Warningf("GET %s returned {%s}, does not match expected format {projects//zones/}", path.Join(mdsBaseURI, zoneURI)) - return "" - } - return parts[len(parts)-1] -} - -// readAudience constructs the audience field to be used in the STS request, if -// it is not specified in the plugin configuration. -// -// "identitynamespace:{TRUST_DOMAIN}:{GKE_CLUSTER_URL}" is the format of the -// audience field. When workload identity is enabled on a GCP project, a default -// trust domain is created whose value is "{PROJECT_ID}.svc.id.goog". The format -// of the GKE_CLUSTER_URL is: -// https://container.googleapis.com/v1/projects/{PROJECT_ID}/zones/{ZONE}/clusters/{CLUSTER_NAME}. -func readAudience(client httpDoer) string { - projURI := "computeMetadata/v1/project/project-id" - project, err := readMetadata(client, projURI) - if err != nil { - logger.Warningf("GET %s failed: %v", path.Join(mdsBaseURI, projURI), err) - return "" - } - trustDomain := fmt.Sprintf("%s.svc.id.goog", project) - - clusterURI := "computeMetadata/v1/instance/attributes/cluster-name" - cluster, err := readMetadata(client, clusterURI) - if err != nil { - logger.Warningf("GET %s failed: %v", path.Join(mdsBaseURI, clusterURI), err) - return "" - } - zone := readZoneFunc(client) - clusterURL := fmt.Sprintf("https://container.googleapis.com/v1/projects/%s/zones/%s/clusters/%s", project, zone, cluster) - audience := fmt.Sprintf("identitynamespace:%s:%s", trustDomain, clusterURL) - return audience -} diff --git a/credentials/tls/certprovider/meshca/config_test.go b/credentials/tls/certprovider/meshca/config_test.go deleted file mode 100644 index 5deb484f341..00000000000 --- a/credentials/tls/certprovider/meshca/config_test.go +++ /dev/null @@ -1,375 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - - "google.golang.org/grpc/internal/grpctest" - "google.golang.org/grpc/internal/testutils" -) - -const ( - testProjectID = "test-project-id" - testGKECluster = "test-gke-cluster" - testGCEZone = "test-zone" -) - -type s struct { - grpctest.Tester -} - -func Test(t *testing.T) { - grpctest.RunSubTests(t, s{}) -} - -var ( - goodConfigFormatStr = ` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "target_uri": %q, - "call_credentials": [ - { - "access_token": "foo" - }, - { - "sts_service": { - "token_exchange_service_uri": "http://test-sts", - "resource": "test-resource", - "audience": "test-audience", - "scope": "test-scope", - "requested_token_type": "test-requested-token-type", - "subject_token_path": "test-subject-token-path", - "subject_token_type": "test-subject-token-type", - "actor_token_path": "test-actor-token-path", - "actor_token_type": "test-actor-token-type" - } - } - ] - }, - "timeout": "10s" - } - ] - }, - "certificate_lifetime": "86400s", - "renewal_grace_period": "43200s", - "key_type": 1, - "key_size": 2048, - "location": "us-west1-b" - }` - goodConfigWithDefaults = json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "call_credentials": [ - { - "sts_service": { - "subject_token_path": "test-subject-token-path" - } - } - ] - }, - "timeout": "10s" - } - ] - } - }`) -) - -var goodConfigFullySpecified = json.RawMessage(fmt.Sprintf(goodConfigFormatStr, "test-meshca")) - -// verifyReceivedRequest reads the HTTP request received by the fake client -// (exposed through a channel), and verifies that it matches the expected -// request. -func verifyReceivedRequest(fc *testutils.FakeHTTPClient, wantURI string) error { - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - val, err := fc.ReqChan.Receive(ctx) - if err != nil { - return err - } - gotReq := val.(*http.Request) - if gotURI := gotReq.URL.String(); gotURI != wantURI { - return fmt.Errorf("request contains URL %q want %q", gotURI, wantURI) - } - if got, want := gotReq.Header.Get("Metadata-Flavor"), "Google"; got != want { - return fmt.Errorf("request contains flavor %q want %q", got, want) - } - return nil -} - -// TestParseConfigSuccessFullySpecified tests the case where the config is fully -// specified and no defaults are required. -func (s) TestParseConfigSuccessFullySpecified(t *testing.T) { - wantConfig := "test-meshca:http://test-sts:test-resource:test-audience:test-scope:test-requested-token-type:test-subject-token-path:test-subject-token-type:test-actor-token-path:test-actor-token-type:10s:24h0m0s:12h0m0s:RSA:2048:us-west1-b" - - cfg, err := pluginConfigFromJSON(goodConfigFullySpecified) - if err != nil { - t.Fatalf("pluginConfigFromJSON(%q) failed: %v", goodConfigFullySpecified, err) - } - gotConfig := cfg.canonical() - if diff := cmp.Diff(wantConfig, string(gotConfig)); diff != "" { - t.Errorf("pluginConfigFromJSON(%q) returned config does not match expected (-want +got):\n%s", string(goodConfigFullySpecified), diff) - } -} - -// TestParseConfigSuccessWithDefaults tests cases where the config is not fully -// specified, and we end up using some sane defaults. -func (s) TestParseConfigSuccessWithDefaults(t *testing.T) { - wantConfig := fmt.Sprintf("%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s", - "meshca.googleapis.com", // Mesh CA Server URI. - "securetoken.googleapis.com", // STS Server URI. - "", // STS Resource Name. - "identitynamespace:test-project-id.svc.id.goog:https://container.googleapis.com/v1/projects/test-project-id/zones/test-zone/clusters/test-gke-cluster", // STS Audience. - "https://www.googleapis.com/auth/cloud-platform", // STS Scope. - "urn:ietf:params:oauth:token-type:access_token", // STS requested token type. - "test-subject-token-path", // STS subject token path. - "urn:ietf:params:oauth:token-type:jwt", // STS subject token type. - "", // STS actor token path. - "", // STS actor token type. - "10s", // Call timeout. - "24h0m0s", // Cert life time. - "12h0m0s", // Cert grace time. - "RSA", // Key type - "2048", // Key size - "test-zone", // Zone - ) - - // We expect the config parser to make four HTTP requests and receive four - // responses. Hence we setup the request and response channels in the fake - // client with appropriate buffer size. - fc := &testutils.FakeHTTPClient{ - ReqChan: testutils.NewChannelWithSize(4), - RespChan: testutils.NewChannelWithSize(4), - } - // Set up the responses to be delivered to the config parser by the fake - // client. The config parser expects responses with project_id, - // gke_cluster_id and gce_zone. The zone is read twice, once as part of - // reading the STS audience and once to get location metadata. - fc.RespChan.Send(&http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader([]byte(testProjectID))), - }) - fc.RespChan.Send(&http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader([]byte(testGKECluster))), - }) - fc.RespChan.Send(&http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf("projects/%s/zones/%s", testProjectID, testGCEZone)))), - }) - fc.RespChan.Send(&http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf("projects/%s/zones/%s", testProjectID, testGCEZone)))), - }) - // Override the http.Client with our fakeClient. - origMakeHTTPDoer := makeHTTPDoer - makeHTTPDoer = func() httpDoer { return fc } - defer func() { makeHTTPDoer = origMakeHTTPDoer }() - - // Spawn a goroutine to verify the HTTP requests sent out as part of the - // config parsing. - errCh := make(chan error, 1) - go func() { - if err := verifyReceivedRequest(fc, "http://metadata.google.internal/computeMetadata/v1/project/project-id"); err != nil { - errCh <- err - return - } - if err := verifyReceivedRequest(fc, "http://metadata.google.internal/computeMetadata/v1/instance/attributes/cluster-name"); err != nil { - errCh <- err - return - } - if err := verifyReceivedRequest(fc, "http://metadata.google.internal/computeMetadata/v1/instance/zone"); err != nil { - errCh <- err - return - } - errCh <- nil - }() - - cfg, err := pluginConfigFromJSON(goodConfigWithDefaults) - if err != nil { - t.Fatalf("pluginConfigFromJSON(%q) failed: %v", goodConfigWithDefaults, err) - } - gotConfig := cfg.canonical() - if diff := cmp.Diff(wantConfig, string(gotConfig)); diff != "" { - t.Errorf("builder.ParseConfig(%q) returned config does not match expected (-want +got):\n%s", goodConfigWithDefaults, diff) - } - - if err := <-errCh; err != nil { - t.Fatal(err) - } -} - -// TestParseConfigFailureCases tests several invalid configs which all result in -// config parsing failures. -func (s) TestParseConfigFailureCases(t *testing.T) { - tests := []struct { - desc string - inputConfig json.RawMessage - wantErr string - }{ - { - desc: "invalid JSON", - inputConfig: json.RawMessage(`bad bad json`), - wantErr: "failed to unmarshal config", - }, - { - desc: "bad apiType", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 1 - } - }`), - wantErr: "server has apiType REST, want GRPC", - }, - { - desc: "no grpc services", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2 - } - }`), - wantErr: "number of gRPC services in config is 0, expected 1", - }, - { - desc: "too many grpc services", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [{}, {}] - } - }`), - wantErr: "number of gRPC services in config is 2, expected 1", - }, - { - desc: "missing google grpc service", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "envoyGrpc": {} - } - ] - } - }`), - wantErr: "missing google gRPC service in config", - }, - { - desc: "missing call credentials", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "target_uri": "foo" - } - } - ] - } - }`), - wantErr: "missing call credentials in config", - }, - { - desc: "missing STS call credentials", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "target_uri": "foo", - "call_credentials": [ - { - "access_token": "foo" - } - ] - } - } - ] - } - }`), - wantErr: "missing STS call credentials in config", - }, - { - desc: "with no defaults", - inputConfig: json.RawMessage(` - { - "server": { - "api_type": 2, - "grpc_services": [ - { - "googleGrpc": { - "target_uri": "foo", - "call_credentials": [ - { - "sts_service": {} - } - ] - } - } - ] - } - }`), - wantErr: "missing subjectTokenPath in STS call credentials config", - }, - } - - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - cfg, err := pluginConfigFromJSON(test.inputConfig) - if err == nil { - t.Fatalf("pluginConfigFromJSON(%q) = %v, expected to return error (%v)", test.inputConfig, string(cfg.canonical()), test.wantErr) - - } - if !strings.Contains(err.Error(), test.wantErr) { - t.Fatalf("builder.ParseConfig(%q) = (%v), want error (%v)", test.inputConfig, err, test.wantErr) - } - }) - } -} diff --git a/credentials/tls/certprovider/meshca/internal/v1/meshca.pb.go b/credentials/tls/certprovider/meshca/internal/v1/meshca.pb.go deleted file mode 100644 index 387f8c55abc..00000000000 --- a/credentials/tls/certprovider/meshca/internal/v1/meshca.pb.go +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright 2019 Istio Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.25.0 -// protoc v3.14.0 -// source: istio/google/security/meshca/v1/meshca.proto - -package google_security_meshca_v1 - -import ( - proto "github.com/golang/protobuf/proto" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - durationpb "google.golang.org/protobuf/types/known/durationpb" - reflect "reflect" - sync "sync" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -// This is a compile-time assertion that a sufficiently up-to-date version -// of the legacy proto package is being used. -const _ = proto.ProtoPackageIsVersion4 - -// Certificate request message. -type MeshCertificateRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The request ID must be a valid UUID with the exception that zero UUID is - // not supported (00000000-0000-0000-0000-000000000000). - RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - // PEM-encoded certificate request. - Csr string `protobuf:"bytes,2,opt,name=csr,proto3" json:"csr,omitempty"` - // Optional: requested certificate validity period. - Validity *durationpb.Duration `protobuf:"bytes,3,opt,name=validity,proto3" json:"validity,omitempty"` // Reserved 4 -} - -func (x *MeshCertificateRequest) Reset() { - *x = MeshCertificateRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_istio_google_security_meshca_v1_meshca_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *MeshCertificateRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*MeshCertificateRequest) ProtoMessage() {} - -func (x *MeshCertificateRequest) ProtoReflect() protoreflect.Message { - mi := &file_istio_google_security_meshca_v1_meshca_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use MeshCertificateRequest.ProtoReflect.Descriptor instead. -func (*MeshCertificateRequest) Descriptor() ([]byte, []int) { - return file_istio_google_security_meshca_v1_meshca_proto_rawDescGZIP(), []int{0} -} - -func (x *MeshCertificateRequest) GetRequestId() string { - if x != nil { - return x.RequestId - } - return "" -} - -func (x *MeshCertificateRequest) GetCsr() string { - if x != nil { - return x.Csr - } - return "" -} - -func (x *MeshCertificateRequest) GetValidity() *durationpb.Duration { - if x != nil { - return x.Validity - } - return nil -} - -// Certificate response message. -type MeshCertificateResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // PEM-encoded certificate chain. - // Leaf cert is element '0'. Root cert is element 'n'. - CertChain []string `protobuf:"bytes,1,rep,name=cert_chain,json=certChain,proto3" json:"cert_chain,omitempty"` -} - -func (x *MeshCertificateResponse) Reset() { - *x = MeshCertificateResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_istio_google_security_meshca_v1_meshca_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *MeshCertificateResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*MeshCertificateResponse) ProtoMessage() {} - -func (x *MeshCertificateResponse) ProtoReflect() protoreflect.Message { - mi := &file_istio_google_security_meshca_v1_meshca_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use MeshCertificateResponse.ProtoReflect.Descriptor instead. -func (*MeshCertificateResponse) Descriptor() ([]byte, []int) { - return file_istio_google_security_meshca_v1_meshca_proto_rawDescGZIP(), []int{1} -} - -func (x *MeshCertificateResponse) GetCertChain() []string { - if x != nil { - return x.CertChain - } - return nil -} - -var File_istio_google_security_meshca_v1_meshca_proto protoreflect.FileDescriptor - -var file_istio_google_security_meshca_v1_meshca_proto_rawDesc = []byte{ - 0x0a, 0x2c, 0x69, 0x73, 0x74, 0x69, 0x6f, 0x2f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x73, - 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2f, 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2f, 0x76, - 0x31, 0x2f, 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x19, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x73, 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2e, - 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2e, 0x76, 0x31, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x80, 0x01, 0x0a, 0x16, 0x4d, 0x65, - 0x73, 0x68, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x49, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x63, 0x73, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x63, 0x73, 0x72, 0x12, 0x35, 0x0a, 0x08, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x69, 0x74, - 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x08, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x69, 0x74, 0x79, 0x22, 0x38, 0x0a, 0x17, - 0x4d, 0x65, 0x73, 0x68, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x5f, - 0x63, 0x68, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x63, 0x65, 0x72, - 0x74, 0x43, 0x68, 0x61, 0x69, 0x6e, 0x32, 0x96, 0x01, 0x0a, 0x16, 0x4d, 0x65, 0x73, 0x68, 0x43, - 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x12, 0x7c, 0x0a, 0x11, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x43, 0x65, 0x72, 0x74, 0x69, - 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x31, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x73, 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2e, 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2e, - 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x32, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x73, 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2e, 0x6d, 0x65, 0x73, 0x68, - 0x63, 0x61, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, - 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, - 0x2e, 0x0a, 0x1d, 0x63, 0x6f, 0x6d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x73, 0x65, - 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x2e, 0x6d, 0x65, 0x73, 0x68, 0x63, 0x61, 0x2e, 0x76, 0x31, - 0x42, 0x0b, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x61, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} - -var ( - file_istio_google_security_meshca_v1_meshca_proto_rawDescOnce sync.Once - file_istio_google_security_meshca_v1_meshca_proto_rawDescData = file_istio_google_security_meshca_v1_meshca_proto_rawDesc -) - -func file_istio_google_security_meshca_v1_meshca_proto_rawDescGZIP() []byte { - file_istio_google_security_meshca_v1_meshca_proto_rawDescOnce.Do(func() { - file_istio_google_security_meshca_v1_meshca_proto_rawDescData = protoimpl.X.CompressGZIP(file_istio_google_security_meshca_v1_meshca_proto_rawDescData) - }) - return file_istio_google_security_meshca_v1_meshca_proto_rawDescData -} - -var file_istio_google_security_meshca_v1_meshca_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_istio_google_security_meshca_v1_meshca_proto_goTypes = []interface{}{ - (*MeshCertificateRequest)(nil), // 0: google.security.meshca.v1.MeshCertificateRequest - (*MeshCertificateResponse)(nil), // 1: google.security.meshca.v1.MeshCertificateResponse - (*durationpb.Duration)(nil), // 2: google.protobuf.Duration -} -var file_istio_google_security_meshca_v1_meshca_proto_depIdxs = []int32{ - 2, // 0: google.security.meshca.v1.MeshCertificateRequest.validity:type_name -> google.protobuf.Duration - 0, // 1: google.security.meshca.v1.MeshCertificateService.CreateCertificate:input_type -> google.security.meshca.v1.MeshCertificateRequest - 1, // 2: google.security.meshca.v1.MeshCertificateService.CreateCertificate:output_type -> google.security.meshca.v1.MeshCertificateResponse - 2, // [2:3] is the sub-list for method output_type - 1, // [1:2] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name -} - -func init() { file_istio_google_security_meshca_v1_meshca_proto_init() } -func file_istio_google_security_meshca_v1_meshca_proto_init() { - if File_istio_google_security_meshca_v1_meshca_proto != nil { - return - } - if !protoimpl.UnsafeEnabled { - file_istio_google_security_meshca_v1_meshca_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*MeshCertificateRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_istio_google_security_meshca_v1_meshca_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*MeshCertificateResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_istio_google_security_meshca_v1_meshca_proto_rawDesc, - NumEnums: 0, - NumMessages: 2, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_istio_google_security_meshca_v1_meshca_proto_goTypes, - DependencyIndexes: file_istio_google_security_meshca_v1_meshca_proto_depIdxs, - MessageInfos: file_istio_google_security_meshca_v1_meshca_proto_msgTypes, - }.Build() - File_istio_google_security_meshca_v1_meshca_proto = out.File - file_istio_google_security_meshca_v1_meshca_proto_rawDesc = nil - file_istio_google_security_meshca_v1_meshca_proto_goTypes = nil - file_istio_google_security_meshca_v1_meshca_proto_depIdxs = nil -} diff --git a/credentials/tls/certprovider/meshca/internal/v1/meshca_grpc.pb.go b/credentials/tls/certprovider/meshca/internal/v1/meshca_grpc.pb.go deleted file mode 100644 index e53a61598ab..00000000000 --- a/credentials/tls/certprovider/meshca/internal/v1/meshca_grpc.pb.go +++ /dev/null @@ -1,106 +0,0 @@ -// Code generated by protoc-gen-go-grpc. DO NOT EDIT. - -package google_security_meshca_v1 - -import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" -) - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 - -// MeshCertificateServiceClient is the client API for MeshCertificateService service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type MeshCertificateServiceClient interface { - // Using provided CSR, returns a signed certificate that represents a GCP - // service account identity. - CreateCertificate(ctx context.Context, in *MeshCertificateRequest, opts ...grpc.CallOption) (*MeshCertificateResponse, error) -} - -type meshCertificateServiceClient struct { - cc grpc.ClientConnInterface -} - -func NewMeshCertificateServiceClient(cc grpc.ClientConnInterface) MeshCertificateServiceClient { - return &meshCertificateServiceClient{cc} -} - -func (c *meshCertificateServiceClient) CreateCertificate(ctx context.Context, in *MeshCertificateRequest, opts ...grpc.CallOption) (*MeshCertificateResponse, error) { - out := new(MeshCertificateResponse) - err := c.cc.Invoke(ctx, "/google.security.meshca.v1.MeshCertificateService/CreateCertificate", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// MeshCertificateServiceServer is the server API for MeshCertificateService service. -// All implementations must embed UnimplementedMeshCertificateServiceServer -// for forward compatibility -type MeshCertificateServiceServer interface { - // Using provided CSR, returns a signed certificate that represents a GCP - // service account identity. - CreateCertificate(context.Context, *MeshCertificateRequest) (*MeshCertificateResponse, error) - mustEmbedUnimplementedMeshCertificateServiceServer() -} - -// UnimplementedMeshCertificateServiceServer must be embedded to have forward compatible implementations. -type UnimplementedMeshCertificateServiceServer struct { -} - -func (UnimplementedMeshCertificateServiceServer) CreateCertificate(context.Context, *MeshCertificateRequest) (*MeshCertificateResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method CreateCertificate not implemented") -} -func (UnimplementedMeshCertificateServiceServer) mustEmbedUnimplementedMeshCertificateServiceServer() { -} - -// UnsafeMeshCertificateServiceServer may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to MeshCertificateServiceServer will -// result in compilation errors. -type UnsafeMeshCertificateServiceServer interface { - mustEmbedUnimplementedMeshCertificateServiceServer() -} - -func RegisterMeshCertificateServiceServer(s grpc.ServiceRegistrar, srv MeshCertificateServiceServer) { - s.RegisterService(&MeshCertificateService_ServiceDesc, srv) -} - -func _MeshCertificateService_CreateCertificate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(MeshCertificateRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(MeshCertificateServiceServer).CreateCertificate(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/google.security.meshca.v1.MeshCertificateService/CreateCertificate", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(MeshCertificateServiceServer).CreateCertificate(ctx, req.(*MeshCertificateRequest)) - } - return interceptor(ctx, in, info, handler) -} - -// MeshCertificateService_ServiceDesc is the grpc.ServiceDesc for MeshCertificateService service. -// It's only intended for direct use with grpc.RegisterService, -// and not to be introspected or modified (even as a copy) -var MeshCertificateService_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "google.security.meshca.v1.MeshCertificateService", - HandlerType: (*MeshCertificateServiceServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "CreateCertificate", - Handler: _MeshCertificateService_CreateCertificate_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "istio/google/security/meshca/v1/meshca.proto", -} diff --git a/credentials/tls/certprovider/meshca/plugin.go b/credentials/tls/certprovider/meshca/plugin.go deleted file mode 100644 index ab1958ac1fd..00000000000 --- a/credentials/tls/certprovider/meshca/plugin.go +++ /dev/null @@ -1,289 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// Package meshca provides an implementation of the Provider interface which -// communicates with MeshCA to get certificates signed. -package meshca - -import ( - "context" - "crypto" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "encoding/pem" - "fmt" - "time" - - durationpb "github.com/golang/protobuf/ptypes/duration" - "github.com/google/uuid" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/tls/certprovider" - meshgrpc "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1" - meshpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1" - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/metadata" -) - -// In requests sent to the MeshCA, we add a metadata header with this key and -// the value being the GCE zone in which the workload is running in. -const locationMetadataKey = "x-goog-request-params" - -// For overriding from unit tests. -var newDistributorFunc = func() distributor { return certprovider.NewDistributor() } - -// distributor wraps the methods on certprovider.Distributor which are used by -// the plugin. This is very useful in tests which need to know exactly when the -// plugin updates its key material. -type distributor interface { - KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) - Set(km *certprovider.KeyMaterial, err error) - Stop() -} - -// providerPlugin is an implementation of the certprovider.Provider interface, -// which gets certificates signed by communicating with the MeshCA. -type providerPlugin struct { - distributor // Holds the key material. - cancel context.CancelFunc - cc *grpc.ClientConn // Connection to MeshCA server. - cfg *pluginConfig // Plugin configuration. - opts certprovider.BuildOptions // Key material options. - logger *grpclog.PrefixLogger // Plugin instance specific prefix. - backoff func(int) time.Duration // Exponential backoff. - doneFunc func() // Notify the builder when done. -} - -// providerParams wraps params passed to the provider plugin at creation time. -type providerParams struct { - // This ClientConn to the MeshCA server is owned by the builder. - cc *grpc.ClientConn - cfg *pluginConfig - opts certprovider.BuildOptions - backoff func(int) time.Duration - doneFunc func() -} - -func newProviderPlugin(params providerParams) *providerPlugin { - ctx, cancel := context.WithCancel(context.Background()) - p := &providerPlugin{ - cancel: cancel, - cc: params.cc, - cfg: params.cfg, - opts: params.opts, - backoff: params.backoff, - doneFunc: params.doneFunc, - distributor: newDistributorFunc(), - } - p.logger = prefixLogger((p)) - p.logger.Infof("plugin created") - go p.run(ctx) - return p -} - -func (p *providerPlugin) Close() { - p.logger.Infof("plugin closed") - p.Stop() // Stop the embedded distributor. - p.cancel() - p.doneFunc() -} - -// run is a long running goroutine which periodically sends out CSRs to the -// MeshCA, and updates the underlying Distributor with the new key material. -func (p *providerPlugin) run(ctx context.Context) { - // We need to start fetching key material right away. The next attempt will - // be triggered by the timer firing. - for { - certValidity, err := p.updateKeyMaterial(ctx) - if err != nil { - return - } - - // We request a certificate with the configured validity duration (which - // is usually twice as much as the grace period). But the server is free - // to return a certificate with whatever validity time it deems right. - refreshAfter := p.cfg.certGraceTime - if refreshAfter > certValidity { - // The default value of cert grace time is half that of the default - // cert validity time. So here, when we have to use a non-default - // cert life time, we will set the grace time again to half that of - // the validity time. - refreshAfter = certValidity / 2 - } - timer := time.NewTimer(refreshAfter) - select { - case <-ctx.Done(): - return - case <-timer.C: - } - } -} - -// updateKeyMaterial generates a CSR and attempts to get it signed from the -// MeshCA. It retries with an exponential backoff till it succeeds or the -// deadline specified in ctx expires. Once it gets the CSR signed from the -// MeshCA, it updates the Distributor with the new key material. -// -// It returns the amount of time the new certificate is valid for. -func (p *providerPlugin) updateKeyMaterial(ctx context.Context) (time.Duration, error) { - client := meshgrpc.NewMeshCertificateServiceClient(p.cc) - retries := 0 - for { - if ctx.Err() != nil { - return 0, ctx.Err() - } - - if retries != 0 { - bi := p.backoff(retries) - p.logger.Warningf("Backing off for %s before attempting the next CreateCertificate() request", bi) - timer := time.NewTimer(bi) - select { - case <-timer.C: - case <-ctx.Done(): - return 0, ctx.Err() - } - } - retries++ - - privKey, err := rsa.GenerateKey(rand.Reader, p.cfg.keySize) - if err != nil { - p.logger.Warningf("RSA key generation failed: %v", err) - continue - } - // We do not set any fields in the CSR (we use an empty - // x509.CertificateRequest as the template) because the MeshCA discards - // them anyways, and uses the workload identity from the access token - // that we present (as part of the STS call creds). - csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{}, crypto.PrivateKey(privKey)) - if err != nil { - p.logger.Warningf("CSR creation failed: %v", err) - continue - } - csrPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrBytes}) - - // Send out the CSR with a call timeout and location metadata, as - // specified in the plugin configuration. - req := &meshpb.MeshCertificateRequest{ - RequestId: uuid.New().String(), - Csr: string(csrPEM), - Validity: &durationpb.Duration{Seconds: int64(p.cfg.certLifetime / time.Second)}, - } - p.logger.Debugf("Sending CreateCertificate() request: %v", req) - - callCtx, ctxCancel := context.WithTimeout(context.Background(), p.cfg.callTimeout) - callCtx = metadata.NewOutgoingContext(callCtx, metadata.Pairs(locationMetadataKey, p.cfg.location)) - resp, err := client.CreateCertificate(callCtx, req) - if err != nil { - p.logger.Warningf("CreateCertificate request failed: %v", err) - ctxCancel() - continue - } - ctxCancel() - - // The returned cert chain must contain more than one cert. Leaf cert is - // element '0', while root cert is element 'n', and the intermediate - // entries form the chain from the root to the leaf. - certChain := resp.GetCertChain() - if l := len(certChain); l <= 1 { - p.logger.Errorf("Received certificate chain contains %d certificates, need more than one", l) - continue - } - - // We need to explicitly parse the PEM cert contents as an - // x509.Certificate to read the certificate validity period. We use this - // to decide when to refresh the cert. Even though the call to - // tls.X509KeyPair actually parses the PEM contents into an - // x509.Certificate, it does not store that in the `Leaf` field. See: - // https://golang.org/pkg/crypto/tls/#X509KeyPair. - identity, intermediates, roots, err := parseCertChain(certChain) - if err != nil { - p.logger.Errorf(err.Error()) - continue - } - _, err = identity.Verify(x509.VerifyOptions{ - Intermediates: intermediates, - Roots: roots, - }) - if err != nil { - p.logger.Errorf("Certificate verification failed for return certChain: %v", err) - continue - } - - key := x509.MarshalPKCS1PrivateKey(privKey) - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: key}) - certPair, err := tls.X509KeyPair([]byte(certChain[0]), keyPEM) - if err != nil { - p.logger.Errorf("Failed to create x509 key pair: %v", err) - continue - } - - // At this point, the received response has been deemed good. - retries = 0 - - // All certs signed by the MeshCA roll up to the same root. And treating - // the last element of the returned chain as the root is the only - // supported option to get the root certificate. So, we ignore the - // options specified in the call to Build(), which contain certificate - // name and whether the caller is interested in identity or root cert. - p.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{certPair}, Roots: roots}, nil) - return time.Until(identity.NotAfter), nil - } -} - -// ParseCertChain parses the result returned by the MeshCA which consists of a -// list of PEM encoded certs. The first element in the list is the leaf or -// identity cert, while the last element is the root, and everything in between -// form the chain of trust. -// -// Caller needs to make sure that certChain has at least two elements. -func parseCertChain(certChain []string) (*x509.Certificate, *x509.CertPool, *x509.CertPool, error) { - identity, err := parseCert([]byte(certChain[0])) - if err != nil { - return nil, nil, nil, err - } - - intermediates := x509.NewCertPool() - for _, cert := range certChain[1 : len(certChain)-1] { - i, err := parseCert([]byte(cert)) - if err != nil { - return nil, nil, nil, err - } - intermediates.AddCert(i) - } - - roots := x509.NewCertPool() - root, err := parseCert([]byte(certChain[len(certChain)-1])) - if err != nil { - return nil, nil, nil, err - } - roots.AddCert(root) - - return identity, intermediates, roots, nil -} - -func parseCert(certPEM []byte) (*x509.Certificate, error) { - block, _ := pem.Decode(certPEM) - if block == nil { - return nil, fmt.Errorf("failed to decode received PEM data: %v", certPEM) - } - return x509.ParseCertificate(block.Bytes) -} diff --git a/credentials/tls/certprovider/meshca/plugin_test.go b/credentials/tls/certprovider/meshca/plugin_test.go deleted file mode 100644 index 51f545d6a0e..00000000000 --- a/credentials/tls/certprovider/meshca/plugin_test.go +++ /dev/null @@ -1,459 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package meshca - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/json" - "encoding/pem" - "errors" - "fmt" - "math/big" - "net" - "reflect" - "testing" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/tls/certprovider" - meshgrpc "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1" - meshpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1" - "google.golang.org/grpc/internal/testutils" -) - -const ( - // Used when waiting for something that is expected to *not* happen. - defaultTestShortTimeout = 10 * time.Millisecond - defaultTestTimeout = 5 * time.Second - defaultTestCertLife = time.Hour - shortTestCertLife = 2 * time.Second - maxErrCount = 2 -) - -// fakeCA provides a very simple fake implementation of the certificate signing -// service as exported by the MeshCA. -type fakeCA struct { - meshgrpc.UnimplementedMeshCertificateServiceServer - - withErrors bool // Whether the CA returns errors to begin with. - withShortLife bool // Whether to create certs with short lifetime - - ccChan *testutils.Channel // Channel to get notified about CreateCertificate calls. - errors int // Error count. - key *rsa.PrivateKey // Private key of CA. - cert *x509.Certificate // Signing certificate. - certPEM []byte // PEM encoding of signing certificate. -} - -// Returns a new instance of the fake Mesh CA. It generates a new RSA key and a -// self-signed certificate which will be used to sign CSRs received in incoming -// requests. -// withErrors controls whether the fake returns errors before succeeding, while -// withShortLife controls whether the fake returns certs with very small -// lifetimes (to test plugin refresh behavior). Every time a CreateCertificate() -// call succeeds, an event is pushed on the ccChan. -func newFakeMeshCA(ccChan *testutils.Channel, withErrors, withShortLife bool) (*fakeCA, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, fmt.Errorf("RSA key generation failed: %v", err) - } - - now := time.Now() - tmpl := &x509.Certificate{ - Subject: pkix.Name{CommonName: "my-fake-ca"}, - SerialNumber: big.NewInt(10), - NotBefore: now.Add(-time.Hour), - NotAfter: now.Add(time.Hour), - KeyUsage: x509.KeyUsageCertSign, - IsCA: true, - BasicConstraintsValid: true, - } - certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) - if err != nil { - return nil, fmt.Errorf("x509.CreateCertificate(%v) failed: %v", tmpl, err) - } - // The PEM encoding of the self-signed certificate is stored because we need - // to return a chain of certificates in the response, starting with the - // client certificate and ending in the root. - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - cert, err := x509.ParseCertificate(certDER) - if err != nil { - return nil, fmt.Errorf("x509.ParseCertificate(%v) failed: %v", certDER, err) - } - - return &fakeCA{ - withErrors: withErrors, - withShortLife: withShortLife, - ccChan: ccChan, - key: key, - cert: cert, - certPEM: certPEM, - }, nil -} - -// CreateCertificate helps implement the MeshCA service. -// -// If the fakeMeshCA was created with `withErrors` set to true, the first -// `maxErrCount` number of RPC return errors. Subsequent requests are signed and -// returned without error. -func (f *fakeCA) CreateCertificate(ctx context.Context, req *meshpb.MeshCertificateRequest) (*meshpb.MeshCertificateResponse, error) { - if f.withErrors { - if f.errors < maxErrCount { - f.errors++ - return nil, errors.New("fake Mesh CA error") - - } - } - - csrPEM := []byte(req.GetCsr()) - block, _ := pem.Decode(csrPEM) - if block == nil { - return nil, fmt.Errorf("failed to decode received CSR: %v", csrPEM) - } - csr, err := x509.ParseCertificateRequest(block.Bytes) - if err != nil { - return nil, fmt.Errorf("failed to parse received CSR: %v", csrPEM) - } - - // By default, we create certs which are valid for an hour. But if - // `withShortLife` is set, we create certs which are valid only for a couple - // of seconds. - now := time.Now() - notBefore, notAfter := now.Add(-defaultTestCertLife), now.Add(defaultTestCertLife) - if f.withShortLife { - notBefore, notAfter = now.Add(-shortTestCertLife), now.Add(shortTestCertLife) - } - tmpl := &x509.Certificate{ - Subject: pkix.Name{CommonName: "signed-cert"}, - SerialNumber: big.NewInt(10), - NotBefore: notBefore, - NotAfter: notAfter, - KeyUsage: x509.KeyUsageDigitalSignature, - } - certDER, err := x509.CreateCertificate(rand.Reader, tmpl, f.cert, csr.PublicKey, f.key) - if err != nil { - return nil, fmt.Errorf("x509.CreateCertificate(%v) failed: %v", tmpl, err) - } - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - - // Push to ccChan to indicate that the RPC is processed. - f.ccChan.Send(nil) - - certChain := []string{ - string(certPEM), // Signed certificate corresponding to CSR - string(f.certPEM), // Root certificate - } - return &meshpb.MeshCertificateResponse{CertChain: certChain}, nil -} - -// opts wraps the options to be passed to setup. -type opts struct { - // Whether the CA returns certs with short lifetime. Used to test client refresh. - withShortLife bool - // Whether the CA returns errors to begin with. Used to test client backoff. - withbackoff bool -} - -// events wraps channels which indicate different events. -type events struct { - // Pushed to when the plugin dials the MeshCA. - dialDone *testutils.Channel - // Pushed to when CreateCertifcate() succeeds on the MeshCA. - createCertDone *testutils.Channel - // Pushed to when the plugin updates the distributor with new key material. - keyMaterialDone *testutils.Channel - // Pushed to when the client backs off after a failed CreateCertificate(). - backoffDone *testutils.Channel -} - -// setup performs tasks common to all tests in this file. -func setup(t *testing.T, o opts) (events, string, func()) { - t.Helper() - - // Create a fake MeshCA which pushes events on the passed channel for - // successful RPCs. - createCertDone := testutils.NewChannel() - fs, err := newFakeMeshCA(createCertDone, o.withbackoff, o.withShortLife) - if err != nil { - t.Fatal(err) - } - - // Create a gRPC server and register the fake MeshCA on it. - server := grpc.NewServer() - meshgrpc.RegisterMeshCertificateServiceServer(server, fs) - - // Start a net.Listener on a local port, and pass it to the gRPC server - // created above and start serving. - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal(err) - } - addr := lis.Addr().String() - go server.Serve(lis) - - // Override the plugin's dial function and perform a blocking dial. Also - // push on dialDone once the dial is complete so that test can block on this - // event before verifying other things. - dialDone := testutils.NewChannel() - origDialFunc := grpcDialFunc - grpcDialFunc = func(uri string, _ ...grpc.DialOption) (*grpc.ClientConn, error) { - if uri != addr { - t.Fatalf("plugin dialing MeshCA at %s, want %s", uri, addr) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - cc, err := grpc.DialContext(ctx, uri, grpc.WithInsecure(), grpc.WithBlock()) - if err != nil { - t.Fatalf("grpc.DialContext(%s) failed: %v", addr, err) - } - dialDone.Send(nil) - return cc, nil - } - - // Override the plugin's newDistributorFunc and return a wrappedDistributor - // which allows the test to be notified whenever the plugin pushes new key - // material into the distributor. - origDistributorFunc := newDistributorFunc - keyMaterialDone := testutils.NewChannel() - d := newWrappedDistributor(keyMaterialDone) - newDistributorFunc = func() distributor { return d } - - // Override the plugin's backoff function to perform no real backoff, but - // push on a channel so that the test can verifiy that backoff actually - // happened. - backoffDone := testutils.NewChannelWithSize(maxErrCount) - origBackoffFunc := backoffFunc - if o.withbackoff { - // Override the plugin's backoff function with this, so that we can verify - // that a backoff actually was triggered. - backoffFunc = func(v int) time.Duration { - backoffDone.Send(v) - return 0 - } - } - - // Return all the channels, and a cancel function to undo all the overrides. - e := events{ - dialDone: dialDone, - createCertDone: createCertDone, - keyMaterialDone: keyMaterialDone, - backoffDone: backoffDone, - } - done := func() { - server.Stop() - grpcDialFunc = origDialFunc - newDistributorFunc = origDistributorFunc - backoffFunc = origBackoffFunc - } - return e, addr, done -} - -// wrappedDistributor wraps a distributor and pushes on a channel whenever new -// key material is pushed to the distributor. -type wrappedDistributor struct { - *certprovider.Distributor - kmChan *testutils.Channel -} - -func newWrappedDistributor(kmChan *testutils.Channel) *wrappedDistributor { - return &wrappedDistributor{ - kmChan: kmChan, - Distributor: certprovider.NewDistributor(), - } -} - -func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) { - wd.Distributor.Set(km, err) - wd.kmChan.Send(nil) -} - -// TestCreateCertificate verifies the simple case where the MeshCA server -// returns a good certificate. -func (s) TestCreateCertificate(t *testing.T) { - e, addr, cancel := setup(t, opts{}) - defer cancel() - - // Set the MeshCA targetURI to point to our fake MeshCA. - inputConfig := json.RawMessage(fmt.Sprintf(goodConfigFormatStr, addr)) - - // Lookup MeshCA plugin builder, parse config and start the plugin. - prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("GetProvider(%s, %s) failed: %v", pluginName, string(inputConfig), err) - } - defer prov.Close() - - // Wait till the plugin dials the MeshCA server. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.dialDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to dial MeshCA") - } - - // Wait till the plugin makes a CreateCertificate() call. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.createCertDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to make CreateCertificate RPC") - } - - // We don't really care about the exact key material returned here. All we - // care about is whether we get any key material at all, and that we don't - // get any errors. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err = prov.KeyMaterial(ctx); err != nil { - t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err) - } -} - -// TestCreateCertificateWithBackoff verifies the case where the MeshCA server -// returns errors initially and then returns a good certificate. The test makes -// sure that the client backs off when the server returns errors. -func (s) TestCreateCertificateWithBackoff(t *testing.T) { - e, addr, cancel := setup(t, opts{withbackoff: true}) - defer cancel() - - // Set the MeshCA targetURI to point to our fake MeshCA. - inputConfig := json.RawMessage(fmt.Sprintf(goodConfigFormatStr, addr)) - - // Lookup MeshCA plugin builder, parse config and start the plugin. - prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("GetProvider(%s, %s) failed: %v", pluginName, string(inputConfig), err) - } - defer prov.Close() - - // Wait till the plugin dials the MeshCA server. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.dialDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to dial MeshCA") - } - - // Making the CreateCertificateRPC involves generating the keys, creating - // the CSR etc which seem to take reasonable amount of time. And in this - // test, the first two attempts will fail. Hence we give it a reasonable - // deadline here. - ctx, cancel = context.WithTimeout(context.Background(), 3*defaultTestTimeout) - defer cancel() - if _, err := e.createCertDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to make CreateCertificate RPC") - } - - // The first `maxErrCount` calls to CreateCertificate end in failure, and - // should lead to a backoff. - for i := 0; i < maxErrCount; i++ { - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.backoffDone.Receive(ctx); err != nil { - t.Fatalf("plugin failed to backoff after error from fake server: %v", err) - } - } - - // We don't really care about the exact key material returned here. All we - // care about is whether we get any key material at all, and that we don't - // get any errors. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err = prov.KeyMaterial(ctx); err != nil { - t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err) - } -} - -// TestCreateCertificateWithRefresh verifies the case where the MeshCA returns a -// certificate with a really short lifetime, and makes sure that the plugin -// refreshes the cert in time. -func (s) TestCreateCertificateWithRefresh(t *testing.T) { - e, addr, cancel := setup(t, opts{withShortLife: true}) - defer cancel() - - // Set the MeshCA targetURI to point to our fake MeshCA. - inputConfig := json.RawMessage(fmt.Sprintf(goodConfigFormatStr, addr)) - - // Lookup MeshCA plugin builder, parse config and start the plugin. - prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.BuildOptions{}) - if err != nil { - t.Fatalf("GetProvider(%s, %s) failed: %v", pluginName, string(inputConfig), err) - } - defer prov.Close() - - // Wait till the plugin dials the MeshCA server. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.dialDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to dial MeshCA") - } - - // Wait till the plugin makes a CreateCertificate() call. - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := e.createCertDone.Receive(ctx); err != nil { - t.Fatal("timeout waiting for plugin to make CreateCertificate RPC") - } - - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - km1, err := prov.KeyMaterial(ctx) - if err != nil { - t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err) - } - - // At this point, we have read the first key material, and since the - // returned key material has a really short validity period, we expect the - // key material to be refreshed quite soon. We drain the channel on which - // the event corresponding to setting of new key material is pushed. This - // enables us to block on the same channel, waiting for refreshed key - // material. - // Since we do not expect this call to block, it is OK to pass the - // background context. - e.keyMaterialDone.Receive(context.Background()) - - // Wait for the next call to CreateCertificate() to refresh the certificate - // returned earlier. - ctx, cancel = context.WithTimeout(context.Background(), 2*shortTestCertLife) - defer cancel() - if _, err := e.keyMaterialDone.Receive(ctx); err != nil { - t.Fatalf("CreateCertificate() RPC not made: %v", err) - } - - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - km2, err := prov.KeyMaterial(ctx) - if err != nil { - t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err) - } - - // TODO(easwars): Remove all references to reflect.DeepEqual and use - // cmp.Equal instead. Currently, the later panics because x509.Certificate - // type defines an Equal method, but does not check for nil. This has been - // fixed in - // https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351, - // but this is only available starting go1.14. So, once we remove support - // for go1.13, we can make the switch. - if reflect.DeepEqual(km1, km2) { - t.Error("certificate refresh did not happen in the background") - } -} diff --git a/credentials/tls/certprovider/pemfile/watcher_test.go b/credentials/tls/certprovider/pemfile/watcher_test.go index e43cf7358ec..6cc65bd5000 100644 --- a/credentials/tls/certprovider/pemfile/watcher_test.go +++ b/credentials/tls/certprovider/pemfile/watcher_test.go @@ -22,7 +22,6 @@ import ( "context" "fmt" "io/ioutil" - "math/big" "os" "path" "testing" @@ -30,7 +29,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" @@ -57,17 +55,15 @@ func Test(t *testing.T) { } func compareKeyMaterial(got, want *certprovider.KeyMaterial) error { - // x509.Certificate type defines an Equal() method, but does not check for - // nil. This has been fixed in - // https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351, - // but this is only available starting go1.14. - // TODO(easwars): Remove this check once we remove support for go1.13. - if (got.Certs == nil && want.Certs != nil) || (want.Certs == nil && got.Certs != nil) { + if len(got.Certs) != len(want.Certs) { return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) } - if !cmp.Equal(got.Certs, want.Certs, cmp.AllowUnexported(big.Int{})) { - return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + for i := 0; i < len(got.Certs); i++ { + if !got.Certs[i].Leaf.Equal(want.Certs[i].Leaf) { + return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + } } + // x509.CertPool contains only unexported fields some of which contain other // unexported fields. So usage of cmp.AllowUnexported() or // cmpopts.IgnoreUnexported() does not help us much here. Also, the standard diff --git a/credentials/tls/certprovider/store_test.go b/credentials/tls/certprovider/store_test.go index 00d33a2be87..ee1f4a358ba 100644 --- a/credentials/tls/certprovider/store_test.go +++ b/credentials/tls/certprovider/store_test.go @@ -1,5 +1,3 @@ -// +build go1.13 - /* * * Copyright 2020 gRPC authors. @@ -27,10 +25,11 @@ import ( "errors" "fmt" "io/ioutil" - "reflect" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/testdata" @@ -154,15 +153,23 @@ func readAndVerifyKeyMaterial(ctx context.Context, kmr kmReader, wantKM *KeyMate } func compareKeyMaterial(got, want *KeyMaterial) error { - // TODO(easwars): Remove all references to reflect.DeepEqual and use - // cmp.Equal instead. Currently, the later panics because x509.Certificate - // type defines an Equal method, but does not check for nil. This has been - // fixed in - // https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351, - // but this is only available starting go1.14. So, once we remove support - // for go1.13, we can make the switch. - if !reflect.DeepEqual(got, want) { - return fmt.Errorf("provider.KeyMaterial() = %+v, want %+v", got, want) + if len(got.Certs) != len(want.Certs) { + return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + } + for i := 0; i < len(got.Certs); i++ { + if !got.Certs[i].Leaf.Equal(want.Certs[i].Leaf) { + return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want) + } + } + + // x509.CertPool contains only unexported fields some of which contain other + // unexported fields. So usage of cmp.AllowUnexported() or + // cmpopts.IgnoreUnexported() does not help us much here. Also, the standard + // library does not provide a way to compare CertPool values. Comparing the + // subjects field of the certs in the CertPool seems like a reasonable + // approach. + if gotR, wantR := got.Roots.Subjects(), want.Roots.Subjects(); !cmp.Equal(gotR, wantR, cmpopts.EquateEmpty()) { + return fmt.Errorf("keyMaterial roots = %v, want %v", gotR, wantR) } return nil } diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index ede0806d70d..680ea9cfa10 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -18,11 +18,6 @@ // Package xds provides a transport credentials implementation where the // security configuration is pushed by a management server using xDS APIs. -// -// Experimental -// -// Notice: All APIs in this package are EXPERIMENTAL and may be removed in a -// later release. package xds import ( @@ -216,12 +211,15 @@ func (c *credsImpl) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.Aut // `HandshakeInfo` does not contain the information we are looking for, we // delegate the handshake to the fallback credentials. hiConn, ok := rawConn.(interface { - XDSHandshakeInfo() *xdsinternal.HandshakeInfo + XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) }) if !ok { return c.fallback.ServerHandshake(rawConn) } - hi := hiConn.XDSHandshakeInfo() + hi, err := hiConn.XDSHandshakeInfo() + if err != nil { + return nil, nil, err + } if hi.UseFallbackCreds() { return c.fallback.ServerHandshake(rawConn) } diff --git a/credentials/xds/xds_client_test.go b/credentials/xds/xds_client_test.go index 219d0aefcba..f4b86df060b 100644 --- a/credentials/xds/xds_client_test.go +++ b/credentials/xds/xds_client_test.go @@ -32,18 +32,19 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" - "google.golang.org/grpc/internal" + icredentials "google.golang.org/grpc/internal/credentials" xdsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/resolver" "google.golang.org/grpc/testdata" ) const ( - defaultTestTimeout = 10 * time.Second + defaultTestTimeout = 1 * time.Second defaultTestShortTimeout = 10 * time.Millisecond - defaultTestCertSAN = "*.test.example.com" + defaultTestCertSAN = "abc.test.example.com" authority = "authority" ) @@ -214,18 +215,20 @@ func makeRootProvider(t *testing.T, caPath string) *fakeProvider { // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo // context value added to it. -func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sans ...string) context.Context { +func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context { // Creating the HandshakeInfo and adding it to the attributes is very // similar to what the CDS balancer would do when it intercepts calls to // NewSubConn(). - info := xdsinternal.NewHandshakeInfo(root, identity, sans...) + info := xdsinternal.NewHandshakeInfo(root, identity) + if sanExactMatch != "" { + info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}) + } addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info) // Moving the attributes from the resolver.Address to the context passed to // the handshaker is done in the transport layer. Since we directly call the // handshaker in these tests, we need to do the same here. - contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) - return contextWithHandshakeInfo(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) + return icredentials.NewClientHandshakeInfoContext(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) } // compareAuthInfo compares the AuthInfo received on the client side after a @@ -292,7 +295,7 @@ func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) { pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}) + ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "") if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil { t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo") } @@ -329,7 +332,7 @@ func (s) TestClientCredsProviderFailure(t *testing.T) { t.Run(test.desc, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider) + ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "") if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) { t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr) } @@ -371,7 +374,7 @@ func (s) TestClientCredsSuccess(t *testing.T) { desc: "mTLS with no acceptedSANs specified", handshakeFunc: testServerMutualTLSHandshake, handshakeInfoCtx: func(ctx context.Context) context.Context { - return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem")) + return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "") }, }, } @@ -530,14 +533,14 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { // Create a root provider which will fail the handshake because it does not // use the correct trust roots. root1 := makeRootProvider(t, "x509/client_ca_cert.pem") - handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, defaultTestCertSAN) + handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil) + handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}) // We need to repeat most of what newTestContextWithHandshakeInfo() does // here because we need access to the underlying HandshakeInfo so that we // can update it before the next call to ClientHandshake(). addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo) - contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) - ctx = contextWithHandshakeInfo(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) + ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { t.Fatal("ClientHandshake() succeeded when expected to fail") } @@ -582,3 +585,7 @@ func (s) TestClientClone(t *testing.T) { t.Fatal("return value from Clone() doesn't point to new credentials instance") } } + +func newStringP(s string) *string { + return &s +} diff --git a/credentials/xds/xds_server_test.go b/credentials/xds/xds_server_test.go index 68d92b28e28..5c29ba38c28 100644 --- a/credentials/xds/xds_server_test.go +++ b/credentials/xds/xds_server_test.go @@ -95,12 +95,13 @@ func (s) TestServerCredsWithoutFallback(t *testing.T) { type wrapperConn struct { net.Conn - xdsHI *xdsinternal.HandshakeInfo - deadline time.Time + xdsHI *xdsinternal.HandshakeInfo + deadline time.Time + handshakeInfoErr error } -func (wc *wrapperConn) XDSHandshakeInfo() *xdsinternal.HandshakeInfo { - return wc.xdsHI +func (wc *wrapperConn) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) { + return wc.xdsHI, wc.handshakeInfoErr } func (wc *wrapperConn) GetDeadline() time.Time { @@ -166,6 +167,58 @@ func (s) TestServerCredsProviderFailure(t *testing.T) { } } +// TestServerCredsHandshake_XDSHandshakeInfoError verifies the case where the +// call to XDSHandshakeInfo() from the ServerHandshake() method returns an +// error, and the test verifies that the ServerHandshake() fails with the +// expected error. +func (s) TestServerCredsHandshake_XDSHandshakeInfoError(t *testing.T) { + opts := ServerOptions{FallbackCreds: &errorCreds{}} + creds, err := NewServerCredentials(opts) + if err != nil { + t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) + } + + // Create a test server which uses the xDS server credentials created above + // to perform TLS handshake on incoming connections. + ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + // Create a wrapped conn which returns a nil HandshakeInfo and a non-nil error. + conn := newWrappedConn(rawConn, nil, time.Now().Add(defaultTestTimeout)) + hiErr := errors.New("xdsHandshakeInfo error") + conn.handshakeInfoErr = hiErr + + // Invoke the ServerHandshake() method on the xDS credentials and verify + // that the error returned by the XDSHandshakeInfo() method on the + // wrapped conn is returned here. + _, _, err := creds.ServerHandshake(conn) + if !errors.Is(err, hiErr) { + return handshakeResult{err: fmt.Errorf("ServerHandshake() returned err: %v, wantErr: %v", err, hiErr)} + } + return handshakeResult{} + }) + defer ts.stop() + + // Dial the test server, but don't trigger the TLS handshake. This will + // cause ServerHandshake() to fail. + rawConn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer rawConn.Close() + + // Read handshake result from the testServer which will return an error if + // the handshake succeeded. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + val, err := ts.hsResult.Receive(ctx) + if err != nil { + t.Fatalf("testServer failed to return handshake result: %v", err) + } + hsr := val.(handshakeResult) + if hsr.err != nil { + t.Fatalf("testServer handshake failure: %v", hsr.err) + } +} + // TestServerCredsHandshakeTimeout verifies the case where the client does not // send required handshake data before the deadline set on the net.Conn passed // to ServerHandshake(). diff --git a/dialoptions.go b/dialoptions.go index e7f86e6d7c8..7a497237bbd 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -66,11 +66,7 @@ type dialOptions struct { minConnectTimeout func() time.Duration defaultServiceConfig *ServiceConfig // defaultServiceConfig is parsed from defaultServiceConfigRawJSON. defaultServiceConfigRawJSON *string - // This is used by ccResolverWrapper to backoff between successive calls to - // resolver.ResolveNow(). The user will have no need to configure this, but - // we need to be able to configure this in tests. - resolveNowBackoff func(int) time.Duration - resolvers []resolver.Builder + resolvers []resolver.Builder } // DialOption configures how we set up the connection. @@ -596,7 +592,6 @@ func defaultDialOptions() dialOptions { ReadBufferSize: defaultReadBufSize, UseProxy: true, }, - resolveNowBackoff: internalbackoff.DefaultExponential.Backoff, } } @@ -611,16 +606,6 @@ func withMinConnectDeadline(f func() time.Duration) DialOption { }) } -// withResolveNowBackoff specifies the function that clientconn uses to backoff -// between successive calls to resolver.ResolveNow(). -// -// For testing purpose only. -func withResolveNowBackoff(f func(int) time.Duration) DialOption { - return newFuncDialOption(func(o *dialOptions) { - o.resolveNowBackoff = f - }) -} - // WithResolvers allows a list of resolver implementations to be registered // locally with the ClientConn without needing to be globally registered via // resolver.Register. They will be matched against the scheme used for the diff --git a/examples/examples_test.sh b/examples/examples_test.sh index 9015272f33e..f5c82d062b2 100755 --- a/examples/examples_test.sh +++ b/examples/examples_test.sh @@ -58,6 +58,7 @@ EXAMPLES=( "features/metadata" "features/multiplex" "features/name_resolving" + "features/unix_abstract" ) declare -A EXPECTED_SERVER_OUTPUT=( @@ -73,6 +74,7 @@ declare -A EXPECTED_SERVER_OUTPUT=( ["features/metadata"]="message:\"this is examples/metadata\", sending echo" ["features/multiplex"]=":50051" ["features/name_resolving"]="serving on localhost:50051" + ["features/unix_abstract"]="serving on @abstract-unix-socket" ) declare -A EXPECTED_CLIENT_OUTPUT=( @@ -88,6 +90,7 @@ declare -A EXPECTED_CLIENT_OUTPUT=( ["features/metadata"]="this is examples/metadata" ["features/multiplex"]="Greeting: Hello multiplex" ["features/name_resolving"]="calling helloworld.Greeter/SayHello to \"example:///resolver.example.grpc.io\"" + ["features/unix_abstract"]="calling echo.Echo/UnaryEcho to unix-abstract:abstract-unix-socket" ) cd ./examples diff --git a/examples/features/encryption/README.md b/examples/features/encryption/README.md index a00188d66a2..2afca1d785f 100644 --- a/examples/features/encryption/README.md +++ b/examples/features/encryption/README.md @@ -42,8 +42,8 @@ configure TLS and create the server credential using On client side, we provide the path to the "ca_cert.pem" to configure TLS and create the client credential using [`credentials.NewClientTLSFromFile`](https://godoc.org/google.golang.org/grpc/credentials#NewClientTLSFromFile). -Note that we override the server name with "x.test.youtube.com", as the server -certificate is valid for *.test.youtube.com but not localhost. It is solely for +Note that we override the server name with "x.test.example.com", as the server +certificate is valid for *.test.example.com but not localhost. It is solely for the convenience of making an example. Once the credentials have been created at both sides, we can start the server diff --git a/examples/features/proto/echo/echo_grpc.pb.go b/examples/features/proto/echo/echo_grpc.pb.go index 052087dae36..e1d24b1e830 100644 --- a/examples/features/proto/echo/echo_grpc.pb.go +++ b/examples/features/proto/echo/echo_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: examples/features/proto/echo/echo.proto package echo diff --git a/examples/features/unix_abstract/README.md b/examples/features/unix_abstract/README.md new file mode 100644 index 00000000000..32b3bd5f262 --- /dev/null +++ b/examples/features/unix_abstract/README.md @@ -0,0 +1,29 @@ +# Unix abstract sockets + +This examples shows how to start a gRPC server listening on a unix abstract +socket and how to get a gRPC client to connect to it. + +## What is a unix abstract socket + +An abstract socket address is distinguished from a regular unix socket by the +fact that the first byte of the address is a null byte ('\0'). The address has +no connection with filesystem path names. + +## Try it + +``` +go run server/main.go +``` + +``` +go run client/main.go +``` + +## Explanation + +The gRPC server in this example listens on an address starting with a null byte +and the network is `unix`. The client uses the `unix-abstract` scheme with the +endpoint set to the abstract unix socket address without the null byte. The +`unix` resolver takes care of adding the null byte on the client. See +https://github.com/grpc/grpc/blob/master/doc/naming.md for the more details. + diff --git a/examples/features/unix_abstract/client/main.go b/examples/features/unix_abstract/client/main.go new file mode 100644 index 00000000000..4f48aca9bdf --- /dev/null +++ b/examples/features/unix_abstract/client/main.go @@ -0,0 +1,68 @@ +//go:build linux +// +build linux + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Binary client is an example client which dials a server on an abstract unix +// socket. +package main + +import ( + "context" + "fmt" + "log" + "time" + + "google.golang.org/grpc" + ecpb "google.golang.org/grpc/examples/features/proto/echo" +) + +func callUnaryEcho(c ecpb.EchoClient, message string) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + r, err := c.UnaryEcho(ctx, &ecpb.EchoRequest{Message: message}) + if err != nil { + log.Fatalf("could not greet: %v", err) + } + fmt.Println(r.Message) +} + +func makeRPCs(cc *grpc.ClientConn, n int) { + hwc := ecpb.NewEchoClient(cc) + for i := 0; i < n; i++ { + callUnaryEcho(hwc, "this is examples/unix_abstract") + } +} + +func main() { + // A dial target of `unix:@abstract-unix-socket` should also work fine for + // this example because of golang conventions (net.Dial behavior). But we do + // not recommend this since we explicitly added the `unix-abstract` scheme + // for cross-language compatibility. + addr := "unix-abstract:abstract-unix-socket" + cc, err := grpc.Dial(addr, grpc.WithInsecure(), grpc.WithBlock()) + if err != nil { + log.Fatalf("grpc.Dial(%q) failed: %v", addr, err) + } + defer cc.Close() + + fmt.Printf("--- calling echo.Echo/UnaryEcho to %s\n", addr) + makeRPCs(cc, 10) + fmt.Println() +} diff --git a/examples/features/unix_abstract/server/main.go b/examples/features/unix_abstract/server/main.go new file mode 100644 index 00000000000..a82b957c1f0 --- /dev/null +++ b/examples/features/unix_abstract/server/main.go @@ -0,0 +1,58 @@ +//go:build linux +// +build linux + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Binary server is an example server listening for gRPC connections on an +// abstract unix socket. +package main + +import ( + "context" + "fmt" + "log" + "net" + + "google.golang.org/grpc" + + pb "google.golang.org/grpc/examples/features/proto/echo" +) + +type ecServer struct { + pb.UnimplementedEchoServer + addr string +} + +func (s *ecServer) UnaryEcho(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) { + return &pb.EchoResponse{Message: fmt.Sprintf("%s (from %s)", req.Message, s.addr)}, nil +} + +func main() { + netw, addr := "unix", "\x00abstract-unix-socket" + lis, err := net.Listen(netw, addr) + if err != nil { + log.Fatalf("net.Listen(%q, %q) failed: %v", netw, addr, err) + } + s := grpc.NewServer() + pb.RegisterEchoServer(s, &ecServer{addr: addr}) + log.Printf("serving on %s\n", lis.Addr().String()) + if err := s.Serve(lis); err != nil { + log.Fatalf("failed to serve: %v", err) + } +} diff --git a/examples/features/xds/client/main.go b/examples/features/xds/client/main.go index b1daa1cae9c..97918faa224 100644 --- a/examples/features/xds/client/main.go +++ b/examples/features/xds/client/main.go @@ -16,78 +16,56 @@ * */ -// Package main implements a client for Greeter service. +// Binary main implements a client for Greeter service using gRPC's client-side +// support for xDS APIs. package main import ( "context" "flag" - "fmt" "log" + "strings" "time" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + xdscreds "google.golang.org/grpc/credentials/xds" pb "google.golang.org/grpc/examples/helloworld/helloworld" _ "google.golang.org/grpc/xds" // To install the xds resolvers and balancers. ) -const ( - defaultTarget = "localhost:50051" - defaultName = "world" +var ( + target = flag.String("target", "xds:///localhost:50051", "uri of the Greeter Server, e.g. 'xds:///helloworld-service:8080'") + name = flag.String("name", "world", "name you wished to be greeted by the server") + xdsCreds = flag.Bool("xds_creds", false, "whether the server should use xDS APIs to receive security configuration") ) -var help = flag.Bool("help", false, "Print usage information") - -func init() { - flag.Usage = func() { - fmt.Fprintf(flag.CommandLine.Output(), ` -Usage: client [name [target]] - - name - The name you wish to be greeted by. Defaults to %q - target - The URI of the server, e.g. "xds:///helloworld-service". Defaults to %q -`, defaultName, defaultTarget) - - flag.PrintDefaults() - } -} - func main() { flag.Parse() - if *help { - flag.Usage() - return - } - args := flag.Args() - - if len(args) > 2 { - flag.Usage() - return - } - name := defaultName - if len(args) > 0 { - name = args[0] + if !strings.HasPrefix(*target, "xds:///") { + log.Fatalf("-target must use a URI with scheme set to 'xds'") } - target := defaultTarget - if len(args) > 1 { - target = args[1] + creds := insecure.NewCredentials() + if *xdsCreds { + log.Println("Using xDS credentials...") + var err error + if creds, err = xdscreds.NewClientCredentials(xdscreds.ClientOptions{FallbackCreds: insecure.NewCredentials()}); err != nil { + log.Fatalf("failed to create client-side xDS credentials: %v", err) + } } - - // Set up a connection to the server. - conn, err := grpc.Dial(target, grpc.WithInsecure()) + conn, err := grpc.Dial(*target, grpc.WithTransportCredentials(creds)) if err != nil { - log.Fatalf("did not connect: %v", err) + log.Fatalf("grpc.Dial(%s) failed: %v", *target, err) } defer conn.Close() - c := pb.NewGreeterClient(conn) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - r, err := c.SayHello(ctx, &pb.HelloRequest{Name: name}) + c := pb.NewGreeterClient(conn) + r, err := c.SayHello(ctx, &pb.HelloRequest{Name: *name}) if err != nil { log.Fatalf("could not greet: %v", err) } diff --git a/examples/features/xds/server/main.go b/examples/features/xds/server/main.go index 7e0815645e5..0367060f4b5 100644 --- a/examples/features/xds/server/main.go +++ b/examples/features/xds/server/main.go @@ -16,7 +16,8 @@ * */ -// Package main starts Greeter service that will response with the hostname. +// Binary server demonstrated gRPC's support for xDS APIs on the server-side. It +// exposes the Greeter service that will response with the hostname. package main import ( @@ -27,36 +28,29 @@ import ( "math/rand" "net" "os" - "strconv" "time" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + xdscreds "google.golang.org/grpc/credentials/xds" pb "google.golang.org/grpc/examples/helloworld/helloworld" "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/reflection" + "google.golang.org/grpc/xds" ) -var help = flag.Bool("help", false, "Print usage information") - -const ( - defaultPort = 50051 +var ( + port = flag.Int("port", 50051, "the port to serve Greeter service requests on. Health service will be served on `port+1`") + xdsCreds = flag.Bool("xds_creds", false, "whether the server should use xDS APIs to receive security configuration") ) -// server is used to implement helloworld.GreeterServer. +// server implements helloworld.GreeterServer interface. type server struct { pb.UnimplementedGreeterServer - serverName string } -func newServer(serverName string) *server { - return &server{ - serverName: serverName, - } -} - -// SayHello implements helloworld.GreeterServer +// SayHello implements helloworld.GreeterServer interface. func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { log.Printf("Received: %v", in.GetName()) return &pb.HelloReply{Message: "Hello " + in.GetName() + ", from " + s.serverName}, nil @@ -72,65 +66,40 @@ func determineHostname() string { return hostname } -func init() { - flag.Usage = func() { - fmt.Fprintf(flag.CommandLine.Output(), ` -Usage: server [port [hostname]] - - port - The listen port. Defaults to %d - hostname - The name clients will see in greet responses. Defaults to the machine's hostname -`, defaultPort) - - flag.PrintDefaults() - } -} - func main() { flag.Parse() - if *help { - flag.Usage() - return - } - args := flag.Args() - if len(args) > 2 { - flag.Usage() - return + greeterPort := fmt.Sprintf(":%d", *port) + greeterLis, err := net.Listen("tcp4", greeterPort) + if err != nil { + log.Fatalf("net.Listen(tcp4, %q) failed: %v", greeterPort, err) } - port := defaultPort - if len(args) > 0 { + creds := insecure.NewCredentials() + if *xdsCreds { + log.Println("Using xDS credentials...") var err error - port, err = strconv.Atoi(args[0]) - if err != nil { - log.Printf("Invalid port number: %v", err) - flag.Usage() - return + if creds, err = xdscreds.NewServerCredentials(xdscreds.ServerOptions{FallbackCreds: insecure.NewCredentials()}); err != nil { + log.Fatalf("failed to create server-side xDS credentials: %v", err) } } - var hostname string - if len(args) > 1 { - hostname = args[1] - } - if hostname == "" { - hostname = determineHostname() - } + greeterServer := xds.NewGRPCServer(grpc.Creds(creds)) + pb.RegisterGreeterServer(greeterServer, &server{serverName: determineHostname()}) - lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) + healthPort := fmt.Sprintf(":%d", *port+1) + healthLis, err := net.Listen("tcp4", healthPort) if err != nil { - log.Fatalf("failed to listen: %v", err) + log.Fatalf("net.Listen(tcp4, %q) failed: %v", healthPort, err) } - s := grpc.NewServer() - pb.RegisterGreeterServer(s, newServer(hostname)) - - reflection.Register(s) + grpcServer := grpc.NewServer() healthServer := health.NewServer() healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) - healthpb.RegisterHealthServer(s, healthServer) + healthpb.RegisterHealthServer(grpcServer, healthServer) - log.Printf("serving on %s, hostname %s", lis.Addr(), hostname) - s.Serve(lis) + log.Printf("Serving GreeterService on %s and HealthService on %s", greeterLis.Addr().String(), healthLis.Addr().String()) + go func() { + greeterServer.Serve(greeterLis) + }() + grpcServer.Serve(healthLis) } diff --git a/examples/go.mod b/examples/go.mod index 18c67afed96..4f19b852edd 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -1,12 +1,12 @@ module google.golang.org/grpc/examples -go 1.11 +go 1.14 require ( - github.com/golang/protobuf v1.4.2 + github.com/golang/protobuf v1.4.3 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 - google.golang.org/grpc v1.31.0 + google.golang.org/grpc v1.36.0 google.golang.org/protobuf v1.25.0 ) diff --git a/examples/go.sum b/examples/go.sum index c6aa163b01a..a359cfc183f 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1,66 +1,77 @@ cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 h1:cqQfy1jclcSy/FwLjemeg3SR1yaINm74aQyupQ0Bl8M= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158 h1:CevA8fI91PAnP8vpnXuB8ZYAZ5wqY86nAbxfgK8tWO4= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d h1:QyzYnTnPE15SQyUeqU6qLbWxMkwyAyu+vGksa0b7j00= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021 h1:fP+fF0up6oPY49OrjPrhIJ8yQfdIM85NXMLkMg1EXVs= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0 h1:oOuy+ugB+P/kBdUnG5QaMXSIyJ1q38wWSojYCb3z5VQ= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 h1:gSJIx1SDwno+2ElGhA4+qG2zF97qiUzTM+rQ0klBOcE= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 h1:LCO0fg4kb6WwkXQXRQQgUYsFeFb5taTX5WAx5O/Vt28= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= @@ -68,16 +79,16 @@ google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLY google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0 h1:qdOKuR/EIArgaWNjetjgTzgVTAZ+S/WXVrq9HW9zimw= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0 h1:UhZDfRO8JRQru4/+LlLE0BRKGF8L+PICnvYZmx/fEGA= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/examples/helloworld/greeter_server/main.go b/examples/helloworld/greeter_server/main.go index 15604f9fc1f..4d077db92cf 100644 --- a/examples/helloworld/greeter_server/main.go +++ b/examples/helloworld/greeter_server/main.go @@ -50,6 +50,7 @@ func main() { } s := grpc.NewServer() pb.RegisterGreeterServer(s, &server{}) + log.Printf("server listening at %v", lis.Addr()) if err := s.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) } diff --git a/examples/helloworld/helloworld/helloworld_grpc.pb.go b/examples/helloworld/helloworld/helloworld_grpc.pb.go index 39a0301c16b..ae27dfa3cfe 100644 --- a/examples/helloworld/helloworld/helloworld_grpc.pb.go +++ b/examples/helloworld/helloworld/helloworld_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: examples/helloworld/helloworld/helloworld.proto package helloworld diff --git a/examples/route_guide/client/client.go b/examples/route_guide/client/client.go index 172f10fb308..f18c10af8b1 100644 --- a/examples/route_guide/client/client.go +++ b/examples/route_guide/client/client.go @@ -40,7 +40,7 @@ var ( tls = flag.Bool("tls", false, "Connection uses TLS if true, else plain TCP") caFile = flag.String("ca_file", "", "The file containing the CA root cert file") serverAddr = flag.String("server_addr", "localhost:10000", "The server address in the format of host:port") - serverHostOverride = flag.String("server_host_override", "x.test.youtube.com", "The server name used to verify the hostname returned by the TLS handshake") + serverHostOverride = flag.String("server_host_override", "x.test.example.com", "The server name used to verify the hostname returned by the TLS handshake") ) // printFeature gets the feature for the given point. diff --git a/examples/route_guide/routeguide/route_guide_grpc.pb.go b/examples/route_guide/routeguide/route_guide_grpc.pb.go index 66860e63c47..efa7c28ce6f 100644 --- a/examples/route_guide/routeguide/route_guide_grpc.pb.go +++ b/examples/route_guide/routeguide/route_guide_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: examples/route_guide/routeguide/route_guide.proto package routeguide diff --git a/go.mod b/go.mod index b177cfa66df..022cc9828fe 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,18 @@ module google.golang.org/grpc -go 1.11 +go 1.14 require ( + github.com/cespare/xxhash/v2 v2.1.1 github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 - github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d + github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b - github.com/golang/protobuf v1.4.2 + github.com/golang/protobuf v1.4.3 github.com/google/go-cmp v0.5.0 github.com/google/uuid v1.1.2 - golang.org/x/net v0.0.0-20190311183353-d8887717615a - golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be - golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a + golang.org/x/net v0.0.0-20200822124328-c89045814202 + golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d + golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 google.golang.org/protobuf v1.25.0 ) diff --git a/go.sum b/go.sum index bb25cd49156..6e7ae0db2b3 100644 --- a/go.sum +++ b/go.sum @@ -1,34 +1,44 @@ -cloud.google.com/go v0.26.0 h1:e0WKqKTd5BnrG8aKH3J3h+QvEIQtSUcf2n5UZ5ZgLtQ= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 h1:cqQfy1jclcSy/FwLjemeg3SR1yaINm74aQyupQ0Bl8M= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158 h1:CevA8fI91PAnP8vpnXuB8ZYAZ5wqY86nAbxfgK8tWO4= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d h1:QyzYnTnPE15SQyUeqU6qLbWxMkwyAyu+vGksa0b7j00= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021 h1:fP+fF0up6oPY49OrjPrhIJ8yQfdIM85NXMLkMg1EXVs= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -37,50 +47,65 @@ github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be h1:vEDujvNQGv4jgYKudGeI/+DAX4Jffq6hpD55MmoEvKs= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= +google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -93,7 +118,9 @@ google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4 google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/grpclog/loggerv2.go b/grpclog/loggerv2.go index 4ee33171e00..34098bb8eb5 100644 --- a/grpclog/loggerv2.go +++ b/grpclog/loggerv2.go @@ -19,11 +19,14 @@ package grpclog import ( + "encoding/json" + "fmt" "io" "io/ioutil" "log" "os" "strconv" + "strings" "google.golang.org/grpc/internal/grpclog" ) @@ -95,8 +98,9 @@ var severityName = []string{ // loggerT is the default logger used by grpclog. type loggerT struct { - m []*log.Logger - v int + m []*log.Logger + v int + jsonFormat bool } // NewLoggerV2 creates a loggerV2 with the provided writers. @@ -105,19 +109,32 @@ type loggerT struct { // Warning logs will be written to warningW and infoW. // Info logs will be written to infoW. func NewLoggerV2(infoW, warningW, errorW io.Writer) LoggerV2 { - return NewLoggerV2WithVerbosity(infoW, warningW, errorW, 0) + return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{}) } // NewLoggerV2WithVerbosity creates a loggerV2 with the provided writers and // verbosity level. func NewLoggerV2WithVerbosity(infoW, warningW, errorW io.Writer, v int) LoggerV2 { + return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{verbose: v}) +} + +type loggerV2Config struct { + verbose int + jsonFormat bool +} + +func newLoggerV2WithConfig(infoW, warningW, errorW io.Writer, c loggerV2Config) LoggerV2 { var m []*log.Logger - m = append(m, log.New(infoW, severityName[infoLog]+": ", log.LstdFlags)) - m = append(m, log.New(io.MultiWriter(infoW, warningW), severityName[warningLog]+": ", log.LstdFlags)) + flag := log.LstdFlags + if c.jsonFormat { + flag = 0 + } + m = append(m, log.New(infoW, "", flag)) + m = append(m, log.New(io.MultiWriter(infoW, warningW), "", flag)) ew := io.MultiWriter(infoW, warningW, errorW) // ew will be used for error and fatal. - m = append(m, log.New(ew, severityName[errorLog]+": ", log.LstdFlags)) - m = append(m, log.New(ew, severityName[fatalLog]+": ", log.LstdFlags)) - return &loggerT{m: m, v: v} + m = append(m, log.New(ew, "", flag)) + m = append(m, log.New(ew, "", flag)) + return &loggerT{m: m, v: c.verbose, jsonFormat: c.jsonFormat} } // newLoggerV2 creates a loggerV2 to be used as default logger. @@ -142,58 +159,79 @@ func newLoggerV2() LoggerV2 { if vl, err := strconv.Atoi(vLevel); err == nil { v = vl } - return NewLoggerV2WithVerbosity(infoW, warningW, errorW, v) + + jsonFormat := strings.EqualFold(os.Getenv("GRPC_GO_LOG_FORMATTER"), "json") + + return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{ + verbose: v, + jsonFormat: jsonFormat, + }) +} + +func (g *loggerT) output(severity int, s string) { + sevStr := severityName[severity] + if !g.jsonFormat { + g.m[severity].Output(2, fmt.Sprintf("%v: %v", sevStr, s)) + return + } + // TODO: we can also include the logging component, but that needs more + // (API) changes. + b, _ := json.Marshal(map[string]string{ + "severity": sevStr, + "message": s, + }) + g.m[severity].Output(2, string(b)) } func (g *loggerT) Info(args ...interface{}) { - g.m[infoLog].Print(args...) + g.output(infoLog, fmt.Sprint(args...)) } func (g *loggerT) Infoln(args ...interface{}) { - g.m[infoLog].Println(args...) + g.output(infoLog, fmt.Sprintln(args...)) } func (g *loggerT) Infof(format string, args ...interface{}) { - g.m[infoLog].Printf(format, args...) + g.output(infoLog, fmt.Sprintf(format, args...)) } func (g *loggerT) Warning(args ...interface{}) { - g.m[warningLog].Print(args...) + g.output(warningLog, fmt.Sprint(args...)) } func (g *loggerT) Warningln(args ...interface{}) { - g.m[warningLog].Println(args...) + g.output(warningLog, fmt.Sprintln(args...)) } func (g *loggerT) Warningf(format string, args ...interface{}) { - g.m[warningLog].Printf(format, args...) + g.output(warningLog, fmt.Sprintf(format, args...)) } func (g *loggerT) Error(args ...interface{}) { - g.m[errorLog].Print(args...) + g.output(errorLog, fmt.Sprint(args...)) } func (g *loggerT) Errorln(args ...interface{}) { - g.m[errorLog].Println(args...) + g.output(errorLog, fmt.Sprintln(args...)) } func (g *loggerT) Errorf(format string, args ...interface{}) { - g.m[errorLog].Printf(format, args...) + g.output(errorLog, fmt.Sprintf(format, args...)) } func (g *loggerT) Fatal(args ...interface{}) { - g.m[fatalLog].Fatal(args...) - // No need to call os.Exit() again because log.Logger.Fatal() calls os.Exit(). + g.output(fatalLog, fmt.Sprint(args...)) + os.Exit(1) } func (g *loggerT) Fatalln(args ...interface{}) { - g.m[fatalLog].Fatalln(args...) - // No need to call os.Exit() again because log.Logger.Fatal() calls os.Exit(). + g.output(fatalLog, fmt.Sprintln(args...)) + os.Exit(1) } func (g *loggerT) Fatalf(format string, args ...interface{}) { - g.m[fatalLog].Fatalf(format, args...) - // No need to call os.Exit() again because log.Logger.Fatal() calls os.Exit(). + g.output(fatalLog, fmt.Sprintf(format, args...)) + os.Exit(1) } func (g *loggerT) V(l int) bool { diff --git a/grpclog/loggerv2_test.go b/grpclog/loggerv2_test.go index 756f215f9c8..0b2c8b23d66 100644 --- a/grpclog/loggerv2_test.go +++ b/grpclog/loggerv2_test.go @@ -52,9 +52,9 @@ func TestLoggerV2Severity(t *testing.T) { } // check if b is in the format of: -// WARNING: 2017/04/07 14:55:42 WARNING +// 2017/04/07 14:55:42 WARNING: WARNING func checkLogForSeverity(s int, b []byte) error { - expected := regexp.MustCompile(fmt.Sprintf(`^%s: [0-9]{4}/[0-9]{2}/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} %s\n$`, severityName[s], severityName[s])) + expected := regexp.MustCompile(fmt.Sprintf(`^[0-9]{4}/[0-9]{2}/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} %s: %s\n$`, severityName[s], severityName[s])) if m := expected.Match(b); !m { return fmt.Errorf("got: %v, want string in format of: %v", string(b), severityName[s]+": 2016/10/05 17:09:26 "+severityName[s]) } diff --git a/health/grpc_health_v1/health_grpc.pb.go b/health/grpc_health_v1/health_grpc.pb.go index 386d16ce62d..bdc3ae284e7 100644 --- a/health/grpc_health_v1/health_grpc.pb.go +++ b/health/grpc_health_v1/health_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/health/v1/health.proto package grpc_health_v1 diff --git a/install_gae.sh b/install_gae.sh deleted file mode 100755 index 15ff9facdd7..00000000000 --- a/install_gae.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -TMP=$(mktemp -d /tmp/sdk.XXX) \ -&& curl -o $TMP.zip "https://storage.googleapis.com/appengine-sdks/featured/go_appengine_sdk_linux_amd64-1.9.68.zip" \ -&& unzip -q $TMP.zip -d $TMP \ -&& export PATH="$PATH:$TMP/go_appengine" \ No newline at end of file diff --git a/internal/admin/admin.go b/internal/admin/admin.go new file mode 100644 index 00000000000..a9285ee7484 --- /dev/null +++ b/internal/admin/admin.go @@ -0,0 +1,60 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package admin contains internal implementation for admin service. +package admin + +import "google.golang.org/grpc" + +// services is a map from name to service register functions. +var services []func(grpc.ServiceRegistrar) (func(), error) + +// AddService adds a service to the list of admin services. +// +// NOTE: this function must only be called during initialization time (i.e. in +// an init() function), and is not thread-safe. +// +// If multiple services with the same service name are added (e.g. two services +// for `grpc.channelz.v1.Channelz`), the server will panic on `Register()`. +func AddService(f func(grpc.ServiceRegistrar) (func(), error)) { + services = append(services, f) +} + +// Register registers the set of admin services to the given server. +func Register(s grpc.ServiceRegistrar) (cleanup func(), _ error) { + var cleanups []func() + for _, f := range services { + cleanup, err := f(s) + if err != nil { + callFuncs(cleanups) + return nil, err + } + if cleanup != nil { + cleanups = append(cleanups, cleanup) + } + } + return func() { + callFuncs(cleanups) + }, nil +} + +func callFuncs(fs []func()) { + for _, f := range fs { + f() + } +} diff --git a/internal/balancer/stub/stub.go b/internal/balancer/stub/stub.go index e3757c1a50b..950eaaa0278 100644 --- a/internal/balancer/stub/stub.go +++ b/internal/balancer/stub/stub.go @@ -33,6 +33,7 @@ type BalancerFuncs struct { ResolverError func(*BalancerData, error) UpdateSubConnState func(*BalancerData, balancer.SubConn, balancer.SubConnState) Close func(*BalancerData) + ExitIdle func(*BalancerData) } // BalancerData contains data relevant to a stub balancer. @@ -75,6 +76,12 @@ func (b *bal) Close() { } } +func (b *bal) ExitIdle() { + if b.bf.ExitIdle != nil { + b.bf.ExitIdle(b.bd) + } +} + type bb struct { name string bf BalancerFuncs diff --git a/internal/binarylog/sink.go b/internal/binarylog/sink.go index 7d7a3056b71..c2fdd58b319 100644 --- a/internal/binarylog/sink.go +++ b/internal/binarylog/sink.go @@ -69,7 +69,8 @@ type writerSink struct { func (ws *writerSink) Write(e *pb.GrpcLogEntry) error { b, err := proto.Marshal(e) if err != nil { - grpclogLogger.Infof("binary logging: failed to marshal proto message: %v", err) + grpclogLogger.Errorf("binary logging: failed to marshal proto message: %v", err) + return err } hdr := make([]byte, 4) binary.BigEndian.PutUint32(hdr, uint32(len(b))) @@ -85,24 +86,27 @@ func (ws *writerSink) Write(e *pb.GrpcLogEntry) error { func (ws *writerSink) Close() error { return nil } type bufferedSink struct { - mu sync.Mutex - closer io.Closer - out Sink // out is built on buf. - buf *bufio.Writer // buf is kept for flush. - - writeStartOnce sync.Once - writeTicker *time.Ticker + mu sync.Mutex + closer io.Closer + out Sink // out is built on buf. + buf *bufio.Writer // buf is kept for flush. + flusherStarted bool + + writeTicker *time.Ticker + done chan struct{} } func (fs *bufferedSink) Write(e *pb.GrpcLogEntry) error { - // Start the write loop when Write is called. - fs.writeStartOnce.Do(fs.startFlushGoroutine) fs.mu.Lock() + defer fs.mu.Unlock() + if !fs.flusherStarted { + // Start the write loop when Write is called. + fs.startFlushGoroutine() + fs.flusherStarted = true + } if err := fs.out.Write(e); err != nil { - fs.mu.Unlock() return err } - fs.mu.Unlock() return nil } @@ -113,7 +117,12 @@ const ( func (fs *bufferedSink) startFlushGoroutine() { fs.writeTicker = time.NewTicker(bufFlushDuration) go func() { - for range fs.writeTicker.C { + for { + select { + case <-fs.done: + return + case <-fs.writeTicker.C: + } fs.mu.Lock() if err := fs.buf.Flush(); err != nil { grpclogLogger.Warningf("failed to flush to Sink: %v", err) @@ -124,10 +133,12 @@ func (fs *bufferedSink) startFlushGoroutine() { } func (fs *bufferedSink) Close() error { + fs.mu.Lock() + defer fs.mu.Unlock() if fs.writeTicker != nil { fs.writeTicker.Stop() } - fs.mu.Lock() + close(fs.done) if err := fs.buf.Flush(); err != nil { grpclogLogger.Warningf("failed to flush to Sink: %v", err) } @@ -137,7 +148,6 @@ func (fs *bufferedSink) Close() error { if err := fs.out.Close(); err != nil { grpclogLogger.Warningf("failed to close the Sink: %v", err) } - fs.mu.Unlock() return nil } @@ -155,5 +165,6 @@ func NewBufferedSink(o io.WriteCloser) Sink { closer: o, out: newWriterSink(bufW), buf: bufW, + done: make(chan struct{}), } } diff --git a/internal/channelz/funcs.go b/internal/channelz/funcs.go index f7314139303..6d5760d9514 100644 --- a/internal/channelz/funcs.go +++ b/internal/channelz/funcs.go @@ -630,7 +630,7 @@ func (c *channelMap) GetServerSockets(id int64, startID int64, maxResults int64) if count == 0 { end = true } - var s []*SocketMetric + s := make([]*SocketMetric, 0, len(sks)) for _, ns := range sks { sm := &SocketMetric{} sm.SocketData = ns.s.ChannelzMetric() diff --git a/internal/channelz/types_linux.go b/internal/channelz/types_linux.go index 692dd618177..1b1c4cce34a 100644 --- a/internal/channelz/types_linux.go +++ b/internal/channelz/types_linux.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/channelz/types_nonlinux.go b/internal/channelz/types_nonlinux.go index 19c2fc521dc..8b06eed1ab8 100644 --- a/internal/channelz/types_nonlinux.go +++ b/internal/channelz/types_nonlinux.go @@ -1,4 +1,5 @@ -// +build !linux appengine +//go:build !linux +// +build !linux /* * @@ -37,6 +38,6 @@ type SocketOptionData struct { // Windows OS doesn't support Socket Option func (s *SocketOptionData) Getsockopt(fd uintptr) { once.Do(func() { - logger.Warning("Channelz: socket options are not supported on non-linux os and appengine.") + logger.Warning("Channelz: socket options are not supported on non-linux environments") }) } diff --git a/internal/channelz/util_linux.go b/internal/channelz/util_linux.go index fdf409d55de..8d194e44e1d 100644 --- a/internal/channelz/util_linux.go +++ b/internal/channelz/util_linux.go @@ -1,5 +1,3 @@ -// +build linux,!appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/channelz/util_nonlinux.go b/internal/channelz/util_nonlinux.go index 8864a081116..837ddc40240 100644 --- a/internal/channelz/util_nonlinux.go +++ b/internal/channelz/util_nonlinux.go @@ -1,4 +1,5 @@ -// +build !linux appengine +//go:build !linux +// +build !linux /* * diff --git a/internal/channelz/util_test.go b/internal/channelz/util_test.go index 3d1a1183fa4..9de6679043d 100644 --- a/internal/channelz/util_test.go +++ b/internal/channelz/util_test.go @@ -1,4 +1,5 @@ -// +build linux,!appengine +//go:build linux +// +build linux /* * diff --git a/internal/credentials/credentials.go b/internal/credentials/credentials.go new file mode 100644 index 00000000000..32c9b59033c --- /dev/null +++ b/internal/credentials/credentials.go @@ -0,0 +1,49 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package credentials + +import ( + "context" +) + +// requestInfoKey is a struct to be used as the key to store RequestInfo in a +// context. +type requestInfoKey struct{} + +// NewRequestInfoContext creates a context with ri. +func NewRequestInfoContext(ctx context.Context, ri interface{}) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) +} + +// RequestInfoFromContext extracts the RequestInfo from ctx. +func RequestInfoFromContext(ctx context.Context) interface{} { + return ctx.Value(requestInfoKey{}) +} + +// clientHandshakeInfoKey is a struct used as the key to store +// ClientHandshakeInfo in a context. +type clientHandshakeInfoKey struct{} + +// ClientHandshakeInfoFromContext extracts the ClientHandshakeInfo from ctx. +func ClientHandshakeInfoFromContext(ctx context.Context) interface{} { + return ctx.Value(clientHandshakeInfoKey{}) +} + +// NewClientHandshakeInfoContext creates a context with chi. +func NewClientHandshakeInfoContext(ctx context.Context, chi interface{}) context.Context { + return context.WithValue(ctx, clientHandshakeInfoKey{}, chi) +} diff --git a/internal/credentials/spiffe.go b/internal/credentials/spiffe.go index be70b6cdfc3..25ade623058 100644 --- a/internal/credentials/spiffe.go +++ b/internal/credentials/spiffe.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2020 gRPC authors. diff --git a/internal/credentials/spiffe_appengine.go b/internal/credentials/spiffe_appengine.go deleted file mode 100644 index af6f5771976..00000000000 --- a/internal/credentials/spiffe_appengine.go +++ /dev/null @@ -1,31 +0,0 @@ -// +build appengine - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package credentials - -import ( - "crypto/tls" - "net/url" -) - -// SPIFFEIDFromState is a no-op for appengine builds. -func SPIFFEIDFromState(state tls.ConnectionState) *url.URL { - return nil -} diff --git a/internal/credentials/syscallconn.go b/internal/credentials/syscallconn.go index f499a614c20..2919632d657 100644 --- a/internal/credentials/syscallconn.go +++ b/internal/credentials/syscallconn.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/credentials/syscallconn_test.go b/internal/credentials/syscallconn_test.go index ee17a0ca67b..b229a47d116 100644 --- a/internal/credentials/syscallconn_test.go +++ b/internal/credentials/syscallconn_test.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/credentials/util.go b/internal/credentials/util.go index 55664fa46b8..f792fd22caf 100644 --- a/internal/credentials/util.go +++ b/internal/credentials/util.go @@ -18,7 +18,9 @@ package credentials -import "crypto/tls" +import ( + "crypto/tls" +) const alpnProtoStrH2 = "h2" diff --git a/internal/credentials/xds/handshake_info.go b/internal/credentials/xds/handshake_info.go index 8b203566063..6ef43cc89fa 100644 --- a/internal/credentials/xds/handshake_info.go +++ b/internal/credentials/xds/handshake_info.go @@ -25,11 +25,13 @@ import ( "crypto/x509" "errors" "fmt" + "strings" "sync" "google.golang.org/grpc/attributes" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/resolver" ) @@ -64,8 +66,8 @@ type HandshakeInfo struct { mu sync.Mutex rootProvider certprovider.Provider identityProvider certprovider.Provider - acceptedSANs map[string]bool // Only on the client side. - requireClientCert bool // Only on server side. + sanMatchers []matcher.StringMatcher // Only on the client side. + requireClientCert bool // Only on server side. } // SetRootCertProvider updates the root certificate provider. @@ -82,13 +84,10 @@ func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) hi.mu.Unlock() } -// SetAcceptedSANs updates the list of accepted SANs. -func (hi *HandshakeInfo) SetAcceptedSANs(sans []string) { +// SetSANMatchers updates the list of SAN matchers. +func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []matcher.StringMatcher) { hi.mu.Lock() - hi.acceptedSANs = make(map[string]bool, len(sans)) - for _, san := range sans { - hi.acceptedSANs[san] = true - } + hi.sanMatchers = sanMatchers hi.mu.Unlock() } @@ -112,6 +111,14 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool { return hi.identityProvider == nil && hi.rootProvider == nil } +// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo. +// To be used only for testing purposes. +func (hi *HandshakeInfo) GetSANMatchersForTesting() []matcher.StringMatcher { + hi.mu.Lock() + defer hi.mu.Unlock() + return append([]matcher.StringMatcher{}, hi.sanMatchers...) +} + // ClientSideTLSConfig constructs a tls.Config to be used in a client-side // handshake based on the contents of the HandshakeInfo. func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) { @@ -131,7 +138,10 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, // Currently the Go stdlib does complete verification of the cert (which // includes hostname verification) or none. We are forced to go with the // latter and perform the normal cert validation ourselves. - cfg := &tls.Config{InsecureSkipVerify: true} + cfg := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + } km, err := rootProv.KeyMaterial(ctx) if err != nil { @@ -152,7 +162,10 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, // ServerSideTLSConfig constructs a tls.Config to be used in a server-side // handshake based on the contents of the HandshakeInfo. func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, error) { - cfg := &tls.Config{ClientAuth: tls.NoClientCert} + cfg := &tls.Config{ + ClientAuth: tls.NoClientCert, + NextProtos: []string{"h2"}, + } hi.mu.Lock() // On the server side, identityProvider is mandatory. RootProvider is // optional based on whether the server is doing TLS or mTLS. @@ -184,47 +197,115 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, return cfg, nil } -// MatchingSANExists returns true if the SAN contained in the passed in -// certificate is present in the list of accepted SANs in the HandshakeInfo. +// MatchingSANExists returns true if the SANs contained in cert match the +// criteria enforced by the list of SAN matchers in HandshakeInfo. // -// If the list of accepted SANs in the HandshakeInfo is empty, this function +// If the list of SAN matchers in the HandshakeInfo is empty, this function // returns true for all input certificates. func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool { - if len(hi.acceptedSANs) == 0 { + hi.mu.Lock() + defer hi.mu.Unlock() + if len(hi.sanMatchers) == 0 { return true } - var sans []string // SANs can be specified in any of these four fields on the parsed cert. - sans = append(sans, cert.DNSNames...) - sans = append(sans, cert.EmailAddresses...) - for _, ip := range cert.IPAddresses { - sans = append(sans, ip.String()) + for _, san := range cert.DNSNames { + if hi.matchSAN(san, true) { + return true + } } - for _, uri := range cert.URIs { - sans = append(sans, uri.String()) + for _, san := range cert.EmailAddresses { + if hi.matchSAN(san, false) { + return true + } + } + for _, san := range cert.IPAddresses { + if hi.matchSAN(san.String(), false) { + return true + } + } + for _, san := range cert.URIs { + if hi.matchSAN(san.String(), false) { + return true + } } + return false +} - hi.mu.Lock() - defer hi.mu.Unlock() - for _, san := range sans { - if hi.acceptedSANs[san] { +// Caller must hold mu. +func (hi *HandshakeInfo) matchSAN(san string, isDNS bool) bool { + for _, matcher := range hi.sanMatchers { + if em := matcher.ExactMatch(); em != "" && isDNS { + // This is a special case which is documented in the xDS protos. + // If the DNS SAN is a wildcard entry, and the match criteria is + // `exact`, then we need to perform DNS wildcard matching + // instead of regular string comparison. + if dnsMatch(em, san) { + return true + } + continue + } + if matcher.Match(san) { return true } } return false } -// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root -// and identity certificate providers. -func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *HandshakeInfo { - acceptedSANs := make(map[string]bool, len(sans)) - for _, san := range sans { - acceptedSANs[san] = true +// dnsMatch implements a DNS wildcard matching algorithm based on RFC2828 and +// grpc-java's implementation in `OkHostnameVerifier` class. +// +// NOTE: Here the `host` argument is the one from the set of string matchers in +// the xDS proto and the `san` argument is a DNS SAN from the certificate, and +// this is the one which can potentially contain a wildcard pattern. +func dnsMatch(host, san string) bool { + // Add trailing "." and turn them into absolute domain names. + if !strings.HasSuffix(host, ".") { + host += "." + } + if !strings.HasSuffix(san, ".") { + san += "." } - return &HandshakeInfo{ - rootProvider: root, - identityProvider: identity, - acceptedSANs: acceptedSANs, + // Domain names are case-insensitive. + host = strings.ToLower(host) + san = strings.ToLower(san) + + // If san does not contain a wildcard, do exact match. + if !strings.Contains(san, "*") { + return host == san + } + + // Wildcard dns matching rules + // - '*' is only permitted in the left-most label and must be the only + // character in that label. For example, *.example.com is permitted, while + // *a.example.com, a*.example.com, a*b.example.com, a.*.example.com are + // not permitted. + // - '*' matches a single domain name component. For example, *.example.com + // matches test.example.com but does not match sub.test.example.com. + // - Wildcard patterns for single-label domain names are not permitted. + if san == "*." || !strings.HasPrefix(san, "*.") || strings.Contains(san[1:], "*") { + return false + } + // Optimization: at this point, we know that the san contains a '*' and + // is the first domain component of san. So, the host name must be at + // least as long as the san to be able to match. + if len(host) < len(san) { + return false + } + // Hostname must end with the non-wildcard portion of san. + if !strings.HasSuffix(host, san[1:]) { + return false } + // At this point we know that the hostName and san share the same suffix + // (the non-wildcard portion of san). Now, we just need to make sure + // that the '*' does not match across domain components. + hostPrefix := strings.TrimSuffix(host, san[1:]) + return !strings.Contains(hostPrefix, ".") +} + +// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root +// and identity certificate providers. +func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo { + return &HandshakeInfo{rootProvider: root, identityProvider: identity} } diff --git a/internal/credentials/xds/handshake_info_test.go b/internal/credentials/xds/handshake_info_test.go new file mode 100644 index 00000000000..91257a1925d --- /dev/null +++ b/internal/credentials/xds/handshake_info_test.go @@ -0,0 +1,304 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "crypto/x509" + "net" + "net/url" + "regexp" + "testing" + + "google.golang.org/grpc/internal/xds/matcher" +) + +func TestDNSMatch(t *testing.T) { + tests := []struct { + desc string + host string + pattern string + wantMatch bool + }{ + { + desc: "invalid wildcard 1", + host: "aa.example.com", + pattern: "*a.example.com", + wantMatch: false, + }, + { + desc: "invalid wildcard 2", + host: "aa.example.com", + pattern: "a*.example.com", + wantMatch: false, + }, + { + desc: "invalid wildcard 3", + host: "abc.example.com", + pattern: "a*c.example.com", + wantMatch: false, + }, + { + desc: "wildcard in one of the middle components", + host: "abc.test.example.com", + pattern: "abc.*.example.com", + wantMatch: false, + }, + { + desc: "single component wildcard", + host: "a.example.com", + pattern: "*", + wantMatch: false, + }, + { + desc: "short host name", + host: "a.com", + pattern: "*.example.com", + wantMatch: false, + }, + { + desc: "suffix mismatch", + host: "a.notexample.com", + pattern: "*.example.com", + wantMatch: false, + }, + { + desc: "wildcard match across components", + host: "sub.test.example.com", + pattern: "*.example.com.", + wantMatch: false, + }, + { + desc: "host doesn't end in period", + host: "test.example.com", + pattern: "test.example.com.", + wantMatch: true, + }, + { + desc: "pattern doesn't end in period", + host: "test.example.com.", + pattern: "test.example.com", + wantMatch: true, + }, + { + desc: "case insensitive", + host: "TEST.EXAMPLE.COM.", + pattern: "test.example.com.", + wantMatch: true, + }, + { + desc: "simple match", + host: "test.example.com", + pattern: "test.example.com", + wantMatch: true, + }, + { + desc: "good wildcard", + host: "a.example.com", + pattern: "*.example.com", + wantMatch: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotMatch := dnsMatch(test.host, test.pattern) + if gotMatch != test.wantMatch { + t.Fatalf("dnsMatch(%s, %s) = %v, want %v", test.host, test.pattern, gotMatch, test.wantMatch) + } + }) + } +} + +func TestMatchingSANExists_FailureCases(t *testing.T) { + url1, err := url.Parse("http://golang.org") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + url2, err := url.Parse("https://github.com/grpc/grpc-go") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + inputCert := &x509.Certificate{ + DNSNames: []string{"foo.bar.example.com", "bar.baz.test.com", "*.example.com"}, + EmailAddresses: []string{"foobar@example.com", "barbaz@test.com"}, + IPAddresses: []net.IP{net.ParseIP("192.0.0.1"), net.ParseIP("2001:db8::68")}, + URIs: []*url.URL{url1, url2}, + } + + tests := []struct { + desc string + sanMatchers []matcher.StringMatcher + }{ + { + desc: "exact match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP("abcd.test.com"), nil, nil, nil, nil, false), + matcher.StringMatcherForTesting(newStringP("http://golang"), nil, nil, nil, nil, false), + matcher.StringMatcherForTesting(newStringP("HTTP://GOLANG.ORG"), nil, nil, nil, nil, false), + }, + }, + { + desc: "prefix match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, newStringP("i-aint-the-one"), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, newStringP("192.168.1.1"), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, newStringP("FOO.BAR"), nil, nil, nil, false), + }, + }, + { + desc: "suffix match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, newStringP("i-aint-the-one"), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP("1::68"), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP(".COM"), nil, nil, false), + }, + }, + { + desc: "regex match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`.*\.examples\.com`), false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`), false), + }, + }, + { + desc: "contains match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("i-aint-the-one"), nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("2001:db8:1:1::68"), nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("GRPC"), nil, false), + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + hi := NewHandshakeInfo(nil, nil) + hi.SetSANMatchers(test.sanMatchers) + + if hi.MatchingSANExists(inputCert) { + t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v succeeded when expected to fail", inputCert, test.sanMatchers) + } + }) + } +} + +func TestMatchingSANExists_Success(t *testing.T) { + url1, err := url.Parse("http://golang.org") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + url2, err := url.Parse("https://github.com/grpc/grpc-go") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + inputCert := &x509.Certificate{ + DNSNames: []string{"baz.test.com", "*.example.com"}, + EmailAddresses: []string{"foobar@example.com", "barbaz@test.com"}, + IPAddresses: []net.IP{net.ParseIP("192.0.0.1"), net.ParseIP("2001:db8::68")}, + URIs: []*url.URL{url1, url2}, + } + + tests := []struct { + desc string + sanMatchers []matcher.StringMatcher + }{ + { + desc: "no san matchers", + }, + { + desc: "exact match dns wildcard", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, newStringP("192.168.1.1"), nil, nil, nil, false), + matcher.StringMatcherForTesting(newStringP("https://github.com/grpc/grpc-java"), nil, nil, nil, nil, false), + matcher.StringMatcherForTesting(newStringP("abc.example.com"), nil, nil, nil, nil, false), + }, + }, + { + desc: "exact match ignore case", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP("FOOBAR@EXAMPLE.COM"), nil, nil, nil, nil, true), + }, + }, + { + desc: "prefix match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, newStringP(".co.in"), nil, nil, false), + matcher.StringMatcherForTesting(nil, newStringP("192.168.1.1"), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, newStringP("baz.test"), nil, nil, nil, false), + }, + }, + { + desc: "prefix match ignore case", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, newStringP("BAZ.test"), nil, nil, nil, true), + }, + }, + { + desc: "suffix match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`), false), + matcher.StringMatcherForTesting(nil, nil, newStringP("192.168.1.1"), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP("@test.com"), nil, nil, false), + }, + }, + { + desc: "suffix match ignore case", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, newStringP("@test.COM"), nil, nil, true), + }, + }, + { + desc: "regex match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("https://github.com/grpc/grpc-java"), nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`), false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`.*\.test\.com`), false), + }, + }, + { + desc: "contains match", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP("https://github.com/grpc/grpc-java"), nil, nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("2001:68::db8"), nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("192.0.0"), nil, false), + }, + }, + { + desc: "contains match ignore case", + sanMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(nil, nil, nil, newStringP("GRPC"), nil, true), + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + hi := NewHandshakeInfo(nil, nil) + hi.SetSANMatchers(test.sanMatchers) + + if !hi.MatchingSANExists(inputCert) { + t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v failed when expected to succeed", inputCert, test.sanMatchers) + } + }) + } +} + +func newStringP(s string) *string { + return &s +} diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go index 73931a94bca..e766ac04af2 100644 --- a/internal/envconfig/envconfig.go +++ b/internal/envconfig/envconfig.go @@ -22,6 +22,8 @@ package envconfig import ( "os" "strings" + + xdsenv "google.golang.org/grpc/internal/xds/env" ) const ( @@ -31,8 +33,8 @@ const ( ) var ( - // Retry is set if retry is explicitly enabled via "GRPC_GO_RETRY=on". - Retry = strings.EqualFold(os.Getenv(retryStr), "on") + // Retry is set if retry is explicitly enabled via "GRPC_GO_RETRY=on" or if XDS retry support is enabled. + Retry = strings.EqualFold(os.Getenv(retryStr), "on") || xdsenv.RetrySupport // TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false"). TXTErrIgnore = !strings.EqualFold(os.Getenv(txtErrIgnoreStr), "false") ) diff --git a/internal/googlecloud/googlecloud.go b/internal/googlecloud/googlecloud.go new file mode 100644 index 00000000000..d6c9e03fc4c --- /dev/null +++ b/internal/googlecloud/googlecloud.go @@ -0,0 +1,128 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package googlecloud contains internal helpful functions for google cloud. +package googlecloud + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "regexp" + "runtime" + "strings" + "sync" + + "google.golang.org/grpc/grpclog" + internalgrpclog "google.golang.org/grpc/internal/grpclog" +) + +const ( + linuxProductNameFile = "/sys/class/dmi/id/product_name" + windowsCheckCommand = "powershell.exe" + windowsCheckCommandArgs = "Get-WmiObject -Class Win32_BIOS" + powershellOutputFilter = "Manufacturer" + windowsManufacturerRegex = ":(.*)" + + logPrefix = "[googlecloud]" +) + +var ( + // The following two variables will be reassigned in tests. + runningOS = runtime.GOOS + manufacturerReader = func() (io.Reader, error) { + switch runningOS { + case "linux": + return os.Open(linuxProductNameFile) + case "windows": + cmd := exec.Command(windowsCheckCommand, windowsCheckCommandArgs) + out, err := cmd.Output() + if err != nil { + return nil, err + } + for _, line := range strings.Split(strings.TrimSuffix(string(out), "\n"), "\n") { + if strings.HasPrefix(line, powershellOutputFilter) { + re := regexp.MustCompile(windowsManufacturerRegex) + name := re.FindString(line) + name = strings.TrimLeft(name, ":") + return strings.NewReader(name), nil + } + } + return nil, errors.New("cannot determine the machine's manufacturer") + default: + return nil, fmt.Errorf("%s is not supported", runningOS) + } + } + + vmOnGCEOnce sync.Once + vmOnGCE bool + + logger = internalgrpclog.NewPrefixLogger(grpclog.Component("googlecloud"), logPrefix) +) + +// OnGCE returns whether the client is running on GCE. +// +// It provides similar functionality as metadata.OnGCE from the cloud library +// package. We keep this to avoid depending on the cloud library module. +func OnGCE() bool { + vmOnGCEOnce.Do(func() { + vmOnGCE = isRunningOnGCE() + }) + return vmOnGCE +} + +// isRunningOnGCE checks whether the local system, without doing a network request is +// running on GCP. +func isRunningOnGCE() bool { + manufacturer, err := readManufacturer() + if err != nil { + logger.Infof("failed to read manufacturer %v, returning OnGCE=false", err) + return false + } + name := string(manufacturer) + switch runningOS { + case "linux": + name = strings.TrimSpace(name) + return name == "Google" || name == "Google Compute Engine" + case "windows": + name = strings.Replace(name, " ", "", -1) + name = strings.Replace(name, "\n", "", -1) + name = strings.Replace(name, "\r", "", -1) + return name == "Google" + default: + return false + } +} + +func readManufacturer() ([]byte, error) { + reader, err := manufacturerReader() + if err != nil { + return nil, err + } + if reader == nil { + return nil, errors.New("got nil reader") + } + manufacturer, err := ioutil.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("failed reading %v: %v", linuxProductNameFile, err) + } + return manufacturer, nil +} diff --git a/internal/googlecloud/googlecloud_test.go b/internal/googlecloud/googlecloud_test.go new file mode 100644 index 00000000000..bd5a42ffab9 --- /dev/null +++ b/internal/googlecloud/googlecloud_test.go @@ -0,0 +1,86 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package googlecloud + +import ( + "io" + "os" + "strings" + "testing" +) + +func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() { + tmpOS := runningOS + tmpReader := manufacturerReader + + // Set test OS and reader function. + runningOS = testOS + manufacturerReader = reader + return func() { + runningOS = tmpOS + manufacturerReader = tmpReader + } +} + +func setup(testOS string, testReader io.Reader) func() { + reader := func() (io.Reader, error) { + return testReader, nil + } + return setupManufacturerReader(testOS, reader) +} + +func setupError(testOS string, err error) func() { + reader := func() (io.Reader, error) { + return nil, err + } + return setupManufacturerReader(testOS, reader) +} + +func TestIsRunningOnGCE(t *testing.T) { + for _, tc := range []struct { + description string + testOS string + testReader io.Reader + out bool + }{ + // Linux tests. + {"linux: not a GCP platform", "linux", strings.NewReader("not GCP"), false}, + {"Linux: GCP platform (Google)", "linux", strings.NewReader("Google"), true}, + {"Linux: GCP platform (Google Compute Engine)", "linux", strings.NewReader("Google Compute Engine"), true}, + {"Linux: GCP platform (Google Compute Engine) with extra spaces", "linux", strings.NewReader(" Google Compute Engine "), true}, + // Windows tests. + {"windows: not a GCP platform", "windows", strings.NewReader("not GCP"), false}, + {"windows: GCP platform (Google)", "windows", strings.NewReader("Google"), true}, + {"windows: GCP platform (Google) with extra spaces", "windows", strings.NewReader(" Google "), true}, + } { + reverseFunc := setup(tc.testOS, tc.testReader) + if got, want := isRunningOnGCE(), tc.out; got != want { + t.Errorf("%v: isRunningOnGCE()=%v, want %v", tc.description, got, want) + } + reverseFunc() + } +} + +func TestIsRunningOnGCENoProductNameFile(t *testing.T) { + reverseFunc := setupError("linux", os.ErrNotExist) + if isRunningOnGCE() { + t.Errorf("ErrNotExist: isRunningOnGCE()=true, want false") + } + reverseFunc() +} diff --git a/internal/grpcrand/grpcrand.go b/internal/grpcrand/grpcrand.go index 200b115ca20..740f83c2b76 100644 --- a/internal/grpcrand/grpcrand.go +++ b/internal/grpcrand/grpcrand.go @@ -31,26 +31,37 @@ var ( mu sync.Mutex ) +// Int implements rand.Int on the grpcrand global source. +func Int() int { + mu.Lock() + defer mu.Unlock() + return r.Int() +} + // Int63n implements rand.Int63n on the grpcrand global source. func Int63n(n int64) int64 { mu.Lock() - res := r.Int63n(n) - mu.Unlock() - return res + defer mu.Unlock() + return r.Int63n(n) } // Intn implements rand.Intn on the grpcrand global source. func Intn(n int) int { mu.Lock() - res := r.Intn(n) - mu.Unlock() - return res + defer mu.Unlock() + return r.Intn(n) } // Float64 implements rand.Float64 on the grpcrand global source. func Float64() float64 { mu.Lock() - res := r.Float64() - mu.Unlock() - return res + defer mu.Unlock() + return r.Float64() +} + +// Uint64 implements rand.Uint64 on the grpcrand global source. +func Uint64() uint64 { + mu.Lock() + defer mu.Unlock() + return r.Uint64() } diff --git a/internal/grpctest/tlogger.go b/internal/grpctest/tlogger.go index 95c3598d1d5..bbb2a2ff4fb 100644 --- a/internal/grpctest/tlogger.go +++ b/internal/grpctest/tlogger.go @@ -41,19 +41,34 @@ const callingFrame = 4 type logType int +func (l logType) String() string { + switch l { + case infoLog: + return "INFO" + case warningLog: + return "WARNING" + case errorLog: + return "ERROR" + case fatalLog: + return "FATAL" + } + return "UNKNOWN" +} + const ( - logLog logType = iota + infoLog logType = iota + warningLog errorLog fatalLog ) type tLogger struct { v int - t *testing.T - start time.Time initialized bool - m sync.Mutex // protects errors + mu sync.Mutex // guards t, start, and errors + t *testing.T + start time.Time errors map[*regexp.Regexp]int } @@ -76,12 +91,14 @@ func getCallingPrefix(depth int) (string, error) { // log logs the message with the specified parameters to the tLogger. func (g *tLogger) log(ltype logType, depth int, format string, args ...interface{}) { + g.mu.Lock() + defer g.mu.Unlock() prefix, err := getCallingPrefix(callingFrame + depth) if err != nil { g.t.Error(err) return } - args = append([]interface{}{prefix}, args...) + args = append([]interface{}{ltype.String() + " " + prefix}, args...) args = append(args, fmt.Sprintf(" (t=+%s)", time.Since(g.start))) if format == "" { @@ -119,14 +136,14 @@ func (g *tLogger) log(ltype logType, depth int, format string, args ...interface // Update updates the testing.T that the testing logger logs to. Should be done // before every test. It also initializes the tLogger if it has not already. func (g *tLogger) Update(t *testing.T) { + g.mu.Lock() + defer g.mu.Unlock() if !g.initialized { grpclog.SetLoggerV2(TLogger) g.initialized = true } g.t = t g.start = time.Now() - g.m.Lock() - defer g.m.Unlock() g.errors = map[*regexp.Regexp]int{} } @@ -141,20 +158,20 @@ func (g *tLogger) ExpectError(expr string) { // ExpectErrorN declares an error to be expected n times. func (g *tLogger) ExpectErrorN(expr string, n int) { + g.mu.Lock() + defer g.mu.Unlock() re, err := regexp.Compile(expr) if err != nil { g.t.Error(err) return } - g.m.Lock() - defer g.m.Unlock() g.errors[re] += n } // EndTest checks if expected errors were not encountered. func (g *tLogger) EndTest(t *testing.T) { - g.m.Lock() - defer g.m.Unlock() + g.mu.Lock() + defer g.mu.Unlock() for re, count := range g.errors { if count > 0 { t.Errorf("Expected error '%v' not encountered", re.String()) @@ -165,8 +182,6 @@ func (g *tLogger) EndTest(t *testing.T) { // expected determines if the error string is protected or not. func (g *tLogger) expected(s string) bool { - g.m.Lock() - defer g.m.Unlock() for re, count := range g.errors { if re.FindStringIndex(s) != nil { g.errors[re]-- @@ -180,35 +195,35 @@ func (g *tLogger) expected(s string) bool { } func (g *tLogger) Info(args ...interface{}) { - g.log(logLog, 0, "", args...) + g.log(infoLog, 0, "", args...) } func (g *tLogger) Infoln(args ...interface{}) { - g.log(logLog, 0, "", args...) + g.log(infoLog, 0, "", args...) } func (g *tLogger) Infof(format string, args ...interface{}) { - g.log(logLog, 0, format, args...) + g.log(infoLog, 0, format, args...) } func (g *tLogger) InfoDepth(depth int, args ...interface{}) { - g.log(logLog, depth, "", args...) + g.log(infoLog, depth, "", args...) } func (g *tLogger) Warning(args ...interface{}) { - g.log(logLog, 0, "", args...) + g.log(warningLog, 0, "", args...) } func (g *tLogger) Warningln(args ...interface{}) { - g.log(logLog, 0, "", args...) + g.log(warningLog, 0, "", args...) } func (g *tLogger) Warningf(format string, args ...interface{}) { - g.log(logLog, 0, format, args...) + g.log(warningLog, 0, format, args...) } func (g *tLogger) WarningDepth(depth int, args ...interface{}) { - g.log(logLog, depth, "", args...) + g.log(warningLog, depth, "", args...) } func (g *tLogger) Error(args ...interface{}) { diff --git a/internal/grpcutil/target_test.go b/internal/grpcutil/target_test.go deleted file mode 100644 index f6c586dd080..00000000000 --- a/internal/grpcutil/target_test.go +++ /dev/null @@ -1,114 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package grpcutil - -import ( - "testing" - - "google.golang.org/grpc/resolver" -) - -func TestParseTarget(t *testing.T) { - for _, test := range []resolver.Target{ - {Scheme: "dns", Authority: "", Endpoint: "google.com"}, - {Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com"}, - {Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com/?a=b"}, - {Scheme: "passthrough", Authority: "", Endpoint: "/unix/socket/address"}, - } { - str := test.Scheme + "://" + test.Authority + "/" + test.Endpoint - got := ParseTarget(str, false) - if got != test { - t.Errorf("ParseTarget(%q, false) = %+v, want %+v", str, got, test) - } - got = ParseTarget(str, true) - if got != test { - t.Errorf("ParseTarget(%q, true) = %+v, want %+v", str, got, test) - } - } -} - -func TestParseTargetString(t *testing.T) { - for _, test := range []struct { - targetStr string - want resolver.Target - wantWithDialer resolver.Target - }{ - {targetStr: "", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}}, - {targetStr: ":///", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}}, - {targetStr: "a:///", want: resolver.Target{Scheme: "a", Authority: "", Endpoint: ""}}, - {targetStr: "://a/", want: resolver.Target{Scheme: "", Authority: "a", Endpoint: ""}}, - {targetStr: ":///a", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a"}}, - {targetStr: "a://b/", want: resolver.Target{Scheme: "a", Authority: "b", Endpoint: ""}}, - {targetStr: "a:///b", want: resolver.Target{Scheme: "a", Authority: "", Endpoint: "b"}}, - {targetStr: "://a/b", want: resolver.Target{Scheme: "", Authority: "a", Endpoint: "b"}}, - {targetStr: "a://b/c", want: resolver.Target{Scheme: "a", Authority: "b", Endpoint: "c"}}, - {targetStr: "dns:///google.com", want: resolver.Target{Scheme: "dns", Authority: "", Endpoint: "google.com"}}, - {targetStr: "dns://a.server.com/google.com", want: resolver.Target{Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com"}}, - {targetStr: "dns://a.server.com/google.com/?a=b", want: resolver.Target{Scheme: "dns", Authority: "a.server.com", Endpoint: "google.com/?a=b"}}, - - {targetStr: "/", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "/"}}, - {targetStr: "google.com", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "google.com"}}, - {targetStr: "google.com/?a=b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "google.com/?a=b"}}, - {targetStr: "/unix/socket/address", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "/unix/socket/address"}}, - - // If we can only parse part of the target. - {targetStr: "://", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "://"}}, - {targetStr: "unix://domain", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix://domain"}}, - {targetStr: "unix://a/b/c", want: resolver.Target{Scheme: "unix", Authority: "a", Endpoint: "/b/c"}}, - {targetStr: "a:b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:b"}}, - {targetStr: "a/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a/b"}}, - {targetStr: "a:/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:/b"}}, - {targetStr: "a//b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a//b"}}, - {targetStr: "a://b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a://b"}}, - - // Unix cases without custom dialer. - // unix:[local_path], unix:[/absolute], and unix://[/absolute] have different - // behaviors with a custom dialer, to prevent behavior changes with custom dialers. - {targetStr: "unix:a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:a/b/c"}}, - {targetStr: "unix:/a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/a/b/c"}}, - {targetStr: "unix:///a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}}, - - {targetStr: "unix-abstract:a/b/c", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a/b/c"}}, - {targetStr: "unix-abstract:a b", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a b"}}, - {targetStr: "unix-abstract:a:b", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a:b"}}, - {targetStr: "unix-abstract:a-b", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "a-b"}}, - {targetStr: "unix-abstract:/ a///://::!@#$%^&*()b", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/ a///://::!@#$%^&*()b"}}, - {targetStr: "unix-abstract:passthrough:abc", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "passthrough:abc"}}, - {targetStr: "unix-abstract:unix:///abc", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "unix:///abc"}}, - {targetStr: "unix-abstract:///a/b/c", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/a/b/c"}}, - {targetStr: "unix-abstract://authority/a/b/c", want: resolver.Target{Scheme: "unix-abstract", Authority: "authority", Endpoint: "/a/b/c"}}, - {targetStr: "unix-abstract:///", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "/"}}, - {targetStr: "unix-abstract://authority", want: resolver.Target{Scheme: "unix-abstract", Authority: "", Endpoint: "//authority"}}, - - {targetStr: "passthrough:///unix:///a/b/c", want: resolver.Target{Scheme: "passthrough", Authority: "", Endpoint: "unix:///a/b/c"}}, - } { - got := ParseTarget(test.targetStr, false) - if got != test.want { - t.Errorf("ParseTarget(%q, false) = %+v, want %+v", test.targetStr, got, test.want) - } - wantWithDialer := test.wantWithDialer - if wantWithDialer == (resolver.Target{}) { - wantWithDialer = test.want - } - got = ParseTarget(test.targetStr, true) - if got != wantWithDialer { - t.Errorf("ParseTarget(%q, true) = %+v, want %+v", test.targetStr, got, wantWithDialer) - } - } -} diff --git a/internal/internal.go b/internal/internal.go index 1e2834c70f6..1b596bf3579 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -38,12 +38,6 @@ var ( // KeepaliveMinPingTime is the minimum ping interval. This must be 10s by // default, but tests may wish to set it lower for convenience. KeepaliveMinPingTime = 10 * time.Second - // NewRequestInfoContext creates a new context based on the argument context attaching - // the passed in RequestInfo to the new context. - NewRequestInfoContext interface{} // func(context.Context, credentials.RequestInfo) context.Context - // NewClientHandshakeInfoContext returns a copy of the input context with - // the passed in ClientHandshakeInfo struct added to it. - NewClientHandshakeInfoContext interface{} // func(context.Context, credentials.ClientHandshakeInfo) context.Context // ParseServiceConfigForTesting is for creating a fake // ClientConn for resolver testing only ParseServiceConfigForTesting interface{} // func(string) *serviceconfig.ParseResult @@ -65,6 +59,11 @@ var ( // gRPC server. An xDS-enabled server needs to know what type of credentials // is configured on the underlying gRPC server. This is set by server.go. GetServerCredentials interface{} // func (*grpc.Server) credentials.TransportCredentials + // DrainServerTransports initiates a graceful close of existing connections + // on a gRPC server accepted on the provided listener address. An + // xDS-enabled server invokes this method on a grpc.Server when a particular + // listener moves to "not-serving" mode. + DrainServerTransports interface{} // func(*grpc.Server, string) ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/internal/leakcheck/leakcheck.go b/internal/leakcheck/leakcheck.go index 1d4fcef994b..946c575f140 100644 --- a/internal/leakcheck/leakcheck.go +++ b/internal/leakcheck/leakcheck.go @@ -42,7 +42,6 @@ var goroutinesToIgnore = []string{ "runtime_mcall", "(*loggingT).flushDaemon", "goroutine in C code", - "httputil.DumpRequestOut", // TODO: Remove this once Go1.13 support is removed. https://github.com/golang/go/issues/37669. } // RegisterIgnoreGoroutine appends s into the ignore goroutine list. The diff --git a/internal/pretty/pretty.go b/internal/pretty/pretty.go new file mode 100644 index 00000000000..0177af4b511 --- /dev/null +++ b/internal/pretty/pretty.go @@ -0,0 +1,82 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package pretty defines helper functions to pretty-print structs for logging. +package pretty + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/golang/protobuf/jsonpb" + protov1 "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/encoding/protojson" + protov2 "google.golang.org/protobuf/proto" +) + +const jsonIndent = " " + +// ToJSON marshals the input into a json string. +// +// If marshal fails, it falls back to fmt.Sprintf("%+v"). +func ToJSON(e interface{}) string { + switch ee := e.(type) { + case protov1.Message: + mm := jsonpb.Marshaler{Indent: jsonIndent} + ret, err := mm.MarshalToString(ee) + if err != nil { + // This may fail for proto.Anys, e.g. for xDS v2, LDS, the v2 + // messages are not imported, and this will fail because the message + // is not found. + return fmt.Sprintf("%+v", ee) + } + return ret + case protov2.Message: + mm := protojson.MarshalOptions{ + Multiline: true, + Indent: jsonIndent, + } + ret, err := mm.Marshal(ee) + if err != nil { + // This may fail for proto.Anys, e.g. for xDS v2, LDS, the v2 + // messages are not imported, and this will fail because the message + // is not found. + return fmt.Sprintf("%+v", ee) + } + return string(ret) + default: + ret, err := json.MarshalIndent(ee, "", jsonIndent) + if err != nil { + return fmt.Sprintf("%+v", ee) + } + return string(ret) + } +} + +// FormatJSON formats the input json bytes with indentation. +// +// If Indent fails, it returns the unchanged input as string. +func FormatJSON(b []byte) string { + var out bytes.Buffer + err := json.Indent(&out, b, "", jsonIndent) + if err != nil { + return string(b) + } + return out.String() +} diff --git a/internal/profiling/buffer/buffer.go b/internal/profiling/buffer/buffer.go index 45745cd0919..f4cd4201de1 100644 --- a/internal/profiling/buffer/buffer.go +++ b/internal/profiling/buffer/buffer.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2019 gRPC authors. diff --git a/internal/profiling/buffer/buffer_appengine.go b/internal/profiling/buffer/buffer_appengine.go deleted file mode 100644 index c92599e5b9c..00000000000 --- a/internal/profiling/buffer/buffer_appengine.go +++ /dev/null @@ -1,43 +0,0 @@ -// +build appengine - -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package buffer - -// CircularBuffer is a no-op implementation for appengine builds. -// -// Appengine does not support stats because of lack of the support for unsafe -// pointers, which are necessary to efficiently store and retrieve things into -// and from a circular buffer. As a result, Push does not do anything and Drain -// returns an empty slice. -type CircularBuffer struct{} - -// NewCircularBuffer returns a no-op for appengine builds. -func NewCircularBuffer(size uint32) (*CircularBuffer, error) { - return nil, nil -} - -// Push returns a no-op for appengine builds. -func (cb *CircularBuffer) Push(x interface{}) { -} - -// Drain returns a no-op for appengine builds. -func (cb *CircularBuffer) Drain() []interface{} { - return nil -} diff --git a/internal/profiling/buffer/buffer_test.go b/internal/profiling/buffer/buffer_test.go index 86bd77d4a2e..a7f3b61e4af 100644 --- a/internal/profiling/buffer/buffer_test.go +++ b/internal/profiling/buffer/buffer_test.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2019 gRPC authors. diff --git a/internal/profiling/goid_modified.go b/internal/profiling/goid_modified.go index b186499cd0d..ff1a5f5933b 100644 --- a/internal/profiling/goid_modified.go +++ b/internal/profiling/goid_modified.go @@ -1,3 +1,4 @@ +//go:build grpcgoid // +build grpcgoid /* diff --git a/internal/profiling/goid_regular.go b/internal/profiling/goid_regular.go index 891c2e98f9d..042933227d8 100644 --- a/internal/profiling/goid_regular.go +++ b/internal/profiling/goid_regular.go @@ -1,3 +1,4 @@ +//go:build !grpcgoid // +build !grpcgoid /* diff --git a/internal/resolver/config_selector.go b/internal/resolver/config_selector.go index 5e7f36703d4..be7e13d5859 100644 --- a/internal/resolver/config_selector.go +++ b/internal/resolver/config_selector.go @@ -117,9 +117,12 @@ type ClientInterceptor interface { NewStream(ctx context.Context, ri RPCInfo, done func(), newStream func(ctx context.Context, done func()) (ClientStream, error)) (ClientStream, error) } -// ServerInterceptor is unimplementable; do not use. +// ServerInterceptor is an interceptor for incoming RPC's on gRPC server side. type ServerInterceptor interface { - notDefined() + // AllowRPC checks if an incoming RPC is allowed to proceed based on + // information about connection RPC was received on, and HTTP Headers. This + // information will be piped into context. + AllowRPC(ctx context.Context) error // TODO: Make this a real interceptor for filters such as rate limiting. } type csKeyType string diff --git a/internal/resolver/config_selector_test.go b/internal/resolver/config_selector_test.go index e5a50995df1..e1dae8bde27 100644 --- a/internal/resolver/config_selector_test.go +++ b/internal/resolver/config_selector_test.go @@ -48,6 +48,8 @@ func (s) TestSafeConfigSelector(t *testing.T) { retChan1 := make(chan *RPCConfig) retChan2 := make(chan *RPCConfig) + defer close(retChan1) + defer close(retChan2) one := 1 two := 2 @@ -55,8 +57,8 @@ func (s) TestSafeConfigSelector(t *testing.T) { resp1 := &RPCConfig{MethodConfig: serviceconfig.MethodConfig{MaxReqSize: &one}} resp2 := &RPCConfig{MethodConfig: serviceconfig.MethodConfig{MaxReqSize: &two}} - cs1Called := make(chan struct{}) - cs2Called := make(chan struct{}) + cs1Called := make(chan struct{}, 1) + cs2Called := make(chan struct{}, 1) cs1 := &fakeConfigSelector{ selectConfig: func(r RPCInfo) (*RPCConfig, error) { diff --git a/internal/resolver/dns/dns_resolver.go b/internal/resolver/dns/dns_resolver.go index 30423556658..75301c51491 100644 --- a/internal/resolver/dns/dns_resolver.go +++ b/internal/resolver/dns/dns_resolver.go @@ -34,6 +34,7 @@ import ( grpclbstate "google.golang.org/grpc/balancer/grpclb/state" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/resolver" @@ -46,6 +47,13 @@ var EnableSRVLookups = false var logger = grpclog.Component("dns") +// Globals to stub out in tests. TODO: Perhaps these two can be combined into a +// single variable for testing the resolver? +var ( + newTimer = time.NewTimer + newTimerDNSResRate = time.NewTimer +) + func init() { resolver.Register(NewBuilder()) } @@ -143,7 +151,6 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts d.wg.Add(1) go d.watcher() - d.ResolveNow(resolver.ResolveNowOptions{}) return d, nil } @@ -201,28 +208,38 @@ func (d *dnsResolver) Close() { func (d *dnsResolver) watcher() { defer d.wg.Done() + backoffIndex := 1 for { - select { - case <-d.ctx.Done(): - return - case <-d.rn: - } - state, err := d.lookup() if err != nil { + // Report error to the underlying grpc.ClientConn. d.cc.ReportError(err) } else { - d.cc.UpdateState(*state) + err = d.cc.UpdateState(*state) } - // Sleep to prevent excessive re-resolutions. Incoming resolution requests - // will be queued in d.rn. - t := time.NewTimer(minDNSResRate) + var timer *time.Timer + if err == nil { + // Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least + // to prevent constantly re-resolving. + backoffIndex = 1 + timer = newTimerDNSResRate(minDNSResRate) + select { + case <-d.ctx.Done(): + timer.Stop() + return + case <-d.rn: + } + } else { + // Poll on an error found in DNS Resolver or an error received from ClientConn. + timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex)) + backoffIndex++ + } select { - case <-t.C: case <-d.ctx.Done(): - t.Stop() + timer.Stop() return + case <-timer.C: } } } @@ -260,18 +277,13 @@ func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) { return newAddrs, nil } -var filterError = func(err error) error { +func handleDNSError(err error, lookupType string) error { if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary { // Timeouts and temporary errors should be communicated to gRPC to // attempt another DNS query (with backoff). Other errors should be // suppressed (they may represent the absence of a TXT record). return nil } - return err -} - -func handleDNSError(err error, lookupType string) error { - err = filterError(err) if err != nil { err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err) logger.Info(err) @@ -306,12 +318,12 @@ func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult { } func (d *dnsResolver) lookupHost() ([]resolver.Address, error) { - var newAddrs []resolver.Address addrs, err := d.resolver.LookupHost(d.ctx, d.host) if err != nil { err = handleDNSError(err, "A") return nil, err } + newAddrs := make([]resolver.Address, 0, len(addrs)) for _, a := range addrs { ip, ok := formatIP(a) if !ok { diff --git a/internal/resolver/dns/dns_resolver_test.go b/internal/resolver/dns/dns_resolver_test.go index 1c8469a275a..69307a981cf 100644 --- a/internal/resolver/dns/dns_resolver_test.go +++ b/internal/resolver/dns/dns_resolver_test.go @@ -30,9 +30,13 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/balancer" grpclbstate "google.golang.org/grpc/balancer/grpclb/state" "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/leakcheck" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) @@ -41,13 +45,15 @@ func TestMain(m *testing.M) { // Set a non-zero duration only for tests which are actually testing that // feature. replaceDNSResRate(time.Duration(0)) // No nead to clean up since we os.Exit - replaceNetFunc(nil) // No nead to clean up since we os.Exit + overrideDefaultResolver(false) // No nead to clean up since we os.Exit code := m.Run() os.Exit(code) } const ( - txtBytesLimit = 255 + txtBytesLimit = 255 + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond ) type testClientConn struct { @@ -57,13 +63,17 @@ type testClientConn struct { state resolver.State updateStateCalls int errChan chan error + updateStateErr error } -func (t *testClientConn) UpdateState(s resolver.State) { +func (t *testClientConn) UpdateState(s resolver.State) error { t.m1.Lock() defer t.m1.Unlock() t.state = s t.updateStateCalls++ + // This error determines whether DNS Resolver actually decides to exponentially backoff or not. + // This can be any error. + return t.updateStateErr } func (t *testClientConn) getState() (resolver.State, int) { @@ -99,12 +109,12 @@ type testResolver struct { // A write to this channel is made when this resolver receives a resolution // request. Tests can rely on reading from this channel to be notified about // resolution requests instead of sleeping for a predefined period of time. - ch chan struct{} + lookupHostCh *testutils.Channel } func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) { - if tr.ch != nil { - tr.ch <- struct{}{} + if tr.lookupHostCh != nil { + tr.lookupHostCh.Send(nil) } return hostLookup(host) } @@ -117,9 +127,17 @@ func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, erro return txtLookup(host) } -func replaceNetFunc(ch chan struct{}) func() { +// overrideDefaultResolver overrides the defaultResolver used by the code with +// an instance of the testResolver. pushOnLookup controls whether the +// testResolver created here pushes lookupHost events on its channel. +func overrideDefaultResolver(pushOnLookup bool) func() { oldResolver := defaultResolver - defaultResolver = &testResolver{ch: ch} + + var lookupHostCh *testutils.Channel + if pushOnLookup { + lookupHostCh = testutils.NewChannel() + } + defaultResolver = &testResolver{lookupHostCh: lookupHostCh} return func() { defaultResolver = oldResolver @@ -669,6 +687,13 @@ func TestResolve(t *testing.T) { func testDNSResolver(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(_ time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string addrWant []resolver.Address @@ -725,7 +750,7 @@ func testDNSResolver(t *testing.T) { if cnt == 0 { t.Fatalf("UpdateState not called after 2s; aborting") } - if !reflect.DeepEqual(a.addrWant, state.Addresses) { + if !cmp.Equal(a.addrWant, state.Addresses, cmpopts.EquateEmpty()) { t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) } sc := scFromState(state) @@ -736,12 +761,151 @@ func testDNSResolver(t *testing.T) { } } +// DNS Resolver immediately starts polling on an error from grpc. This should continue until the ClientConn doesn't +// send back an error from updating the DNS Resolver's state. +func TestDNSResolverExponentialBackoff(t *testing.T) { + defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + timerChan := testutils.NewChannel() + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, allows this test to call timer immediately. + t := time.NewTimer(time.Hour) + timerChan.Send(t) + return t + } + tests := []struct { + name string + target string + addrWant []resolver.Address + scWant string + }{ + { + "happy case default port", + "foo.bar.com", + []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, + generateSC("foo.bar.com"), + }, + { + "happy case specified port", + "foo.bar.com:1234", + []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, + generateSC("foo.bar.com"), + }, + { + "happy case another default port", + "srv.ipv4.single.fake", + []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, + generateSC("srv.ipv4.single.fake"), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + b := NewBuilder() + cc := &testClientConn{target: test.target} + // Cause ClientConn to return an error. + cc.updateStateErr = balancer.ErrBadResolverState + r, err := b.Build(resolver.Target{Endpoint: test.target}, cc, resolver.BuildOptions{}) + if err != nil { + t.Fatalf("Error building resolver for target %v: %v", test.target, err) + } + var state resolver.State + var cnt int + for i := 0; i < 2000; i++ { + state, cnt = cc.getState() + if cnt > 0 { + break + } + time.Sleep(time.Millisecond) + } + if cnt == 0 { + t.Fatalf("UpdateState not called after 2s; aborting") + } + if !reflect.DeepEqual(test.addrWant, state.Addresses) { + t.Errorf("Resolved addresses of target: %q = %+v, want %+v", test.target, state.Addresses, test.addrWant) + } + sc := scFromState(state) + if test.scWant != sc { + t.Errorf("Resolved service config of target: %q = %+v, want %+v", test.target, sc, test.scWant) + } + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + // Cause timer to go off 10 times, and see if it calls updateState() correctly. + for i := 0; i < 10; i++ { + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) + } + // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call + // ClientConn update state. + deadline := time.Now().Add(defaultTestTimeout) + for { + cc.m1.Lock() + got := cc.updateStateCalls + cc.m1.Unlock() + if got == 11 { + break + } + + if time.Now().After(deadline) { + t.Fatalf("Exponential backoff is not working as expected - should update state 11 times instead of %d", got) + } + + time.Sleep(time.Millisecond) + } + + // Update resolver.ClientConn to not return an error anymore - this should stop it from backing off. + cc.updateStateErr = nil + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) + // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call + // ClientConn update state the final time. The DNS Resolver should then stop polling. + deadline = time.Now().Add(defaultTestTimeout) + for { + cc.m1.Lock() + got := cc.updateStateCalls + cc.m1.Unlock() + if got == 12 { + break + } + + if time.Now().After(deadline) { + t.Fatalf("Exponential backoff is not working as expected - should stop backing off at 12 total UpdateState calls instead of %d", got) + } + + _, err := timerChan.ReceiveOrFail() + if err { + t.Fatalf("Should not poll again after Client Conn stops returning error.") + } + + time.Sleep(time.Millisecond) + } + r.Close() + }) + } +} + func testDNSResolverWithSRV(t *testing.T) { EnableSRVLookups = true defer func() { EnableSRVLookups = false }() defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(_ time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string addrWant []resolver.Address @@ -814,7 +978,7 @@ func testDNSResolverWithSRV(t *testing.T) { if cnt == 0 { t.Fatalf("UpdateState not called after 2s; aborting") } - if !reflect.DeepEqual(a.addrWant, state.Addresses) { + if !cmp.Equal(a.addrWant, state.Addresses, cmpopts.EquateEmpty()) { t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant) } gs := grpclbstate.Get(state) @@ -855,6 +1019,13 @@ func mutateTbl(target string) func() { func testDNSResolveNow(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(_ time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string addrWant []resolver.Address @@ -926,6 +1097,13 @@ const colonDefaultPort = ":" + defaultPort func testIPResolver(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(_ time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string want []resolver.Address @@ -975,6 +1153,13 @@ func testIPResolver(t *testing.T) { func TestResolveFunc(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { addr string want error @@ -1013,6 +1198,13 @@ func TestResolveFunc(t *testing.T) { func TestDisableServiceConfig(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { target string scWant string @@ -1059,6 +1251,13 @@ func TestDisableServiceConfig(t *testing.T) { func TestTXTError(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } defer func(v bool) { envconfig.TXTErrIgnore = v }(envconfig.TXTErrIgnore) for _, ignore := range []bool{false, true} { envconfig.TXTErrIgnore = ignore @@ -1090,6 +1289,13 @@ func TestTXTError(t *testing.T) { } func TestDNSResolverRetry(t *testing.T) { + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } b := NewBuilder() target := "ipv4.single.fake" cc := &testClientConn{target: target} @@ -1144,6 +1350,13 @@ func TestDNSResolverRetry(t *testing.T) { func TestCustomAuthority(t *testing.T) { defer leakcheck.Check(t) + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential backoff. + return time.NewTimer(time.Hour) + } tests := []struct { authority string @@ -1249,16 +1462,33 @@ func TestCustomAuthority(t *testing.T) { // requests. It sets the re-resolution rate to a small value and repeatedly // calls ResolveNow() and ensures only the expected number of resolution // requests are made. + func TestRateLimitedResolve(t *testing.T) { defer leakcheck.Check(t) - - const dnsResRate = 10 * time.Millisecond - dc := replaceDNSResRate(dnsResRate) - defer dc() + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, will protect from triggering exponential + // backoff. + return time.NewTimer(time.Hour) + } + defer func(nt func(d time.Duration) *time.Timer) { + newTimerDNSResRate = nt + }(newTimerDNSResRate) + + timerChan := testutils.NewChannel() + newTimerDNSResRate = func(d time.Duration) *time.Timer { + // Will never fire on its own, allows this test to call timer + // immediately. + t := time.NewTimer(time.Hour) + timerChan.Send(t) + return t + } // Create a new testResolver{} for this test because we want the exact count // of the number of times the resolver was invoked. - nc := replaceNetFunc(make(chan struct{})) + nc := overrideDefaultResolver(true) defer nc() target := "foo.bar.com" @@ -1281,55 +1511,65 @@ func TestRateLimitedResolve(t *testing.T) { t.Fatalf("delegate resolver returned unexpected type: %T\n", tr) } - // Observe the time before unblocking the lookupHost call. The 100ms rate - // limiting timer will begin immediately after that. This means the next - // resolution could happen less than 100ms if we read the time *after* - // receiving from tr.ch - start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Wait for the first resolution request to be done. This happens as part - // of the first iteration of the for loop in watcher() because we call - // ResolveNow in Build. - <-tr.ch - - // Here we start a couple of goroutines. One repeatedly calls ResolveNow() - // until asked to stop, and the other waits for two resolution requests to be - // made to our testResolver and stops the former. We measure the start and - // end times, and expect the duration elapsed to be in the interval - // {wantCalls*dnsResRate, wantCalls*dnsResRate} - done := make(chan struct{}) - go func() { - for { - select { - case <-done: - return - default: - r.ResolveNow(resolver.ResolveNowOptions{}) - time.Sleep(1 * time.Millisecond) - } - } - }() + // of the first iteration of the for loop in watcher(). + if _, err := tr.lookupHostCh.Receive(ctx); err != nil { + t.Fatalf("Timed out waiting for lookup() call.") + } - gotCalls := 0 - const wantCalls = 3 - min, max := wantCalls*dnsResRate, (wantCalls+1)*dnsResRate - tMax := time.NewTimer(max) - for gotCalls != wantCalls { - select { - case <-tr.ch: - gotCalls++ - case <-tMax.C: - t.Fatalf("Timed out waiting for %v calls after %v; got %v", wantCalls, max, gotCalls) - } + // Call Resolve Now 100 times, shouldn't continue onto next iteration of + // watcher, thus shouldn't lookup again. + for i := 0; i <= 100; i++ { + r.ResolveNow(resolver.ResolveNowOptions{}) } - close(done) - elapsed := time.Since(start) - if gotCalls != wantCalls { - t.Fatalf("resolve count mismatch for target: %q = %+v, want %+v\n", target, gotCalls, wantCalls) + continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer continueCancel() + + if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil { + t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") } - if elapsed < min { - t.Fatalf("elapsed time: %v, wanted it to be between {%v and %v}", elapsed, min, max) + + // Make the DNSMinResRate timer fire immediately (by receiving it, then + // resetting to 0), this will unblock the resolver which is currently + // blocked on the DNS Min Res Rate timer going off, which will allow it to + // continue to the next iteration of the watcher loop. + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) + + // Now that DNS Min Res Rate timer has gone off, it should lookup again. + if _, err := tr.lookupHostCh.Receive(ctx); err != nil { + t.Fatalf("Timed out waiting for lookup() call.") + } + + // Resolve Now 1000 more times, shouldn't lookup again as DNS Min Res Rate + // timer has not gone off. + for i := 0; i < 1000; i++ { + r.ResolveNow(resolver.ResolveNowOptions{}) + } + + if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil { + t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") + } + + // Make the DNSMinResRate timer fire immediately again. + timer, err = timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer = timer.(*time.Timer) + timerPointer.Reset(0) + + // Now that DNS Min Res Rate timer has gone off, it should lookup again. + if _, err = tr.lookupHostCh.Receive(ctx); err != nil { + t.Fatalf("Timed out waiting for lookup() call.") } wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}} @@ -1347,21 +1587,66 @@ func TestRateLimitedResolve(t *testing.T) { } } +// DNS Resolver immediately starts polling on an error. This will cause the re-resolution to return another error. +// Thus, test that it constantly sends errors to the grpc.ClientConn. func TestReportError(t *testing.T) { const target = "notfoundaddress" + defer func(nt func(d time.Duration) *time.Timer) { + newTimer = nt + }(newTimer) + timerChan := testutils.NewChannel() + newTimer = func(d time.Duration) *time.Timer { + // Will never fire on its own, allows this test to call timer immediately. + t := time.NewTimer(time.Hour) + timerChan.Send(t) + return t + } cc := &testClientConn{target: target, errChan: make(chan error)} + totalTimesCalledError := 0 b := NewBuilder() r, err := b.Build(resolver.Target{Endpoint: target}, cc, resolver.BuildOptions{}) if err != nil { - t.Fatalf("%v\n", err) + t.Fatalf("Error building resolver for target %v: %v", target, err) + } + // Should receive first error. + err = <-cc.errChan + if !strings.Contains(err.Error(), "hostLookup error") { + t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) } + totalTimesCalledError++ + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) defer r.Close() - select { - case err := <-cc.errChan: + + // Cause timer to go off 10 times, and see if it matches DNS Resolver updating Error. + for i := 0; i < 10; i++ { + // Should call ReportError(). + err = <-cc.errChan if !strings.Contains(err.Error(), "hostLookup error") { t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) } - case <-time.After(time.Second): - t.Fatalf("did not receive error after 1s") + totalTimesCalledError++ + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) + } + + if totalTimesCalledError != 11 { + t.Errorf("ReportError() not called 11 times, instead called %d times.", totalTimesCalledError) + } + // Clean up final watcher iteration. + <-cc.errChan + _, err = timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) } } diff --git a/internal/resolver/dns/go113.go b/internal/resolver/dns/go113.go deleted file mode 100644 index 8783a8cf821..00000000000 --- a/internal/resolver/dns/go113.go +++ /dev/null @@ -1,33 +0,0 @@ -// +build go1.13 - -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package dns - -import "net" - -func init() { - filterError = func(err error) error { - if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound { - // The name does not exist; not an error. - return nil - } - return err - } -} diff --git a/internal/serviceconfig/serviceconfig.go b/internal/serviceconfig/serviceconfig.go index bd4b8875f1a..badbdbf597f 100644 --- a/internal/serviceconfig/serviceconfig.go +++ b/internal/serviceconfig/serviceconfig.go @@ -46,6 +46,22 @@ type BalancerConfig struct { type intermediateBalancerConfig []map[string]json.RawMessage +// MarshalJSON implements the json.Marshaler interface. +// +// It marshals the balancer and config into a length-1 slice +// ([]map[string]config). +func (bc *BalancerConfig) MarshalJSON() ([]byte, error) { + if bc.Config == nil { + // If config is nil, return empty config `{}`. + return []byte(fmt.Sprintf(`[{%q: %v}]`, bc.Name, "{}")), nil + } + c, err := json.Marshal(bc.Config) + if err != nil { + return nil, err + } + return []byte(fmt.Sprintf(`[{%q: %s}]`, bc.Name, c)), nil +} + // UnmarshalJSON implements the json.Unmarshaler interface. // // ServiceConfig contains a list of loadBalancingConfigs, each with a name and @@ -62,6 +78,7 @@ func (bc *BalancerConfig) UnmarshalJSON(b []byte) error { return err } + var names []string for i, lbcfg := range ir { if len(lbcfg) != 1 { return fmt.Errorf("invalid loadBalancingConfig: entry %v does not contain exactly 1 policy/config pair: %q", i, lbcfg) @@ -76,6 +93,7 @@ func (bc *BalancerConfig) UnmarshalJSON(b []byte) error { for name, jsonCfg = range lbcfg { } + names = append(names, name) builder := balancer.Get(name) if builder == nil { // If the balancer is not registered, move on to the next config. @@ -104,7 +122,7 @@ func (bc *BalancerConfig) UnmarshalJSON(b []byte) error { // return. This means we had a loadBalancingConfig slice but did not // encounter a registered policy. The config is considered invalid in this // case. - return fmt.Errorf("invalid loadBalancingConfig: no supported policies found") + return fmt.Errorf("invalid loadBalancingConfig: no supported policies found in %v", names) } // MethodConfig defines the configuration recommended by the service providers for a diff --git a/internal/serviceconfig/serviceconfig_test.go b/internal/serviceconfig/serviceconfig_test.go index b8abaae027e..770ee2efeb8 100644 --- a/internal/serviceconfig/serviceconfig_test.go +++ b/internal/serviceconfig/serviceconfig_test.go @@ -29,16 +29,18 @@ import ( ) type testBalancerConfigType struct { - externalserviceconfig.LoadBalancingConfig + externalserviceconfig.LoadBalancingConfig `json:"-"` + + Check bool `json:"check"` } -var testBalancerConfig = testBalancerConfigType{} +var testBalancerConfig = testBalancerConfigType{Check: true} const ( testBalancerBuilderName = "test-bb" testBalancerBuilderNotParserName = "test-bb-not-parser" - testBalancerConfigJSON = `{"test-balancer-config":"true"}` + testBalancerConfigJSON = `{"check":true}` ) type testBalancerBuilder struct { @@ -133,3 +135,48 @@ func TestBalancerConfigUnmarshalJSON(t *testing.T) { }) } } + +func TestBalancerConfigMarshalJSON(t *testing.T) { + tests := []struct { + name string + bc BalancerConfig + wantJSON string + }{ + { + name: "OK", + bc: BalancerConfig{ + Name: testBalancerBuilderName, + Config: testBalancerConfig, + }, + wantJSON: fmt.Sprintf(`[{"test-bb": {"check":true}}]`), + }, + { + name: "OK config is nil", + bc: BalancerConfig{ + Name: testBalancerBuilderNotParserName, + Config: nil, // nil should be marshalled to an empty config "{}". + }, + wantJSON: fmt.Sprintf(`[{"test-bb-not-parser": {}}]`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := tt.bc.MarshalJSON() + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + if str := string(b); str != tt.wantJSON { + t.Fatalf("got str %q, want %q", str, tt.wantJSON) + } + + var bc BalancerConfig + if err := bc.UnmarshalJSON(b); err != nil { + t.Errorf("failed to mnmarshal: %v", err) + } + if !cmp.Equal(bc, tt.bc) { + t.Errorf("diff: %v", cmp.Diff(bc, tt.bc)) + } + }) + } +} diff --git a/internal/status/status.go b/internal/status/status.go index 710223b8ded..e5c6513edd1 100644 --- a/internal/status/status.go +++ b/internal/status/status.go @@ -97,7 +97,7 @@ func (s *Status) Err() error { if s.Code() == codes.OK { return nil } - return &Error{e: s.Proto()} + return &Error{s: s} } // WithDetails returns a new status with the provided details messages appended to the status. @@ -136,19 +136,23 @@ func (s *Status) Details() []interface{} { return details } +func (s *Status) String() string { + return fmt.Sprintf("rpc error: code = %s desc = %s", s.Code(), s.Message()) +} + // Error wraps a pointer of a status proto. It implements error and Status, // and a nil *Error should never be returned by this package. type Error struct { - e *spb.Status + s *Status } func (e *Error) Error() string { - return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage()) + return e.s.String() } // GRPCStatus returns the Status represented by se. func (e *Error) GRPCStatus() *Status { - return FromProto(e.e) + return e.s } // Is implements future error.Is functionality. @@ -158,5 +162,5 @@ func (e *Error) Is(target error) bool { if !ok { return false } - return proto.Equal(e.e, tse.e) + return proto.Equal(e.s.s, tse.s.s) } diff --git a/internal/syscall/syscall_linux.go b/internal/syscall/syscall_linux.go index 4b2964f2a1e..b3a72276dee 100644 --- a/internal/syscall/syscall_linux.go +++ b/internal/syscall/syscall_linux.go @@ -1,5 +1,3 @@ -// +build !appengine - /* * * Copyright 2018 gRPC authors. diff --git a/internal/syscall/syscall_nonlinux.go b/internal/syscall/syscall_nonlinux.go index 7913ef1dbfb..999f52cd75b 100644 --- a/internal/syscall/syscall_nonlinux.go +++ b/internal/syscall/syscall_nonlinux.go @@ -1,4 +1,5 @@ -// +build !linux appengine +//go:build !linux +// +build !linux /* * @@ -35,41 +36,41 @@ var logger = grpclog.Component("core") func log() { once.Do(func() { - logger.Info("CPU time info is unavailable on non-linux or appengine environment.") + logger.Info("CPU time info is unavailable on non-linux environments.") }) } -// GetCPUTime returns the how much CPU time has passed since the start of this process. -// It always returns 0 under non-linux or appengine environment. +// GetCPUTime returns the how much CPU time has passed since the start of this +// process. It always returns 0 under non-linux environments. func GetCPUTime() int64 { log() return 0 } -// Rusage is an empty struct under non-linux or appengine environment. +// Rusage is an empty struct under non-linux environments. type Rusage struct{} -// GetRusage is a no-op function under non-linux or appengine environment. +// GetRusage is a no-op function under non-linux environments. func GetRusage() *Rusage { log() return nil } // CPUTimeDiff returns the differences of user CPU time and system CPU time used -// between two Rusage structs. It a no-op function for non-linux or appengine environment. +// between two Rusage structs. It a no-op function for non-linux environments. func CPUTimeDiff(first *Rusage, latest *Rusage) (float64, float64) { log() return 0, 0 } -// SetTCPUserTimeout is a no-op function under non-linux or appengine environments +// SetTCPUserTimeout is a no-op function under non-linux environments. func SetTCPUserTimeout(conn net.Conn, timeout time.Duration) error { log() return nil } -// GetTCPUserTimeout is a no-op function under non-linux or appengine environments -// a negative return value indicates the operation is not supported +// GetTCPUserTimeout is a no-op function under non-linux environments. +// A negative return value indicates the operation is not supported func GetTCPUserTimeout(conn net.Conn) (int, error) { log() return -1, nil diff --git a/credentials/tls/certprovider/meshca/logging.go b/internal/testutils/marshal_any.go similarity index 55% rename from credentials/tls/certprovider/meshca/logging.go rename to internal/testutils/marshal_any.go index ae20059c4f7..9ddef6de15d 100644 --- a/credentials/tls/certprovider/meshca/logging.go +++ b/internal/testutils/marshal_any.go @@ -1,8 +1,6 @@ -// +build go1.13 - /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,22 +13,24 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * */ -package meshca +package testutils import ( "fmt" - "google.golang.org/grpc/grpclog" - internalgrpclog "google.golang.org/grpc/internal/grpclog" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/protobuf/types/known/anypb" ) -const prefix = "[%p] " - -var logger = grpclog.Component("meshca") - -func prefixLogger(p *providerPlugin) *internalgrpclog.PrefixLogger { - return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) +// MarshalAny is a convenience function to marshal protobuf messages into any +// protos. It will panic if the marshaling fails. +func MarshalAny(m proto.Message) *anypb.Any { + a, err := ptypes.MarshalAny(m) + if err != nil { + panic(fmt.Sprintf("ptypes.MarshalAny(%+v) failed: %v", m, err)) + } + return a } diff --git a/internal/transport/controlbuf.go b/internal/transport/controlbuf.go index 40ef23923fd..8394d252df0 100644 --- a/internal/transport/controlbuf.go +++ b/internal/transport/controlbuf.go @@ -20,13 +20,17 @@ package transport import ( "bytes" + "errors" "fmt" "runtime" + "strconv" "sync" "sync/atomic" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/status" ) var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { @@ -128,6 +132,15 @@ type cleanupStream struct { func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM +type earlyAbortStream struct { + httpStatus uint32 + streamID uint32 + contentSubtype string + status *status.Status +} + +func (*earlyAbortStream) isTransportResponseFrame() bool { return false } + type dataFrame struct { streamID uint32 endStream bool @@ -284,7 +297,7 @@ type controlBuffer struct { // closed and nilled when transportResponseFrames drops below the // threshold. Both fields are protected by mu. transportResponseFrames int - trfChan atomic.Value // *chan struct{} + trfChan atomic.Value // chan struct{} } func newControlBuffer(done <-chan struct{}) *controlBuffer { @@ -298,10 +311,10 @@ func newControlBuffer(done <-chan struct{}) *controlBuffer { // throttle blocks if there are too many incomingSettings/cleanupStreams in the // controlbuf. func (c *controlBuffer) throttle() { - ch, _ := c.trfChan.Load().(*chan struct{}) + ch, _ := c.trfChan.Load().(chan struct{}) if ch != nil { select { - case <-*ch: + case <-ch: case <-c.done: } } @@ -335,8 +348,7 @@ func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it cbItem) (b if c.transportResponseFrames == maxQueuedTransportResponseFrames { // We are adding the frame that puts us over the threshold; create // a throttling channel. - ch := make(chan struct{}) - c.trfChan.Store(&ch) + c.trfChan.Store(make(chan struct{})) } } c.mu.Unlock() @@ -377,9 +389,9 @@ func (c *controlBuffer) get(block bool) (interface{}, error) { if c.transportResponseFrames == maxQueuedTransportResponseFrames { // We are removing the frame that put us over the // threshold; close and clear the throttling channel. - ch := c.trfChan.Load().(*chan struct{}) - close(*ch) - c.trfChan.Store((*chan struct{})(nil)) + ch := c.trfChan.Load().(chan struct{}) + close(ch) + c.trfChan.Store((chan struct{})(nil)) } c.transportResponseFrames-- } @@ -395,7 +407,6 @@ func (c *controlBuffer) get(block bool) (interface{}, error) { select { case <-c.ch: case <-c.done: - c.finish() return nil, ErrConnClosing } } @@ -420,6 +431,14 @@ func (c *controlBuffer) finish() { hdr.onOrphaned(ErrConnClosing) } } + // In case throttle() is currently in flight, it needs to be unblocked. + // Otherwise, the transport may not close, since the transport is closed by + // the reader encountering the connection error. + ch, _ := c.trfChan.Load().(chan struct{}) + if ch != nil { + close(ch) + } + c.trfChan.Store((chan struct{})(nil)) c.mu.Unlock() } @@ -749,6 +768,27 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { return nil } +func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error { + if l.side == clientSide { + return errors.New("earlyAbortStream not handled on client") + } + // In case the caller forgets to set the http status, default to 200. + if eas.httpStatus == 0 { + eas.httpStatus = 200 + } + headerFields := []hpack.HeaderField{ + {Name: ":status", Value: strconv.Itoa(int(eas.httpStatus))}, + {Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)}, + {Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))}, + {Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())}, + } + + if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil { + return err + } + return nil +} + func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error { if l.side == clientSide { l.draining = true @@ -787,6 +827,8 @@ func (l *loopyWriter) handle(i interface{}) error { return l.registerStreamHandler(i) case *cleanupStream: return l.cleanupStreamHandler(i) + case *earlyAbortStream: + return l.earlyAbortStreamHandler(i) case *incomingGoAway: return l.incomingGoAwayHandler(i) case *dataFrame: diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 05d3871e628..1c3459c2b4c 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -141,9 +141,8 @@ type serverHandlerTransport struct { stats stats.Handler } -func (ht *serverHandlerTransport) Close() error { +func (ht *serverHandlerTransport) Close() { ht.closeOnce.Do(ht.closeCloseChanOnce) - return nil } func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) } diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index f9efdfb0716..b08dcaaf3c4 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -62,7 +62,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { ProtoMajor: 2, Method: "GET", Header: http.Header{}, - RequestURI: "/", }, wantErr: "invalid gRPC request method", }, @@ -74,7 +73,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { Header: http.Header{ "Content-Type": {"application/foo"}, }, - RequestURI: "/service/foo.bar", }, wantErr: "invalid gRPC request content-type", }, @@ -86,7 +84,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { Header: http.Header{ "Content-Type": {"application/grpc"}, }, - RequestURI: "/service/foo.bar", }, modrw: func(w http.ResponseWriter) http.ResponseWriter { // Return w without its Flush method @@ -109,7 +106,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", }, check: func(t *serverHandlerTransport, tt *testCase) error { if t.req != tt.req { @@ -133,7 +129,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", }, check: func(t *serverHandlerTransport, tt *testCase) error { if !t.timeoutSet { @@ -157,7 +152,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", }, wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`, }, @@ -175,7 +169,6 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", }, check: func(ht *serverHandlerTransport, tt *testCase) error { want := metadata.MD{ @@ -247,8 +240,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", - Body: bodyr, + Body: bodyr, } rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) ht, err := NewServerHandlerTransport(rw, req, nil) @@ -359,8 +351,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { URL: &url.URL{ Path: "/service/foo.bar", }, - RequestURI: "/service/foo.bar", - Body: bodyr, + Body: bodyr, } rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) ht, err := NewServerHandlerTransport(rw, req, nil) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index a76310c6e13..e3203adbd0b 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -24,6 +24,7 @@ import ( "io" "math" "net" + "net/http" "strconv" "strings" "sync" @@ -32,15 +33,14 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" - "google.golang.org/grpc/internal/grpcutil" - imetadata "google.golang.org/grpc/internal/metadata" - "google.golang.org/grpc/internal/transport/networktype" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" + icredentials "google.golang.org/grpc/internal/credentials" + "google.golang.org/grpc/internal/grpcutil" + imetadata "google.golang.org/grpc/internal/metadata" "google.golang.org/grpc/internal/syscall" + "google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -116,6 +116,9 @@ type http2Client struct { // goAwayReason records the http2.ErrCode and debug data received with the // GoAway frame. goAwayReason GoAwayReason + // goAwayDebugMessage contains a detailed human readable string about a + // GoAway frame, useful for error messages. + goAwayDebugMessage string // A condition variable used to signal when the keepalive goroutine should // go dormant. The condition for dormancy is based on the number of active // streams and the `PermitWithoutStream` keepalive client parameter. And @@ -238,9 +241,16 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // Attributes field of resolver.Address, which is shoved into connectCtx // and passed to the credential handshaker. This makes it possible for // address specific arbitrary data to reach the credential handshaker. - contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) - connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) - conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn) + connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) + rawConn := conn + // Pull the deadline from the connectCtx, which will be used for + // timeouts in the authentication protocol handshake. Can ignore the + // boolean as the deadline will return the zero value, which will make + // the conn not timeout on I/O operations. + deadline, _ := connectCtx.Deadline() + rawConn.SetDeadline(deadline) + conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, rawConn) + rawConn.SetDeadline(time.Time{}) if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } @@ -347,12 +357,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // Send connection preface to server. n, err := t.conn.Write(clientPreface) if err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + err = connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + t.Close(err) + return nil, err } if n != len(clientPreface) { - t.Close() - return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + err = connectionErrorf(true, nil, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + t.Close(err) + return nil, err } var ss []http2.Setting @@ -370,14 +382,16 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts } err = t.framer.fr.WriteSettings(ss...) if err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + err = connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + t.Close(err) + return nil, err } // Adjust the connection flow control window if needed. if delta := uint32(icwz - defaultWindowSize); delta > 0 { if err := t.framer.fr.WriteWindowUpdate(0, delta); err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err) + err = connectionErrorf(true, err, "transport: failed to write window update: %v", err) + t.Close(err) + return nil, err } } @@ -403,11 +417,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts logger.Errorf("transport: loopyWriter.run returning. Err: %v", err) } } - // If it's a connection error, let reader goroutine handle it - // since there might be data in the buffers. - if _, ok := err.(net.Error); !ok { - t.conn.Close() - } + // Do not close the transport. Let reader goroutine handle it since + // there might be data in the buffers. + t.conn.Close() + t.controlBuf.finish() close(t.writerDone) }() return t, nil @@ -463,7 +476,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) Method: callHdr.Method, AuthInfo: t.authInfo, } - ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) + ctxWithRequestInfo := icredentials.NewRequestInfoContext(ctx, ri) authData, err := t.getTrAuthData(ctxWithRequestInfo, aud) if err != nil { return nil, err @@ -612,26 +625,35 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call return callAuthData, nil } -// PerformedIOError wraps an error to indicate IO may have been performed -// before the error occurred. -type PerformedIOError struct { +// NewStreamError wraps an error and reports additional information. Typically +// NewStream errors result in transparent retry, as they mean nothing went onto +// the wire. However, there are two notable exceptions: +// +// 1. If the stream headers violate the max header list size allowed by the +// server. In this case there is no reason to retry at all, as it is +// assumed the RPC would continue to fail on subsequent attempts. +// 2. If the credentials errored when requesting their headers. In this case, +// it's possible a retry can fix the problem, but indefinitely transparently +// retrying is not appropriate as it is likely the credentials, if they can +// eventually succeed, would need I/O to do so. +type NewStreamError struct { Err error + + DoNotRetry bool + DoNotTransparentRetry bool } -// Error implements error. -func (p PerformedIOError) Error() string { - return p.Err.Error() +func (e NewStreamError) Error() string { + return e.Err.Error() } // NewStream creates a stream and registers it into the transport as "active" -// streams. +// streams. All non-nil errors returned will be *NewStreamError. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { ctx = peer.NewContext(ctx, t.getPeer()) headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { - // We may have performed I/O in the per-RPC creds callback, so do not - // allow transparent retry. - return nil, PerformedIOError{err} + return nil, &NewStreamError{Err: err, DoNotTransparentRetry: true} } s := t.newStream(ctx, callHdr) cleanup := func(err error) { @@ -731,23 +753,23 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return true }, hdr) if err != nil { - return nil, err + return nil, &NewStreamError{Err: err} } if success { break } if hdrListSizeErr != nil { - return nil, hdrListSizeErr + return nil, &NewStreamError{Err: hdrListSizeErr, DoNotRetry: true} } firstTry = false select { case <-ch: - case <-s.ctx.Done(): - return nil, ContextErr(s.ctx.Err()) + case <-ctx.Done(): + return nil, &NewStreamError{Err: ContextErr(ctx.Err())} case <-t.goAway: - return nil, errStreamDrain + return nil, &NewStreamError{Err: errStreamDrain} case <-t.ctx.Done(): - return nil, ErrConnClosing + return nil, &NewStreamError{Err: ErrConnClosing} } } if t.statsHandler != nil { @@ -854,12 +876,12 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // This method blocks until the addrConn that initiated this transport is // re-connected. This happens because t.onClose() begins reconnect logic at the // addrConn level and blocks until the addrConn is successfully connected. -func (t *http2Client) Close() error { +func (t *http2Client) Close(err error) { t.mu.Lock() // Make sure we only Close once. if t.state == closing { t.mu.Unlock() - return nil + return } // Call t.onClose before setting the state to closing to prevent the client // from attempting to create new streams ASAP. @@ -875,13 +897,25 @@ func (t *http2Client) Close() error { t.mu.Unlock() t.controlBuf.finish() t.cancel() - err := t.conn.Close() + t.conn.Close() if channelz.IsOn() { channelz.RemoveEntry(t.channelzID) } + // Append info about previous goaways if there were any, since this may be important + // for understanding the root cause for this connection to be closed. + _, goAwayDebugMessage := t.GetGoAwayReason() + + var st *status.Status + if len(goAwayDebugMessage) > 0 { + st = status.Newf(codes.Unavailable, "closing transport due to: %v, received prior goaway: %v", err, goAwayDebugMessage) + err = st.Err() + } else { + st = status.New(codes.Unavailable, err.Error()) + } + // Notify all active streams. for _, s := range streams { - t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, status.New(codes.Unavailable, ErrConnClosing.Desc), nil, false) + t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false) } if t.statsHandler != nil { connEnd := &stats.ConnEnd{ @@ -889,7 +923,6 @@ func (t *http2Client) Close() error { } t.statsHandler.HandleConn(t.ctx, connEnd) } - return err } // GracefulClose sets the state to draining, which prevents new streams from @@ -908,7 +941,7 @@ func (t *http2Client) GracefulClose() { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close() + t.Close(ErrConnClosing) return } t.controlBuf.put(&incomingGoAway{}) @@ -1049,7 +1082,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { } // The server has closed the stream without sending trailers. Record that // the read direction is closed, and set the status appropriately. - if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { + if f.StreamEnded() { t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) } } @@ -1154,9 +1187,9 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { } } id := f.LastStreamID - if id > 0 && id%2 != 1 { + if id > 0 && id%2 == 0 { t.mu.Unlock() - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered numbered stream id: %v", id)) return } // A client can receive multiple GoAways from the server (see @@ -1174,7 +1207,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)) return } default: @@ -1204,7 +1237,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) } } @@ -1220,12 +1253,17 @@ func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) { t.goAwayReason = GoAwayTooManyPings } } + if len(f.DebugData()) == 0 { + t.goAwayDebugMessage = fmt.Sprintf("code: %s", f.ErrCode) + } else { + t.goAwayDebugMessage = fmt.Sprintf("code: %s, debug data: %q", f.ErrCode, string(f.DebugData())) + } } -func (t *http2Client) GetGoAwayReason() GoAwayReason { +func (t *http2Client) GetGoAwayReason() (GoAwayReason, string) { t.mu.Lock() defer t.mu.Unlock() - return t.goAwayReason + return t.goAwayReason, t.goAwayDebugMessage } func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { @@ -1252,35 +1290,128 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { return } - state := &decodeState{} - // Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode. - state.data.isGRPC = !initialHeader - if h2code, err := state.decodeHeader(frame); err != nil { - t.closeStream(s, err, true, h2code, status.Convert(err), nil, endStream) + // frame.Truncated is set to true when framer detects that the current header + // list size hits MaxHeaderListSize limit. + if frame.Truncated { + se := status.New(codes.Internal, "peer header list size exceeded limit") + t.closeStream(s, se.Err(), true, http2.ErrCodeFrameSize, se, nil, endStream) return } - isHeader := false - defer func() { - if t.statsHandler != nil { - if isHeader { - inHeader := &stats.InHeader{ - Client: true, - WireLength: int(frame.Header().Length), - Header: s.header.Copy(), - Compression: s.recvCompress, - } - t.statsHandler.HandleRPC(s.ctx, inHeader) - } else { - inTrailer := &stats.InTrailer{ - Client: true, - WireLength: int(frame.Header().Length), - Trailer: s.trailer.Copy(), - } - t.statsHandler.HandleRPC(s.ctx, inTrailer) + var ( + // If a gRPC Response-Headers has already been received, then it means + // that the peer is speaking gRPC and we are in gRPC mode. + isGRPC = !initialHeader + mdata = make(map[string][]string) + contentTypeErr = "malformed header: missing HTTP content-type" + grpcMessage string + statusGen *status.Status + recvCompress string + httpStatusCode *int + httpStatusErr string + rawStatusCode = codes.Unknown + // headerError is set if an error is encountered while parsing the headers + headerError string + ) + + if initialHeader { + httpStatusErr = "malformed header: missing HTTP status" + } + + for _, hf := range frame.Fields { + switch hf.Name { + case "content-type": + if _, validContentType := grpcutil.ContentSubtype(hf.Value); !validContentType { + contentTypeErr = fmt.Sprintf("transport: received unexpected content-type %q", hf.Value) + break + } + contentTypeErr = "" + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) + isGRPC = true + case "grpc-encoding": + recvCompress = hf.Value + case "grpc-status": + code, err := strconv.ParseInt(hf.Value, 10, 32) + if err != nil { + se := status.New(codes.Internal, fmt.Sprintf("transport: malformed grpc-status: %v", err)) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return } + rawStatusCode = codes.Code(uint32(code)) + case "grpc-message": + grpcMessage = decodeGrpcMessage(hf.Value) + case "grpc-status-details-bin": + var err error + statusGen, err = decodeGRPCStatusDetails(hf.Value) + if err != nil { + headerError = fmt.Sprintf("transport: malformed grpc-status-details-bin: %v", err) + } + case ":status": + if hf.Value == "200" { + httpStatusErr = "" + statusCode := 200 + httpStatusCode = &statusCode + break + } + + c, err := strconv.ParseInt(hf.Value, 10, 32) + if err != nil { + se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err)) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return + } + statusCode := int(c) + httpStatusCode = &statusCode + + httpStatusErr = fmt.Sprintf( + "unexpected HTTP status code received from server: %d (%s)", + statusCode, + http.StatusText(statusCode), + ) + default: + if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { + break + } + v, err := decodeMetadataHeader(hf.Name, hf.Value) + if err != nil { + headerError = fmt.Sprintf("transport: malformed %s: %v", hf.Name, err) + logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err) + break + } + mdata[hf.Name] = append(mdata[hf.Name], v) } - }() + } + + if !isGRPC || httpStatusErr != "" { + var code = codes.Internal // when header does not include HTTP status, return INTERNAL + + if httpStatusCode != nil { + var ok bool + code, ok = HTTPStatusConvTab[*httpStatusCode] + if !ok { + code = codes.Unknown + } + } + var errs []string + if httpStatusErr != "" { + errs = append(errs, httpStatusErr) + } + if contentTypeErr != "" { + errs = append(errs, contentTypeErr) + } + // Verify the HTTP response is a 200. + se := status.New(code, strings.Join(errs, "; ")) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return + } + + if headerError != "" { + se := status.New(codes.Internal, headerError) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return + } + + isHeader := false // If headerChan hasn't been closed yet if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { @@ -1291,9 +1422,9 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { // These values can be set without any synchronization because // stream goroutine will read it only after seeing a closed // headerChan which we'll close after setting this. - s.recvCompress = state.data.encoding - if len(state.data.mdata) > 0 { - s.header = state.data.mdata + s.recvCompress = recvCompress + if len(mdata) > 0 { + s.header = mdata } } else { // HEADERS frame block carries a Trailers-Only. @@ -1302,13 +1433,36 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { close(s.headerChan) } + if t.statsHandler != nil { + if isHeader { + inHeader := &stats.InHeader{ + Client: true, + WireLength: int(frame.Header().Length), + Header: metadata.MD(mdata).Copy(), + Compression: s.recvCompress, + } + t.statsHandler.HandleRPC(s.ctx, inHeader) + } else { + inTrailer := &stats.InTrailer{ + Client: true, + WireLength: int(frame.Header().Length), + Trailer: metadata.MD(mdata).Copy(), + } + t.statsHandler.HandleRPC(s.ctx, inTrailer) + } + } + if !endStream { return } + if statusGen == nil { + statusGen = status.New(rawStatusCode, grpcMessage) + } + // if client received END_STREAM from server while stream was still active, send RST_STREAM rst := s.getState() == streamActive - t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true) + t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, statusGen, mdata, true) } // reader runs as a separate goroutine in charge of reading data from network @@ -1322,7 +1476,8 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.fr.ReadFrame() if err != nil { - t.Close() // this kicks off resetTransport, so must be last before return + err = connectionErrorf(true, err, "error reading server preface: %v", err) + t.Close(err) // this kicks off resetTransport, so must be last before return return } t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!) @@ -1331,7 +1486,8 @@ func (t *http2Client) reader() { } sf, ok := frame.(*http2.SettingsFrame) if !ok { - t.Close() // this kicks off resetTransport, so must be last before return + // this kicks off resetTransport, so must be last before return + t.Close(connectionErrorf(true, nil, "initial http2 frame from server is not a settings frame: %T", frame)) return } t.onPrefaceReceipt() @@ -1367,7 +1523,7 @@ func (t *http2Client) reader() { continue } else { // Transport error. - t.Close() + t.Close(connectionErrorf(true, err, "error reading from server: %v", err)) return } } @@ -1426,7 +1582,7 @@ func (t *http2Client) keepalive() { continue } if outstandingPing && timeoutLeft <= 0 { - t.Close() + t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")) return } t.mu.Lock() diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 7c6c89d4f9b..f2cad9ebc31 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -102,11 +102,11 @@ type http2Server struct { mu sync.Mutex // guard the following - // drainChan is initialized when drain(...) is called the first time. + // drainChan is initialized when Drain() is called the first time. // After which the server writes out the first GoAway(with ID 2^31-1) frame. // Then an independent goroutine will be launched to later send the second GoAway. // During this time we don't want to write another first GoAway(with ID 2^31 -1) frame. - // Thus call to drain(...) will be a no-op if drainChan is already initialized since draining is + // Thus call to Drain() will be a no-op if drainChan is already initialized since draining is // already underway. drainChan chan struct{} state transportState @@ -125,9 +125,30 @@ type http2Server struct { connectionID uint64 } -// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is -// returned if something goes wrong. -func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { +// NewServerTransport creates a http2 transport with conn and configuration +// options from config. +// +// It returns a non-nil transport and a nil error on success. On failure, it +// returns a nil transport and a non-nil error. For a special case where the +// underlying conn gets closed before the client preface could be read, it +// returns a nil transport and a nil error. +func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { + var authInfo credentials.AuthInfo + rawConn := conn + if config.Credentials != nil { + var err error + conn, authInfo, err = config.Credentials.ServerHandshake(rawConn) + if err != nil { + // ErrConnDispatched means that the connection was dispatched away + // from gRPC; those connections should be left open. io.EOF means + // the connection was closed before handshaking completed, which can + // happen naturally from probers. Return these errors directly. + if err == credentials.ErrConnDispatched || err == io.EOF { + return nil, err + } + return nil, connectionErrorf(false, err, "ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) + } + } writeBufSize := config.WriteBufferSize readBufSize := config.ReadBufferSize maxHeaderListSize := defaultServerMaxHeaderListSize @@ -210,14 +231,15 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err if kep.MinTime == 0 { kep.MinTime = defaultKeepalivePolicyMinTime } + done := make(chan struct{}) t := &http2Server{ - ctx: context.Background(), + ctx: setConnection(context.Background(), rawConn), done: done, conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), - authInfo: config.AuthInfo, + authInfo: authInfo, framer: framer, readerDone: make(chan struct{}), writerDone: make(chan struct{}), @@ -266,6 +288,14 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // Check the validity of client preface. preface := make([]byte, len(clientPreface)) if _, err := io.ReadFull(t.conn, preface); err != nil { + // In deployments where a gRPC server runs behind a cloud load balancer + // which performs regular TCP level health checks, the connection is + // closed immediately by the latter. Returning io.EOF here allows the + // grpc server implementation to recognize this scenario and suppress + // logging to reduce spam. + if err == io.EOF { + return nil, io.EOF + } return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to receive the preface from client: %v", err) } if !bytes.Equal(preface, clientPreface) { @@ -295,6 +325,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err } } t.conn.Close() + t.controlBuf.finish() close(t.writerDone) }() go t.keepalive() @@ -304,37 +335,131 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) { streamID := frame.Header().StreamID - state := &decodeState{ - serverSide: true, - } - if h2code, err := state.decodeHeader(frame); err != nil { - if _, ok := status.FromError(err); ok { - t.controlBuf.put(&cleanupStream{ - streamID: streamID, - rst: true, - rstCode: h2code, - onWrite: func() {}, - }) - } + + // frame.Truncated is set to true when framer detects that the current header + // list size hits MaxHeaderListSize limit. + if frame.Truncated { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeFrameSize, + onWrite: func() {}, + }) return false } buf := newRecvBuffer() s := &Stream{ - id: streamID, - st: t, - buf: buf, - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - recvCompress: state.data.encoding, - method: state.data.method, - contentSubtype: state.data.contentSubtype, + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + } + + var ( + // If a gRPC Response-Headers has already been received, then it means + // that the peer is speaking gRPC and we are in gRPC mode. + isGRPC = false + mdata = make(map[string][]string) + httpMethod string + // headerError is set if an error is encountered while parsing the headers + headerError bool + + timeoutSet bool + timeout time.Duration + ) + + for _, hf := range frame.Fields { + switch hf.Name { + case "content-type": + contentSubtype, validContentType := grpcutil.ContentSubtype(hf.Value) + if !validContentType { + break + } + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) + s.contentSubtype = contentSubtype + isGRPC = true + case "grpc-encoding": + s.recvCompress = hf.Value + case ":method": + httpMethod = hf.Value + case ":path": + s.method = hf.Value + case "grpc-timeout": + timeoutSet = true + var err error + if timeout, err = decodeTimeout(hf.Value); err != nil { + headerError = true + } + // "Transports must consider requests containing the Connection header + // as malformed." - A41 + case "connection": + if logger.V(logLevel) { + logger.Errorf("transport: http2Server.operateHeaders parsed a :connection header which makes a request malformed as per the HTTP/2 spec") + } + headerError = true + default: + if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { + break + } + v, err := decodeMetadataHeader(hf.Name, hf.Value) + if err != nil { + headerError = true + logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err) + break + } + mdata[hf.Name] = append(mdata[hf.Name], v) + } + } + + // "If multiple Host headers or multiple :authority headers are present, the + // request must be rejected with an HTTP status code 400 as required by Host + // validation in RFC 7230 §5.4, gRPC status code INTERNAL, or RST_STREAM + // with HTTP/2 error code PROTOCOL_ERROR." - A41. Since this is a HTTP/2 + // error, this takes precedence over a client not speaking gRPC. + if len(mdata[":authority"]) > 1 || len(mdata["host"]) > 1 { + errMsg := fmt.Sprintf("num values of :authority: %v, num values of host: %v, both must only have 1 value as per HTTP/2 spec", len(mdata[":authority"]), len(mdata["host"])) + if logger.V(logLevel) { + logger.Errorf("transport: %v", errMsg) + } + t.controlBuf.put(&earlyAbortStream{ + httpStatus: 400, + streamID: streamID, + contentSubtype: s.contentSubtype, + status: status.New(codes.Internal, errMsg), + }) + return false + } + + if !isGRPC || headerError { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeProtocol, + onWrite: func() {}, + }) + return false + } + + // "If :authority is missing, Host must be renamed to :authority." - A41 + if len(mdata[":authority"]) == 0 { + // No-op if host isn't present, no eventual :authority header is a valid + // RPC. + if host, ok := mdata["host"]; ok { + mdata[":authority"] = host + delete(mdata, "host") + } + } else { + // "If :authority is present, Host must be discarded" - A41 + delete(mdata, "host") } + if frame.StreamEnded() { // s is just created by the caller. No lock needed. s.state = streamReadDone } - if state.data.timeoutSet { - s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout) + if timeoutSet { + s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout) } else { s.ctx, s.cancel = context.WithCancel(t.ctx) } @@ -347,33 +472,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } s.ctx = peer.NewContext(s.ctx, pr) // Attach the received metadata to the context. - if len(state.data.mdata) > 0 { - s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata) - } - if state.data.statsTags != nil { - s.ctx = stats.SetIncomingTags(s.ctx, state.data.statsTags) - } - if state.data.statsTrace != nil { - s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace) - } - if t.inTapHandle != nil { - var err error - info := &tap.Info{ - FullMethodName: state.data.method, + if len(mdata) > 0 { + s.ctx = metadata.NewIncomingContext(s.ctx, mdata) + if statsTags := mdata["grpc-tags-bin"]; len(statsTags) > 0 { + s.ctx = stats.SetIncomingTags(s.ctx, []byte(statsTags[len(statsTags)-1])) } - s.ctx, err = t.inTapHandle(s.ctx, info) - if err != nil { - if logger.V(logLevel) { - logger.Warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) - } - t.controlBuf.put(&cleanupStream{ - streamID: s.id, - rst: true, - rstCode: http2.ErrCodeRefusedStream, - onWrite: func() {}, - }) - s.cancel() - return false + if statsTrace := mdata["grpc-trace-bin"]; len(statsTrace) > 0 { + s.ctx = stats.SetIncomingTrace(s.ctx, []byte(statsTrace[len(statsTrace)-1])) } } t.mu.Lock() @@ -403,10 +508,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return true } t.maxStreamID = streamID - if state.data.httpMethod != http.MethodPost { + if httpMethod != http.MethodPost { t.mu.Unlock() if logger.V(logLevel) { - logger.Warningf("transport: http2Server.operateHeaders parsed a :method field: %v which should be POST", state.data.httpMethod) + logger.Infof("transport: http2Server.operateHeaders parsed a :method field: %v which should be POST", httpMethod) } t.controlBuf.put(&cleanupStream{ streamID: streamID, @@ -417,6 +522,26 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.cancel() return false } + if t.inTapHandle != nil { + var err error + if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: s.method}); err != nil { + t.mu.Unlock() + if logger.V(logLevel) { + logger.Infof("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) + } + stat, ok := status.FromError(err) + if !ok { + stat = status.New(codes.PermissionDenied, err.Error()) + } + t.controlBuf.put(&earlyAbortStream{ + httpStatus: 200, + streamID: s.id, + contentSubtype: s.contentSubtype, + status: stat, + }) + return false + } + } t.activeStreams[streamID] = s if len(t.activeStreams) == 1 { t.idle = time.Time{} @@ -438,7 +563,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( LocalAddr: t.localAddr, Compression: s.recvCompress, WireLength: int(frame.Header().Length), - Header: metadata.MD(state.data.mdata).Copy(), + Header: metadata.MD(mdata).Copy(), } t.stats.HandleRPC(s.ctx, inHeader) } @@ -650,7 +775,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { s.write(recvMsg{buffer: buffer}) } } - if f.Header().Flags.Has(http2.FlagDataEndStream) { + if f.StreamEnded() { // Received the end of stream from the client. s.compareAndSwapState(streamActive, streamReadDone) s.write(recvMsg{err: io.EOF}) @@ -1005,12 +1130,12 @@ func (t *http2Server) keepalive() { if val <= 0 { // The connection has been idle for a duration of keepalive.MaxConnectionIdle or more. // Gracefully close the connection. - t.drain(http2.ErrCodeNo, []byte{}) + t.Drain() return } idleTimer.Reset(val) case <-ageTimer.C: - t.drain(http2.ErrCodeNo, []byte{}) + t.Drain() ageTimer.Reset(t.kp.MaxConnectionAgeGrace) select { case <-ageTimer.C: @@ -1064,11 +1189,11 @@ func (t *http2Server) keepalive() { // Close starts shutting down the http2Server transport. // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. -func (t *http2Server) Close() error { +func (t *http2Server) Close() { t.mu.Lock() if t.state == closing { t.mu.Unlock() - return errors.New("transport: Close() was already called") + return } t.state = closing streams := t.activeStreams @@ -1076,7 +1201,9 @@ func (t *http2Server) Close() error { t.mu.Unlock() t.controlBuf.finish() close(t.done) - err := t.conn.Close() + if err := t.conn.Close(); err != nil && logger.V(logLevel) { + logger.Infof("transport: error closing conn during Close: %v", err) + } if channelz.IsOn() { channelz.RemoveEntry(t.channelzID) } @@ -1088,7 +1215,6 @@ func (t *http2Server) Close() error { connEnd := &stats.ConnEnd{} t.stats.HandleConn(t.ctx, connEnd) } - return err } // deleteStream deletes the stream s from transport's active streams. @@ -1153,17 +1279,13 @@ func (t *http2Server) RemoteAddr() net.Addr { } func (t *http2Server) Drain() { - t.drain(http2.ErrCodeNo, []byte{}) -} - -func (t *http2Server) drain(code http2.ErrCode, debugData []byte) { t.mu.Lock() defer t.mu.Unlock() if t.drainChan != nil { return } t.drainChan = make(chan struct{}) - t.controlBuf.put(&goAway{code: code, debugData: debugData, headsUp: true}) + t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte{}, headsUp: true}) } var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} @@ -1281,3 +1403,18 @@ func getJitter(v time.Duration) time.Duration { j := grpcrand.Int63n(2*r) - r return time.Duration(j) } + +type connectionKey struct{} + +// GetConnection gets the connection from the context. +func GetConnection(ctx context.Context) net.Conn { + conn, _ := ctx.Value(connectionKey{}).(net.Conn) + return conn +} + +// SetConnection adds the connection to the context to be able to get +// information about the destination ip and port for an incoming RPC. This also +// allows any unary or streaming interceptors to see the connection. +func setConnection(ctx context.Context, conn net.Conn) context.Context { + return context.WithValue(ctx, connectionKey{}, conn) +} diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index c7dee140cf1..d8247bcdf69 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -39,7 +39,6 @@ import ( spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/status" ) @@ -96,53 +95,6 @@ var ( logger = grpclog.Component("transport") ) -type parsedHeaderData struct { - encoding string - // statusGen caches the stream status received from the trailer the server - // sent. Client side only. Do not access directly. After all trailers are - // parsed, use the status method to retrieve the status. - statusGen *status.Status - // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not - // intended for direct access outside of parsing. - rawStatusCode *int - rawStatusMsg string - httpStatus *int - // Server side only fields. - timeoutSet bool - timeout time.Duration - method string - httpMethod string - // key-value metadata map from the peer. - mdata map[string][]string - statsTags []byte - statsTrace []byte - contentSubtype string - - // isGRPC field indicates whether the peer is speaking gRPC (otherwise HTTP). - // - // We are in gRPC mode (peer speaking gRPC) if: - // * We are client side and have already received a HEADER frame that indicates gRPC peer. - // * The header contains valid a content-type, i.e. a string starts with "application/grpc" - // And we should handle error specific to gRPC. - // - // Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we - // are in HTTP fallback mode, and should handle error specific to HTTP. - isGRPC bool - grpcErr error - httpErr error - contentTypeErr string -} - -// decodeState configures decoding criteria and records the decoded data. -type decodeState struct { - // whether decoding on server side or not - serverSide bool - - // Records the states during HPACK decoding. It will be filled with info parsed from HTTP HEADERS - // frame once decodeHeader function has been invoked and returned. - data parsedHeaderData -} - // isReservedHeader checks whether hdr belongs to HTTP2 headers // reserved by gRPC protocol. Any other headers are classified as the // user-specified metadata. @@ -180,14 +132,6 @@ func isWhitelistedHeader(hdr string) bool { } } -func (d *decodeState) status() *status.Status { - if d.data.statusGen == nil { - // No status-details were provided; generate status using code/msg. - d.data.statusGen = status.New(codes.Code(int32(*(d.data.rawStatusCode))), d.data.rawStatusMsg) - } - return d.data.statusGen -} - const binHdrSuffix = "-bin" func encodeBinHeader(v []byte) string { @@ -217,168 +161,16 @@ func decodeMetadataHeader(k, v string) (string, error) { return v, nil } -func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) (http2.ErrCode, error) { - // frame.Truncated is set to true when framer detects that the current header - // list size hits MaxHeaderListSize limit. - if frame.Truncated { - return http2.ErrCodeFrameSize, status.Error(codes.Internal, "peer header list size exceeded limit") - } - - for _, hf := range frame.Fields { - d.processHeaderField(hf) - } - - if d.data.isGRPC { - if d.data.grpcErr != nil { - return http2.ErrCodeProtocol, d.data.grpcErr - } - if d.serverSide { - return http2.ErrCodeNo, nil - } - if d.data.rawStatusCode == nil && d.data.statusGen == nil { - // gRPC status doesn't exist. - // Set rawStatusCode to be unknown and return nil error. - // So that, if the stream has ended this Unknown status - // will be propagated to the user. - // Otherwise, it will be ignored. In which case, status from - // a later trailer, that has StreamEnded flag set, is propagated. - code := int(codes.Unknown) - d.data.rawStatusCode = &code - } - return http2.ErrCodeNo, nil - } - - // HTTP fallback mode - if d.data.httpErr != nil { - return http2.ErrCodeProtocol, d.data.httpErr - } - - var ( - code = codes.Internal // when header does not include HTTP status, return INTERNAL - ok bool - ) - - if d.data.httpStatus != nil { - code, ok = HTTPStatusConvTab[*(d.data.httpStatus)] - if !ok { - code = codes.Unknown - } - } - - return http2.ErrCodeProtocol, status.Error(code, d.constructHTTPErrMsg()) -} - -// constructErrMsg constructs error message to be returned in HTTP fallback mode. -// Format: HTTP status code and its corresponding message + content-type error message. -func (d *decodeState) constructHTTPErrMsg() string { - var errMsgs []string - - if d.data.httpStatus == nil { - errMsgs = append(errMsgs, "malformed header: missing HTTP status") - } else { - errMsgs = append(errMsgs, fmt.Sprintf("%s: HTTP status code %d", http.StatusText(*(d.data.httpStatus)), *d.data.httpStatus)) - } - - if d.data.contentTypeErr == "" { - errMsgs = append(errMsgs, "transport: missing content-type field") - } else { - errMsgs = append(errMsgs, d.data.contentTypeErr) - } - - return strings.Join(errMsgs, "; ") -} - -func (d *decodeState) addMetadata(k, v string) { - if d.data.mdata == nil { - d.data.mdata = make(map[string][]string) +func decodeGRPCStatusDetails(rawDetails string) (*status.Status, error) { + v, err := decodeBinHeader(rawDetails) + if err != nil { + return nil, err } - d.data.mdata[k] = append(d.data.mdata[k], v) -} - -func (d *decodeState) processHeaderField(f hpack.HeaderField) { - switch f.Name { - case "content-type": - contentSubtype, validContentType := grpcutil.ContentSubtype(f.Value) - if !validContentType { - d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value) - return - } - d.data.contentSubtype = contentSubtype - // TODO: do we want to propagate the whole content-type in the metadata, - // or come up with a way to just propagate the content-subtype if it was set? - // ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"} - // in the metadata? - d.addMetadata(f.Name, f.Value) - d.data.isGRPC = true - case "grpc-encoding": - d.data.encoding = f.Value - case "grpc-status": - code, err := strconv.Atoi(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err) - return - } - d.data.rawStatusCode = &code - case "grpc-message": - d.data.rawStatusMsg = decodeGrpcMessage(f.Value) - case "grpc-status-details-bin": - v, err := decodeBinHeader(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) - return - } - s := &spb.Status{} - if err := proto.Unmarshal(v, s); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) - return - } - d.data.statusGen = status.FromProto(s) - case "grpc-timeout": - d.data.timeoutSet = true - var err error - if d.data.timeout, err = decodeTimeout(f.Value); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err) - } - case ":path": - d.data.method = f.Value - case ":status": - code, err := strconv.Atoi(f.Value) - if err != nil { - d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err) - return - } - d.data.httpStatus = &code - case "grpc-tags-bin": - v, err := decodeBinHeader(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) - return - } - d.data.statsTags = v - d.addMetadata(f.Name, string(v)) - case "grpc-trace-bin": - v, err := decodeBinHeader(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) - return - } - d.data.statsTrace = v - d.addMetadata(f.Name, string(v)) - case ":method": - d.data.httpMethod = f.Value - default: - if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) { - break - } - v, err := decodeMetadataHeader(f.Name, f.Value) - if err != nil { - if logger.V(logLevel) { - logger.Errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err) - } - return - } - d.addMetadata(f.Name, v) + st := &spb.Status{} + if err = proto.Unmarshal(v, st); err != nil { + return nil, err } + return status.FromProto(st), nil } type timeoutUnit uint8 diff --git a/internal/transport/http_util_test.go b/internal/transport/http_util_test.go index 2205050acea..bbd53180471 100644 --- a/internal/transport/http_util_test.go +++ b/internal/transport/http_util_test.go @@ -23,9 +23,6 @@ import ( "reflect" "testing" "time" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/hpack" ) func (s) TestTimeoutDecode(t *testing.T) { @@ -189,68 +186,6 @@ func (s) TestDecodeMetadataHeader(t *testing.T) { } } -func (s) TestDecodeHeaderH2ErrCode(t *testing.T) { - for _, test := range []struct { - name string - // input - metaHeaderFrame *http2.MetaHeadersFrame - serverSide bool - // output - wantCode http2.ErrCode - }{ - { - name: "valid header", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/grpc"}, - }}, - wantCode: http2.ErrCodeNo, - }, - { - name: "valid header serverSide", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/grpc"}, - }}, - serverSide: true, - wantCode: http2.ErrCodeNo, - }, - { - name: "invalid grpc status header field", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/grpc"}, - {Name: "grpc-status", Value: "xxxx"}, - }}, - wantCode: http2.ErrCodeProtocol, - }, - { - name: "invalid http content type", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/json"}, - }}, - wantCode: http2.ErrCodeProtocol, - }, - { - name: "http fallback and invalid http status", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - // No content type provided then fallback into handling http error. - {Name: ":status", Value: "xxxx"}, - }}, - wantCode: http2.ErrCodeProtocol, - }, - { - name: "http2 frame size exceeds", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: nil, Truncated: true}, - wantCode: http2.ErrCodeFrameSize, - }, - } { - t.Run(test.name, func(t *testing.T) { - state := &decodeState{serverSide: test.serverSide} - if h2code, _ := state.decodeHeader(test.metaHeaderFrame); h2code != test.wantCode { - t.Fatalf("decodeState.decodeHeader(%v) = %v, want %v", test.metaHeaderFrame, h2code, test.wantCode) - } - }) - } -} - func (s) TestParseDialTarget(t *testing.T) { for _, test := range []struct { target, wantNet, wantAddr string diff --git a/internal/transport/keepalive_test.go b/internal/transport/keepalive_test.go index c8f177fecf1..c4021925f32 100644 --- a/internal/transport/keepalive_test.go +++ b/internal/transport/keepalive_test.go @@ -24,6 +24,7 @@ package transport import ( "context" + "fmt" "io" "net" "testing" @@ -47,7 +48,7 @@ func (s) TestMaxConnectionIdle(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -68,7 +69,7 @@ func (s) TestMaxConnectionIdle(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayNoReason { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayNoReason { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayNoReason) } case <-timeout.C: @@ -86,7 +87,7 @@ func (s) TestMaxConnectionIdleBusyClient(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -122,7 +123,7 @@ func (s) TestMaxConnectionAge(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -142,7 +143,7 @@ func (s) TestMaxConnectionAge(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayNoReason { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayNoReason { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayNoReason) } case <-timeout.C: @@ -169,7 +170,7 @@ func (s) TestKeepaliveServerClosesUnresponsiveClient(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -192,7 +193,7 @@ func (s) TestKeepaliveServerClosesUnresponsiveClient(t *testing.T) { // We read from the net.Conn till we get an error, which is expected when // the server closes the connection as part of the keepalive logic. - errCh := make(chan error) + errCh := make(chan error, 1) go func() { b := make([]byte, 24) for { @@ -228,7 +229,7 @@ func (s) TestKeepaliveServerWithResponsiveClient(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -257,7 +258,7 @@ func (s) TestKeepaliveClientClosesUnresponsiveServer(t *testing.T) { PermitWithoutStream: true, }}, connCh) defer cancel() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) conn, ok := <-connCh if !ok { @@ -288,7 +289,7 @@ func (s) TestKeepaliveClientOpenWithUnresponsiveServer(t *testing.T) { Timeout: 1 * time.Second, }}, connCh) defer cancel() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) conn, ok := <-connCh if !ok { @@ -317,7 +318,7 @@ func (s) TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { Timeout: 1 * time.Second, }}, connCh) defer cancel() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) conn, ok := <-connCh if !ok { @@ -345,14 +346,21 @@ func (s) TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { // responds to keepalive pings, and makes sure than a client transport stays // healthy without any active streams. func (s) TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { - server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, normal, ConnectOptions{ - KeepaliveParams: keepalive.ClientParameters{ - Time: 1 * time.Second, - Timeout: 1 * time.Second, - PermitWithoutStream: true, - }}) + server, client, cancel := setUpWithOptions(t, 0, + &ServerConfig{ + KeepalivePolicy: keepalive.EnforcementPolicy{ + PermitWithoutStream: true, + }, + }, + normal, + ConnectOptions{ + KeepaliveParams: keepalive.ClientParameters{ + Time: 1 * time.Second, + Timeout: 1 * time.Second, + PermitWithoutStream: true, + }}) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -391,7 +399,7 @@ func (s) TestKeepaliveClientFrequency(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -402,7 +410,7 @@ func (s) TestKeepaliveClientFrequency(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayTooManyPings { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayTooManyPings { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayTooManyPings) } case <-timeout.C: @@ -436,7 +444,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -447,7 +455,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayTooManyPings { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayTooManyPings { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayTooManyPings) } case <-timeout.C: @@ -480,7 +488,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -497,7 +505,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { if !timeout.Stop() { <-timeout.C } - if reason := client.GetGoAwayReason(); reason != GoAwayTooManyPings { + if reason, _ := client.GetGoAwayReason(); reason != GoAwayTooManyPings { t.Fatalf("GoAwayReason is %v, want %v", reason, GoAwayTooManyPings) } case <-timeout.C: @@ -530,7 +538,7 @@ func (s) TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -564,7 +572,7 @@ func (s) TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -604,7 +612,7 @@ func (s) TestKeepaliveServerEnforcementWithDormantKeepaliveOnClient(t *testing.T } server, client, cancel := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() @@ -658,7 +666,7 @@ func (s) TestTCPUserTimeout(t *testing.T) { }, ) defer func() { - client.Close() + client.Close(fmt.Errorf("closed manually by test")) server.stop() cancel() }() diff --git a/internal/transport/networktype/networktype.go b/internal/transport/networktype/networktype.go index 96967428b51..7bb53cff101 100644 --- a/internal/transport/networktype/networktype.go +++ b/internal/transport/networktype/networktype.go @@ -17,7 +17,7 @@ */ // Package networktype declares the network type to be used in the default -// dailer. Attribute of a resolver.Address. +// dialer. Attribute of a resolver.Address. package networktype import ( diff --git a/internal/transport/proxy_test.go b/internal/transport/proxy_test.go index a2f1aa43854..404354a19db 100644 --- a/internal/transport/proxy_test.go +++ b/internal/transport/proxy_test.go @@ -1,3 +1,4 @@ +//go:build !race // +build !race /* @@ -119,7 +120,7 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy msg := []byte{4, 3, 5, 2} recvBuf := make([]byte, len(msg)) - done := make(chan error) + done := make(chan error, 1) go func() { in, err := blis.Accept() if err != nil { diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 5cf7c5f80fe..d3bf65b2bdf 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -30,6 +30,7 @@ import ( "net" "sync" "sync/atomic" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -518,7 +519,8 @@ const ( // ServerConfig consists of all the configurations to establish a server transport. type ServerConfig struct { MaxStreams uint32 - AuthInfo credentials.AuthInfo + ConnectionTimeout time.Duration + Credentials credentials.TransportCredentials InTapHandle tap.ServerInHandle StatsHandler stats.Handler KeepaliveParams keepalive.ServerParameters @@ -532,12 +534,6 @@ type ServerConfig struct { HeaderTableSize *uint32 } -// NewServerTransport creates a ServerTransport with conn or non-nil error -// if it fails. -func NewServerTransport(protocol string, conn net.Conn, config *ServerConfig) (ServerTransport, error) { - return newHTTP2Server(conn, config) -} - // ConnectOptions covers all relevant options for communicating with the server. type ConnectOptions struct { // UserAgent is the application user agent. @@ -622,7 +618,7 @@ type ClientTransport interface { // Close tears down this transport. Once it returns, the transport // should not be accessed any more. The caller must make sure this // is called only once. - Close() error + Close(err error) // GracefulClose starts to tear down the transport: the transport will stop // accepting new RPCs and NewStream will return error. Once all streams are @@ -656,8 +652,9 @@ type ClientTransport interface { // HTTP/2). GoAway() <-chan struct{} - // GetGoAwayReason returns the reason why GoAway frame was received. - GetGoAwayReason() GoAwayReason + // GetGoAwayReason returns the reason why GoAway frame was received, along + // with a human readable string with debug info. + GetGoAwayReason() (GoAwayReason, string) // RemoteAddr returns the remote network address. RemoteAddr() net.Addr @@ -693,7 +690,7 @@ type ServerTransport interface { // Close tears down the transport. Once it is called, the transport // should not be accessed any more. All the pending streams and their // handlers will be terminated asynchronously. - Close() error + Close() // RemoteAddr returns the remote network address. RemoteAddr() net.Addr diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 1d8d3ed355d..4e561a73c4c 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -323,7 +323,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT if err != nil { return } - transport, err := NewServerTransport("http2", conn, serverConfig) + transport, err := NewServerTransport(conn, serverConfig) if err != nil { return } @@ -481,7 +481,7 @@ func (s) TestInflightStreamClosing(t *testing.T) { server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -550,7 +550,7 @@ func (s) TestClientSendAndReceive(t *testing.T) { if recvErr != io.EOF { t.Fatalf("Error: %v; want ", recvErr) } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -560,7 +560,7 @@ func (s) TestClientErrorNotify(t *testing.T) { go server.stop() // ct.reader should detect the error and activate ct.Error(). <-ct.Error() - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) } func performOneRPC(ct ClientTransport) { @@ -597,7 +597,7 @@ func (s) TestClientMix(t *testing.T) { }(s) go func(ct ClientTransport) { <-ct.Error() - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) }(ct) for i := 0; i < 1000; i++ { time.Sleep(10 * time.Millisecond) @@ -636,7 +636,7 @@ func (s) TestLargeMessage(t *testing.T) { }() } wg.Wait() - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -653,7 +653,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) { server, ct, cancel := setUpWithOptions(t, 0, sc, delayRead, co) defer cancel() defer server.stop() - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) server.mu.Lock() ready := server.ready server.mu.Unlock() @@ -780,7 +780,7 @@ func (s) TestGracefulClose(t *testing.T) { go func() { defer wg.Done() str, err := ct.NewStream(ctx, &CallHdr{}) - if err == ErrConnClosing { + if err != nil && err.(*NewStreamError).Err == ErrConnClosing { return } else if err != nil { t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) @@ -831,7 +831,7 @@ func (s) TestLargeMessageSuspension(t *testing.T) { if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -841,7 +841,7 @@ func (s) TestMaxStreams(t *testing.T) { } server, ct, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer cancel() - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) defer server.stop() callHdr := &CallHdr{ Host: "localhost", @@ -901,7 +901,7 @@ func (s) TestMaxStreams(t *testing.T) { // Close the first stream created so that the new stream can finally be created. ct.CloseStream(s, nil) <-done - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) <-ct.writerDone if ct.maxConcurrentStreams != 1 { t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams) @@ -960,7 +960,7 @@ func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) { sc.mu.Unlock() break } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) select { case <-ss.Context().Done(): if ss.Context().Err() != context.Canceled { @@ -980,7 +980,7 @@ func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) { server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() @@ -1069,7 +1069,7 @@ func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) { server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() defer server.mu.Unlock() @@ -1302,7 +1302,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) { if err != nil { t.Fatalf("Error while creating client transport: %v", err) } - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) str, err := ct.NewStream(connectCtx, &CallHdr{}) if err != nil { t.Fatalf("Error while creating stream: %v", err) @@ -1345,7 +1345,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) { if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) { t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -1367,7 +1367,7 @@ func (s) TestInvalidHeaderField(t *testing.T) { if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) { t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField) } - ct.Close() + ct.Close(fmt.Errorf("closed manually by test")) server.stop() } @@ -1375,7 +1375,7 @@ func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) defer cancel() defer server.stop() - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"}) @@ -1481,7 +1481,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) server, client, cancel := setUpWithOptions(t, 0, sc, pingpong, co) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() defer server.mu.Unlock() @@ -1563,7 +1563,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) } // Close down both server and client so that their internals can be read without data // races. - client.Close() + client.Close(fmt.Errorf("closed manually by test")) st.Close() <-st.readerDone <-st.writerDone @@ -1663,81 +1663,291 @@ func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { } } -// If the client sends an HTTP/2 request with a :method header with a value other than POST, as specified in -// the gRPC over HTTP/2 specification, the server should close the stream. -func (s) TestServerWithClientSendingWrongMethod(t *testing.T) { - server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) - defer server.stop() - // Create a client directly to not couple what you can send to API of http2_client.go. - mconn, err := net.Dial("tcp", server.lis.Addr().String()) - if err != nil { - t.Fatalf("Client failed to dial: %v", err) +// TestHeadersCausingStreamError tests headers that should cause a stream protocol +// error, which would end up with a RST_STREAM being sent to the client and also +// the server closing the stream. +func (s) TestHeadersCausingStreamError(t *testing.T) { + tests := []struct { + name string + headers []struct { + name string + values []string + } + }{ + // If the client sends an HTTP/2 request with a :method header with a + // value other than POST, as specified in the gRPC over HTTP/2 + // specification, the server should close the stream. + { + name: "Client Sending Wrong Method", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"PUT"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "content-type", values: []string{"application/grpc"}}, + }, + }, + // "Transports must consider requests containing the Connection header + // as malformed" - A41 Malformed requests map to a stream error of type + // PROTOCOL_ERROR. + { + name: "Connection header present", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "content-type", values: []string{"application/grpc"}}, + {name: "connection", values: []string{"not-supported"}}, + }, + }, + // multiple :authority or multiple Host headers would make the eventual + // :authority ambiguous as per A41. Since these headers won't have a + // content-type that corresponds to a grpc-client, the server should + // simply write a RST_STREAM to the wire. + { + // Note: multiple authority headers are handled by the framer + // itself, which will cause a stream error. Thus, it will never get + // to operateHeaders with the check in operateHeaders for stream + // error, but the server transport will still send a stream error. + name: "Multiple authority headers", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost", "localhost2"}}, + {name: "host", values: []string{"localhost"}}, + }, + }, } - defer mconn.Close() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) + defer server.stop() + // Create a client directly to not tie what you can send to API of + // http2_client.go (i.e. control headers being sent). + mconn, err := net.Dial("tcp", server.lis.Addr().String()) + if err != nil { + t.Fatalf("Client failed to dial: %v", err) + } + defer mconn.Close() - if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { - t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) + if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { + t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) + } + + framer := http2.NewFramer(mconn, mconn) + if err := framer.WriteSettings(); err != nil { + t.Fatalf("Error while writing settings: %v", err) + } + + // result chan indicates that reader received a RSTStream from server. + // An error will be passed on it if any other frame is received. + result := testutils.NewChannel() + + // Launch a reader goroutine. + go func() { + for { + frame, err := framer.ReadFrame() + if err != nil { + return + } + switch frame := frame.(type) { + case *http2.SettingsFrame: + // Do nothing. A settings frame is expected from server preface. + case *http2.RSTStreamFrame: + if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeProtocol { + // Client only created a single stream, so RST Stream should be for that single stream. + result.Send(fmt.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))) + } + // Records that client successfully received RST Stream frame. + result.Send(nil) + return + default: + // The server should send nothing but a single RST Stream frame. + result.Send(errors.New("the client received a frame other than RST Stream")) + } + } + }() + + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + + // Needs to build headers deterministically to conform to gRPC over + // HTTP/2 spec. + for _, header := range test.headers { + for _, value := range header.values { + if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { + t.Fatalf("Error while encoding header: %v", err) + } + } + } + + if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { + t.Fatalf("Error while writing headers: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + r, err := result.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from channel: %v", err) + } + if r != nil { + t.Fatalf("want nil, got %v", r) + } + }) } +} - framer := http2.NewFramer(mconn, mconn) - if err := framer.WriteSettings(); err != nil { - t.Fatalf("Error while writing settings: %v", err) +// TestHeadersMultipleHosts tests that a request with multiple hosts gets +// rejected with HTTP Status 400 and gRPC status Internal, regardless of whether +// the client is speaking gRPC or not. +func (s) TestHeadersMultipleHosts(t *testing.T) { + tests := []struct { + name string + headers []struct { + name string + values []string + } + }{ + // Note: multiple authority headers are handled by the framer itself, + // which will cause a stream error. Thus, it will never get to + // operateHeaders with the check in operateHeaders for possible grpc-status sent back. + + // multiple :authority or multiple Host headers would make the eventual + // :authority ambiguous as per A41. This takes precedence even over the + // fact a request is non grpc. All of these requests should be rejected + // with grpc-status Internal. + { + name: "Multiple host headers non grpc", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "host", values: []string{"localhost", "localhost2"}}, + }, + }, + { + name: "Multiple host headers grpc", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "content-type", values: []string{"application/grpc"}}, + {name: "host", values: []string{"localhost", "localhost2"}}, + }, + }, } + for _, test := range tests { + server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) + defer server.stop() + // Create a client directly to not tie what you can send to API of + // http2_client.go (i.e. control headers being sent). + mconn, err := net.Dial("tcp", server.lis.Addr().String()) + if err != nil { + t.Fatalf("Client failed to dial: %v", err) + } + defer mconn.Close() - // success chan indicates that reader received a RSTStream from server. - // An error will be passed on it if any other frame is received. - success := testutils.NewChannel() + if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { + t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) + } - // Launch a reader goroutine. - go func() { - for { - frame, err := framer.ReadFrame() - if err != nil { - return + framer := http2.NewFramer(mconn, mconn) + framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil) + if err := framer.WriteSettings(); err != nil { + t.Fatalf("Error while writing settings: %v", err) + } + + // result chan indicates that reader received a Headers Frame with + // desired grpc status and message from server. An error will be passed + // on it if any other frame is received. + result := testutils.NewChannel() + + // Launch a reader goroutine. + go func() { + for { + frame, err := framer.ReadFrame() + if err != nil { + return + } + switch frame := frame.(type) { + case *http2.SettingsFrame: + // Do nothing. A settings frame is expected from server preface. + case *http2.MetaHeadersFrame: + var status, grpcStatus, grpcMessage string + for _, header := range frame.Fields { + if header.Name == ":status" { + status = header.Value + } + if header.Name == "grpc-status" { + grpcStatus = header.Value + } + if header.Name == "grpc-message" { + grpcMessage = header.Value + } + } + if status != "400" { + result.Send(fmt.Errorf("incorrect HTTP Status got %v, want 200", status)) + return + } + if grpcStatus != "13" { // grpc status code internal + result.Send(fmt.Errorf("incorrect gRPC Status got %v, want 13", grpcStatus)) + return + } + if !strings.Contains(grpcMessage, "both must only have 1 value as per HTTP/2 spec") { + result.Send(fmt.Errorf("incorrect gRPC message")) + return + } + + // Records that client successfully received a HeadersFrame + // with expected Trailers-Only response. + result.Send(nil) + return + default: + // The server should send nothing but a single Settings and Headers frame. + result.Send(errors.New("the client received a frame other than Settings or Headers")) + } } - switch frame := frame.(type) { - case *http2.SettingsFrame: - // Do nothing. A settings frame is expected from server preface. - case *http2.RSTStreamFrame: - if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeProtocol { - // Client only created a single stream, so RST Stream should be for that single stream. - t.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) + }() + + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + + // Needs to build headers deterministically to conform to gRPC over + // HTTP/2 spec. + for _, header := range test.headers { + for _, value := range header.values { + if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { + t.Fatalf("Error while encoding header: %v", err) } - // Records that client successfully received RST Stream frame. - success.Send(nil) - return - default: - // The server should send nothing but a single RST Stream frame. - success.Send(errors.New("The client received a frame other than RST Stream")) } } - }() - - // Done with HTTP/2 setup - now create a stream with a bad method header. - var buf bytes.Buffer - henc := hpack.NewEncoder(&buf) - // Method is required to be POST in a gRPC call. - if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "PUT"}); err != nil { - t.Fatalf("Error while encoding header: %v", err) - } - // Have the rest of the headers be ok and within the gRPC over HTTP/2 spec. - if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { - t.Fatalf("Error while encoding header: %v", err) - } - if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { - t.Fatalf("Error while encoding header: %v", err) - } - if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { - t.Fatalf("Error while encoding header: %v", err) - } - if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { - t.Fatalf("Error while writing headers: %v", err) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if e, err := success.Receive(ctx); e != nil || err != nil { - t.Fatalf("Error in frame server should send: %v. Error receiving from channel: %v", e, err) + if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { + t.Fatalf("Error while writing headers: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + r, err := result.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from channel: %v", err) + } + if r != nil { + t.Fatalf("want nil, got %v", r) + } } } @@ -1762,7 +1972,7 @@ func runPingPongTest(t *testing.T, msgSize int) { server, client, cancel := setUp(t, 0, 0, pingpong) defer cancel() defer server.stop() - defer client.Close() + defer client.Close(fmt.Errorf("closed manually by test")) waitWhileTrue(t, func() (bool, error) { server.mu.Lock() defer server.mu.Unlock() @@ -1850,7 +2060,7 @@ func (s) TestHeaderTblSize(t *testing.T) { server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) defer cancel() - defer ct.Close() + defer ct.Close(fmt.Errorf("closed manually by test")) defer server.stop() ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() @@ -1969,10 +2179,186 @@ func (s) TestClientHandshakeInfo(t *testing.T) { if err != nil { t.Fatalf("NewClientTransport(): %v", err) } - defer tr.Close() + defer tr.Close(fmt.Errorf("closed manually by test")) wantAttr := attributes.New(testAttrKey, testAttrVal) if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) { t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr) } } + +func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { + testStream := func() *Stream { + return &Stream{ + done: make(chan struct{}), + headerChan: make(chan struct{}), + buf: &recvBuffer{ + c: make(chan recvMsg), + mu: sync.Mutex{}, + }, + } + } + + testClient := func(ts *Stream) *http2Client { + return &http2Client{ + mu: sync.Mutex{}, + activeStreams: map[uint32]*Stream{ + 0: ts, + }, + controlBuf: &controlBuffer{ + ch: make(chan struct{}), + done: make(chan struct{}), + list: &itemList{}, + }, + } + } + + for _, test := range []struct { + name string + // input + metaHeaderFrame *http2.MetaHeadersFrame + // output + wantStatus *status.Status + }{ + { + name: "valid header", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + {Name: "grpc-status", Value: "0"}, + {Name: ":status", Value: "200"}, + }, + }, + // no error + wantStatus: status.New(codes.OK, ""), + }, + { + name: "missing content-type header", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "grpc-status", Value: "0"}, + {Name: ":status", Value: "200"}, + }, + }, + wantStatus: status.New( + codes.Unknown, + "malformed header: missing HTTP content-type", + ), + }, + { + name: "invalid grpc status header field", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + {Name: "grpc-status", Value: "xxxx"}, + {Name: ":status", Value: "200"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "transport: malformed grpc-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", + ), + }, + { + name: "invalid http content type", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/json"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"", + ), + }, + { + name: "http fallback and invalid http status", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + // No content type provided then fallback into handling http error. + {Name: ":status", Value: "xxxx"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "transport: malformed http-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", + ), + }, + { + name: "http2 frame size exceeds", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: nil, + Truncated: true, + }, + wantStatus: status.New( + codes.Internal, + "peer header list size exceeded limit", + ), + }, + { + name: "bad status in grpc mode", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + {Name: "grpc-status", Value: "0"}, + {Name: ":status", Value: "504"}, + }, + }, + wantStatus: status.New( + codes.Unavailable, + "unexpected HTTP status code received from server: 504 (Gateway Timeout)", + ), + }, + { + name: "missing http status", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "malformed header: missing HTTP status", + ), + }, + } { + + t.Run(test.name, func(t *testing.T) { + ts := testStream() + s := testClient(ts) + + test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ + FrameHeader: http2.FrameHeader{ + StreamID: 0, + }, + } + + s.operateHeaders(test.metaHeaderFrame) + + got := ts.status + want := test.wantStatus + if got.Code() != want.Code() || got.Message() != want.Message() { + t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) + } + }) + t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) { + ts := testStream() + s := testClient(ts) + + test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ + FrameHeader: http2.FrameHeader{ + StreamID: 0, + Flags: http2.FlagHeadersEndStream, + }, + } + + s.operateHeaders(test.metaHeaderFrame) + + got := ts.status + want := test.wantStatus + if got.Code() != want.Code() || got.Message() != want.Message() { + t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) + } + }) + } +} diff --git a/internal/xds/bootstrap.go b/internal/xds/bootstrap.go new file mode 100644 index 00000000000..1d74ab46a11 --- /dev/null +++ b/internal/xds/bootstrap.go @@ -0,0 +1,147 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package xds contains types that need to be shared between code under +// google.golang.org/grpc/xds/... and the rest of gRPC. +package xds + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/xds/env" +) + +var logger = grpclog.Component("internal/xds") + +// TransportAPI refers to the API version for xDS transport protocol. +type TransportAPI int + +const ( + // TransportV2 refers to the v2 xDS transport protocol. + TransportV2 TransportAPI = iota + // TransportV3 refers to the v3 xDS transport protocol. + TransportV3 +) + +// BootstrapOptions wraps the parameters passed to SetupBootstrapFile. +type BootstrapOptions struct { + // Version is the xDS transport protocol version. + Version TransportAPI + // NodeID is the node identifier of the gRPC client/server node in the + // proxyless service mesh. + NodeID string + // ServerURI is the address of the management server. + ServerURI string + // ServerListenerResourceNameTemplate is the Listener resource name to fetch. + ServerListenerResourceNameTemplate string + // CertificateProviders is the certificate providers configuration. + CertificateProviders map[string]json.RawMessage +} + +// SetupBootstrapFile creates a temporary file with bootstrap contents, based on +// the passed in options, and updates the bootstrap environment variable to +// point to this file. +// +// Returns a cleanup function which will be non-nil if the setup process was +// completed successfully. It is the responsibility of the caller to invoke the +// cleanup function at the end of the test. +func SetupBootstrapFile(opts BootstrapOptions) (func(), error) { + bootstrapContents, err := BootstrapContents(opts) + if err != nil { + return nil, err + } + f, err := ioutil.TempFile("", "test_xds_bootstrap_*") + if err != nil { + return nil, fmt.Errorf("failed to created bootstrap file: %v", err) + } + + if err := ioutil.WriteFile(f.Name(), bootstrapContents, 0644); err != nil { + return nil, fmt.Errorf("failed to created bootstrap file: %v", err) + } + logger.Infof("Created bootstrap file at %q with contents: %s\n", f.Name(), bootstrapContents) + + origBootstrapFileName := env.BootstrapFileName + env.BootstrapFileName = f.Name() + return func() { + os.Remove(f.Name()) + env.BootstrapFileName = origBootstrapFileName + }, nil +} + +// BootstrapContents returns the contents to go into a bootstrap file, +// environment, or configuration passed to +// xds.NewXDSResolverWithConfigForTesting. +func BootstrapContents(opts BootstrapOptions) ([]byte, error) { + cfg := &bootstrapConfig{ + XdsServers: []server{ + { + ServerURI: opts.ServerURI, + ChannelCreds: []creds{ + { + Type: "insecure", + }, + }, + }, + }, + Node: node{ + ID: opts.NodeID, + }, + CertificateProviders: opts.CertificateProviders, + ServerListenerResourceNameTemplate: opts.ServerListenerResourceNameTemplate, + } + switch opts.Version { + case TransportV2: + // TODO: Add any v2 specific fields. + case TransportV3: + cfg.XdsServers[0].ServerFeatures = append(cfg.XdsServers[0].ServerFeatures, "xds_v3") + default: + return nil, fmt.Errorf("unsupported xDS transport protocol version: %v", opts.Version) + } + + bootstrapContents, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to created bootstrap file: %v", err) + } + return bootstrapContents, nil +} + +type bootstrapConfig struct { + XdsServers []server `json:"xds_servers,omitempty"` + Node node `json:"node,omitempty"` + CertificateProviders map[string]json.RawMessage `json:"certificate_providers,omitempty"` + ServerListenerResourceNameTemplate string `json:"server_listener_resource_name_template,omitempty"` +} + +type server struct { + ServerURI string `json:"server_uri,omitempty"` + ChannelCreds []creds `json:"channel_creds,omitempty"` + ServerFeatures []string `json:"server_features,omitempty"` +} + +type creds struct { + Type string `json:"type,omitempty"` + Config interface{} `json:"config,omitempty"` +} + +type node struct { + ID string `json:"id,omitempty"` +} diff --git a/xds/internal/env/env.go b/internal/xds/env/env.go similarity index 52% rename from xds/internal/env/env.go rename to internal/xds/env/env.go index a28d741f356..2977bfa6285 100644 --- a/xds/internal/env/env.go +++ b/internal/xds/env/env.go @@ -37,42 +37,59 @@ const ( // and kept in variable BootstrapFileName. // // When both bootstrap FileName and FileContent are set, FileName is used. - BootstrapFileContentEnv = "GRPC_XDS_BOOTSTRAP_CONFIG" - circuitBreakingSupportEnv = "GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING" - timeoutSupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT" - faultInjectionSupportEnv = "GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION" + BootstrapFileContentEnv = "GRPC_XDS_BOOTSTRAP_CONFIG" + + ringHashSupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH" clientSideSecuritySupportEnv = "GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT" + aggregateAndDNSSupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER" + retrySupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY" + rbacSupportEnv = "GRPC_XDS_EXPERIMENTAL_RBAC" + + c2pResolverSupportEnv = "GRPC_EXPERIMENTAL_GOOGLE_C2P_RESOLVER" + c2pResolverTestOnlyTrafficDirectorURIEnv = "GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI" ) var ( // BootstrapFileName holds the name of the file which contains xDS bootstrap // configuration. Users can specify the location of the bootstrap file by - // setting the environment variable "GRPC_XDS_BOOSTRAP". + // setting the environment variable "GRPC_XDS_BOOTSTRAP". // // When both bootstrap FileName and FileContent are set, FileName is used. BootstrapFileName = os.Getenv(BootstrapFileNameEnv) // BootstrapFileContent holds the content of the xDS bootstrap // configuration. Users can specify the bootstrap config by - // setting the environment variable "GRPC_XDS_BOOSTRAP_CONFIG". + // setting the environment variable "GRPC_XDS_BOOTSTRAP_CONFIG". // // When both bootstrap FileName and FileContent are set, FileName is used. BootstrapFileContent = os.Getenv(BootstrapFileContentEnv) - // CircuitBreakingSupport indicates whether circuit breaking support is - // enabled, which can be done by setting the environment variable - // "GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING" to "true". - CircuitBreakingSupport = strings.EqualFold(os.Getenv(circuitBreakingSupportEnv), "true") - // TimeoutSupport indicates whether support for max_stream_duration in - // route actions is enabled. This can be enabled by setting the - // environment variable "GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT" to "true". - TimeoutSupport = strings.EqualFold(os.Getenv(timeoutSupportEnv), "true") - // FaultInjectionSupport is used to control both fault injection and HTTP - // filter support. - FaultInjectionSupport = strings.EqualFold(os.Getenv(faultInjectionSupportEnv), "true") + // RingHashSupport indicates whether ring hash support is enabled, which can + // be disabled by setting the environment variable + // "GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH" to "false". + RingHashSupport = !strings.EqualFold(os.Getenv(ringHashSupportEnv), "false") // ClientSideSecuritySupport is used to control processing of security // configuration on the client-side. // // Note that there is no env var protection for the server-side because we // have a brand new API on the server-side and users explicitly need to use // the new API to get security integration on the server. - ClientSideSecuritySupport = strings.EqualFold(os.Getenv(clientSideSecuritySupportEnv), "true") + ClientSideSecuritySupport = !strings.EqualFold(os.Getenv(clientSideSecuritySupportEnv), "false") + // AggregateAndDNSSupportEnv indicates whether processing of aggregated + // cluster and DNS cluster is enabled, which can be enabled by setting the + // environment variable + // "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER" to + // "true". + AggregateAndDNSSupportEnv = strings.EqualFold(os.Getenv(aggregateAndDNSSupportEnv), "true") + + // RetrySupport indicates whether xDS retry is enabled. + RetrySupport = !strings.EqualFold(os.Getenv(retrySupportEnv), "false") + + // RBACSupport indicates whether xDS configured RBAC HTTP Filter is enabled. + RBACSupport = strings.EqualFold(os.Getenv(rbacSupportEnv), "true") + + // C2PResolverSupport indicates whether support for C2P resolver is enabled. + // This can be enabled by setting the environment variable + // "GRPC_EXPERIMENTAL_GOOGLE_C2P_RESOLVER" to "true". + C2PResolverSupport = strings.EqualFold(os.Getenv(c2pResolverSupportEnv), "true") + // C2PResolverTestOnlyTrafficDirectorURI is the TD URI for testing. + C2PResolverTestOnlyTrafficDirectorURI = os.Getenv(c2pResolverTestOnlyTrafficDirectorURIEnv) ) diff --git a/internal/xds/matcher/matcher_header.go b/internal/xds/matcher/matcher_header.go new file mode 100644 index 00000000000..35a22adadcf --- /dev/null +++ b/internal/xds/matcher/matcher_header.go @@ -0,0 +1,253 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package matcher + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "google.golang.org/grpc/metadata" +) + +// HeaderMatcher is an interface for header matchers. These are +// documented in (EnvoyProxy link here?). These matchers will match on different +// aspects of HTTP header name/value pairs. +type HeaderMatcher interface { + Match(metadata.MD) bool + String() string +} + +// mdValuesFromOutgoingCtx retrieves metadata from context. If there are +// multiple values, the values are concatenated with "," (comma and no space). +// +// All header matchers only match against the comma-concatenated string. +func mdValuesFromOutgoingCtx(md metadata.MD, key string) (string, bool) { + vs, ok := md[key] + if !ok { + return "", false + } + return strings.Join(vs, ","), true +} + +// HeaderExactMatcher matches on an exact match of the value of the header. +type HeaderExactMatcher struct { + key string + exact string +} + +// NewHeaderExactMatcher returns a new HeaderExactMatcher. +func NewHeaderExactMatcher(key, exact string) *HeaderExactMatcher { + return &HeaderExactMatcher{key: key, exact: exact} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderExactMatcher. +func (hem *HeaderExactMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hem.key) + if !ok { + return false + } + return v == hem.exact +} + +func (hem *HeaderExactMatcher) String() string { + return fmt.Sprintf("headerExact:%v:%v", hem.key, hem.exact) +} + +// HeaderRegexMatcher matches on whether the entire request header value matches +// the regex. +type HeaderRegexMatcher struct { + key string + re *regexp.Regexp +} + +// NewHeaderRegexMatcher returns a new HeaderRegexMatcher. +func NewHeaderRegexMatcher(key string, re *regexp.Regexp) *HeaderRegexMatcher { + return &HeaderRegexMatcher{key: key, re: re} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderRegexMatcher. +func (hrm *HeaderRegexMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hrm.key) + if !ok { + return false + } + return hrm.re.MatchString(v) +} + +func (hrm *HeaderRegexMatcher) String() string { + return fmt.Sprintf("headerRegex:%v:%v", hrm.key, hrm.re.String()) +} + +// HeaderRangeMatcher matches on whether the request header value is within the +// range. The header value must be an integer in base 10 notation. +type HeaderRangeMatcher struct { + key string + start, end int64 // represents [start, end). +} + +// NewHeaderRangeMatcher returns a new HeaderRangeMatcher. +func NewHeaderRangeMatcher(key string, start, end int64) *HeaderRangeMatcher { + return &HeaderRangeMatcher{key: key, start: start, end: end} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderRangeMatcher. +func (hrm *HeaderRangeMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hrm.key) + if !ok { + return false + } + if i, err := strconv.ParseInt(v, 10, 64); err == nil && i >= hrm.start && i < hrm.end { + return true + } + return false +} + +func (hrm *HeaderRangeMatcher) String() string { + return fmt.Sprintf("headerRange:%v:[%d,%d)", hrm.key, hrm.start, hrm.end) +} + +// HeaderPresentMatcher will match based on whether the header is present in the +// whole request. +type HeaderPresentMatcher struct { + key string + present bool +} + +// NewHeaderPresentMatcher returns a new HeaderPresentMatcher. +func NewHeaderPresentMatcher(key string, present bool) *HeaderPresentMatcher { + return &HeaderPresentMatcher{key: key, present: present} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderPresentMatcher. +func (hpm *HeaderPresentMatcher) Match(md metadata.MD) bool { + vs, ok := mdValuesFromOutgoingCtx(md, hpm.key) + present := ok && len(vs) > 0 + return present == hpm.present +} + +func (hpm *HeaderPresentMatcher) String() string { + return fmt.Sprintf("headerPresent:%v:%v", hpm.key, hpm.present) +} + +// HeaderPrefixMatcher matches on whether the prefix of the header value matches +// the prefix passed into this struct. +type HeaderPrefixMatcher struct { + key string + prefix string +} + +// NewHeaderPrefixMatcher returns a new HeaderPrefixMatcher. +func NewHeaderPrefixMatcher(key string, prefix string) *HeaderPrefixMatcher { + return &HeaderPrefixMatcher{key: key, prefix: prefix} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderPrefixMatcher. +func (hpm *HeaderPrefixMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hpm.key) + if !ok { + return false + } + return strings.HasPrefix(v, hpm.prefix) +} + +func (hpm *HeaderPrefixMatcher) String() string { + return fmt.Sprintf("headerPrefix:%v:%v", hpm.key, hpm.prefix) +} + +// HeaderSuffixMatcher matches on whether the suffix of the header value matches +// the suffix passed into this struct. +type HeaderSuffixMatcher struct { + key string + suffix string +} + +// NewHeaderSuffixMatcher returns a new HeaderSuffixMatcher. +func NewHeaderSuffixMatcher(key string, suffix string) *HeaderSuffixMatcher { + return &HeaderSuffixMatcher{key: key, suffix: suffix} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderSuffixMatcher. +func (hsm *HeaderSuffixMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hsm.key) + if !ok { + return false + } + return strings.HasSuffix(v, hsm.suffix) +} + +func (hsm *HeaderSuffixMatcher) String() string { + return fmt.Sprintf("headerSuffix:%v:%v", hsm.key, hsm.suffix) +} + +// HeaderContainsMatcher matches on whether the header value contains the +// value passed into this struct. +type HeaderContainsMatcher struct { + key string + contains string +} + +// NewHeaderContainsMatcher returns a new HeaderContainsMatcher. key is the HTTP +// Header key to match on, and contains is the value that the header should +// should contain for a successful match. An empty contains string does not +// work, use HeaderPresentMatcher in that case. +func NewHeaderContainsMatcher(key string, contains string) *HeaderContainsMatcher { + return &HeaderContainsMatcher{key: key, contains: contains} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderContainsMatcher. +func (hcm *HeaderContainsMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hcm.key) + if !ok { + return false + } + return strings.Contains(v, hcm.contains) +} + +func (hcm *HeaderContainsMatcher) String() string { + return fmt.Sprintf("headerContains:%v%v", hcm.key, hcm.contains) +} + +// InvertMatcher inverts the match result of the underlying header matcher. +type InvertMatcher struct { + m HeaderMatcher +} + +// NewInvertMatcher returns a new InvertMatcher. +func NewInvertMatcher(m HeaderMatcher) *InvertMatcher { + return &InvertMatcher{m: m} +} + +// Match returns whether the passed in HTTP Headers match according to the +// InvertMatcher. +func (i *InvertMatcher) Match(md metadata.MD) bool { + return !i.m.Match(md) +} + +func (i *InvertMatcher) String() string { + return fmt.Sprintf("invert{%s}", i.m) +} diff --git a/xds/internal/resolver/matcher_header_test.go b/internal/xds/matcher/matcher_header_test.go similarity index 88% rename from xds/internal/resolver/matcher_header_test.go rename to internal/xds/matcher/matcher_header_test.go index fb87cc5dd32..9a0d51300d0 100644 --- a/xds/internal/resolver/matcher_header_test.go +++ b/internal/xds/matcher/matcher_header_test.go @@ -16,7 +16,7 @@ * */ -package resolver +package matcher import ( "regexp" @@ -64,8 +64,8 @@ func TestHeaderExactMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hem := newHeaderExactMatcher(tt.key, tt.exact) - if got := hem.match(tt.md); got != tt.want { + hem := NewHeaderExactMatcher(tt.key, tt.exact) + if got := hem.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -110,8 +110,8 @@ func TestHeaderRegexMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hrm := newHeaderRegexMatcher(tt.key, regexp.MustCompile(tt.regexStr)) - if got := hrm.match(tt.md); got != tt.want { + hrm := NewHeaderRegexMatcher(tt.key, regexp.MustCompile(tt.regexStr)) + if got := hrm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -157,8 +157,8 @@ func TestHeaderRangeMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hrm := newHeaderRangeMatcher(tt.key, tt.start, tt.end) - if got := hrm.match(tt.md); got != tt.want { + hrm := NewHeaderRangeMatcher(tt.key, tt.start, tt.end) + if got := hrm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -204,8 +204,8 @@ func TestHeaderPresentMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hpm := newHeaderPresentMatcher(tt.key, tt.present) - if got := hpm.match(tt.md); got != tt.want { + hpm := NewHeaderPresentMatcher(tt.key, tt.present) + if got := hpm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -250,8 +250,8 @@ func TestHeaderPrefixMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hpm := newHeaderPrefixMatcher(tt.key, tt.prefix) - if got := hpm.match(tt.md); got != tt.want { + hpm := NewHeaderPrefixMatcher(tt.key, tt.prefix) + if got := hpm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -296,8 +296,8 @@ func TestHeaderSuffixMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hsm := newHeaderSuffixMatcher(tt.key, tt.suffix) - if got := hsm.match(tt.md); got != tt.want { + hsm := NewHeaderSuffixMatcher(tt.key, tt.suffix) + if got := hsm.Match(tt.md); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -307,24 +307,24 @@ func TestHeaderSuffixMatcherMatch(t *testing.T) { func TestInvertMatcherMatch(t *testing.T) { tests := []struct { name string - m headerMatcherInterface + m HeaderMatcher md metadata.MD }{ { name: "true->false", - m: newHeaderExactMatcher("th", "tv"), + m: NewHeaderExactMatcher("th", "tv"), md: metadata.Pairs("th", "tv"), }, { name: "false->true", - m: newHeaderExactMatcher("th", "abc"), + m: NewHeaderExactMatcher("th", "abc"), md: metadata.Pairs("th", "tv"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := newInvertMatcher(tt.m).match(tt.md) - want := !tt.m.match(tt.md) + got := NewInvertMatcher(tt.m).Match(tt.md) + want := !tt.m.Match(tt.md) if got != want { t.Errorf("match() = %v, want %v", got, want) } diff --git a/internal/xds/matcher/string_matcher.go b/internal/xds/matcher/string_matcher.go new file mode 100644 index 00000000000..d7df6a1e2b4 --- /dev/null +++ b/internal/xds/matcher/string_matcher.go @@ -0,0 +1,183 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package matcher contains types that need to be shared between code under +// google.golang.org/grpc/xds/... and the rest of gRPC. +package matcher + +import ( + "errors" + "fmt" + "regexp" + "strings" + + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" +) + +// StringMatcher contains match criteria for matching a string, and is an +// internal representation of the `StringMatcher` proto defined at +// https://github.com/envoyproxy/envoy/blob/main/api/envoy/type/matcher/v3/string.proto. +type StringMatcher struct { + // Since these match fields are part of a `oneof` in the corresponding xDS + // proto, only one of them is expected to be set. + exactMatch *string + prefixMatch *string + suffixMatch *string + regexMatch *regexp.Regexp + containsMatch *string + // If true, indicates the exact/prefix/suffix/contains matching should be + // case insensitive. This has no effect on the regex match. + ignoreCase bool +} + +// Match returns true if input matches the criteria in the given StringMatcher. +func (sm StringMatcher) Match(input string) bool { + if sm.ignoreCase { + input = strings.ToLower(input) + } + switch { + case sm.exactMatch != nil: + return input == *sm.exactMatch + case sm.prefixMatch != nil: + return strings.HasPrefix(input, *sm.prefixMatch) + case sm.suffixMatch != nil: + return strings.HasSuffix(input, *sm.suffixMatch) + case sm.regexMatch != nil: + return sm.regexMatch.MatchString(input) + case sm.containsMatch != nil: + return strings.Contains(input, *sm.containsMatch) + } + return false +} + +// StringMatcherFromProto is a helper function to create a StringMatcher from +// the corresponding StringMatcher proto. +// +// Returns a non-nil error if matcherProto is invalid. +func StringMatcherFromProto(matcherProto *v3matcherpb.StringMatcher) (StringMatcher, error) { + if matcherProto == nil { + return StringMatcher{}, errors.New("input StringMatcher proto is nil") + } + + matcher := StringMatcher{ignoreCase: matcherProto.GetIgnoreCase()} + switch mt := matcherProto.GetMatchPattern().(type) { + case *v3matcherpb.StringMatcher_Exact: + matcher.exactMatch = &mt.Exact + if matcher.ignoreCase { + *matcher.exactMatch = strings.ToLower(*matcher.exactMatch) + } + case *v3matcherpb.StringMatcher_Prefix: + if matcherProto.GetPrefix() == "" { + return StringMatcher{}, errors.New("empty prefix is not allowed in StringMatcher") + } + matcher.prefixMatch = &mt.Prefix + if matcher.ignoreCase { + *matcher.prefixMatch = strings.ToLower(*matcher.prefixMatch) + } + case *v3matcherpb.StringMatcher_Suffix: + if matcherProto.GetSuffix() == "" { + return StringMatcher{}, errors.New("empty suffix is not allowed in StringMatcher") + } + matcher.suffixMatch = &mt.Suffix + if matcher.ignoreCase { + *matcher.suffixMatch = strings.ToLower(*matcher.suffixMatch) + } + case *v3matcherpb.StringMatcher_SafeRegex: + regex := matcherProto.GetSafeRegex().GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return StringMatcher{}, fmt.Errorf("safe_regex matcher %q is invalid", regex) + } + matcher.regexMatch = re + case *v3matcherpb.StringMatcher_Contains: + if matcherProto.GetContains() == "" { + return StringMatcher{}, errors.New("empty contains is not allowed in StringMatcher") + } + matcher.containsMatch = &mt.Contains + if matcher.ignoreCase { + *matcher.containsMatch = strings.ToLower(*matcher.containsMatch) + } + default: + return StringMatcher{}, fmt.Errorf("unrecognized string matcher: %+v", matcherProto) + } + return matcher, nil +} + +// StringMatcherForTesting is a helper function to create a StringMatcher based +// on the given arguments. Intended only for testing purposes. +func StringMatcherForTesting(exact, prefix, suffix, contains *string, regex *regexp.Regexp, ignoreCase bool) StringMatcher { + sm := StringMatcher{ + exactMatch: exact, + prefixMatch: prefix, + suffixMatch: suffix, + regexMatch: regex, + containsMatch: contains, + ignoreCase: ignoreCase, + } + if ignoreCase { + switch { + case sm.exactMatch != nil: + *sm.exactMatch = strings.ToLower(*exact) + case sm.prefixMatch != nil: + *sm.prefixMatch = strings.ToLower(*prefix) + case sm.suffixMatch != nil: + *sm.suffixMatch = strings.ToLower(*suffix) + case sm.containsMatch != nil: + *sm.containsMatch = strings.ToLower(*contains) + } + } + return sm +} + +// ExactMatch returns the value of the configured exact match or an empty string +// if exact match criteria was not specified. +func (sm StringMatcher) ExactMatch() string { + if sm.exactMatch != nil { + return *sm.exactMatch + } + return "" +} + +// Equal returns true if other and sm are equivalent to each other. +func (sm StringMatcher) Equal(other StringMatcher) bool { + if sm.ignoreCase != other.ignoreCase { + return false + } + + if (sm.exactMatch != nil) != (other.exactMatch != nil) || + (sm.prefixMatch != nil) != (other.prefixMatch != nil) || + (sm.suffixMatch != nil) != (other.suffixMatch != nil) || + (sm.regexMatch != nil) != (other.regexMatch != nil) || + (sm.containsMatch != nil) != (other.containsMatch != nil) { + return false + } + + switch { + case sm.exactMatch != nil: + return *sm.exactMatch == *other.exactMatch + case sm.prefixMatch != nil: + return *sm.prefixMatch == *other.prefixMatch + case sm.suffixMatch != nil: + return *sm.suffixMatch == *other.suffixMatch + case sm.regexMatch != nil: + return sm.regexMatch.String() == other.regexMatch.String() + case sm.containsMatch != nil: + return *sm.containsMatch == *other.containsMatch + } + return true +} diff --git a/internal/xds/matcher/string_matcher_test.go b/internal/xds/matcher/string_matcher_test.go new file mode 100644 index 00000000000..389963b94e9 --- /dev/null +++ b/internal/xds/matcher/string_matcher_test.go @@ -0,0 +1,309 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package matcher + +import ( + "regexp" + "testing" + + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + "github.com/google/go-cmp/cmp" +) + +func TestStringMatcherFromProto(t *testing.T) { + tests := []struct { + desc string + inputProto *v3matcherpb.StringMatcher + wantMatcher StringMatcher + wantErr bool + }{ + { + desc: "nil proto", + wantErr: true, + }, + { + desc: "empty prefix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}, + }, + wantErr: true, + }, + { + desc: "empty suffix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: ""}, + }, + wantErr: true, + }, + { + desc: "empty contains", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: ""}, + }, + wantErr: true, + }, + { + desc: "invalid regex", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{ + SafeRegex: &v3matcherpb.RegexMatcher{Regex: "??"}, + }, + }, + wantErr: true, + }, + { + desc: "happy case exact", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "exact"}, + }, + wantMatcher: StringMatcher{exactMatch: newStringP("exact")}, + }, + { + desc: "happy case exact ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "EXACT"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + exactMatch: newStringP("exact"), + ignoreCase: true, + }, + }, + { + desc: "happy case prefix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "prefix"}, + }, + wantMatcher: StringMatcher{prefixMatch: newStringP("prefix")}, + }, + { + desc: "happy case prefix ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "PREFIX"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + prefixMatch: newStringP("prefix"), + ignoreCase: true, + }, + }, + { + desc: "happy case suffix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "suffix"}, + }, + wantMatcher: StringMatcher{suffixMatch: newStringP("suffix")}, + }, + { + desc: "happy case suffix ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "SUFFIX"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + suffixMatch: newStringP("suffix"), + ignoreCase: true, + }, + }, + { + desc: "happy case regex", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{ + SafeRegex: &v3matcherpb.RegexMatcher{Regex: "good?regex?"}, + }, + }, + wantMatcher: StringMatcher{regexMatch: regexp.MustCompile("good?regex?")}, + }, + { + desc: "happy case contains", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "contains"}, + }, + wantMatcher: StringMatcher{containsMatch: newStringP("contains")}, + }, + { + desc: "happy case contains ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "CONTAINS"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + containsMatch: newStringP("contains"), + ignoreCase: true, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotMatcher, err := StringMatcherFromProto(test.inputProto) + if (err != nil) != test.wantErr { + t.Fatalf("StringMatcherFromProto(%+v) returned err: %v, wantErr: %v", test.inputProto, err, test.wantErr) + } + if diff := cmp.Diff(gotMatcher, test.wantMatcher, cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Fatalf("StringMatcherFromProto(%+v) returned unexpected diff (-got, +want):\n%s", test.inputProto, diff) + } + }) + } +} + +func TestMatch(t *testing.T) { + var ( + exactMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "exact"}}) + prefixMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "prefix"}}) + suffixMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "suffix"}}) + regexMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: "good?regex?"}}}) + containsMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "contains"}}) + exactMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "exact"}, + IgnoreCase: true, + }) + prefixMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "prefix"}, + IgnoreCase: true, + }) + suffixMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "suffix"}, + IgnoreCase: true, + }) + containsMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "contains"}, + IgnoreCase: true, + }) + ) + + tests := []struct { + desc string + matcher StringMatcher + input string + wantMatch bool + }{ + { + desc: "exact match success", + matcher: exactMatcher, + input: "exact", + wantMatch: true, + }, + { + desc: "exact match failure", + matcher: exactMatcher, + input: "not-exact", + }, + { + desc: "exact match success with ignore case", + matcher: exactMatcherIgnoreCase, + input: "EXACT", + wantMatch: true, + }, + { + desc: "exact match failure with ignore case", + matcher: exactMatcherIgnoreCase, + input: "not-exact", + }, + { + desc: "prefix match success", + matcher: prefixMatcher, + input: "prefixIsHere", + wantMatch: true, + }, + { + desc: "prefix match failure", + matcher: prefixMatcher, + input: "not-prefix", + }, + { + desc: "prefix match success with ignore case", + matcher: prefixMatcherIgnoreCase, + input: "PREFIXisHere", + wantMatch: true, + }, + { + desc: "prefix match failure with ignore case", + matcher: prefixMatcherIgnoreCase, + input: "not-PREFIX", + }, + { + desc: "suffix match success", + matcher: suffixMatcher, + input: "hereIsThesuffix", + wantMatch: true, + }, + { + desc: "suffix match failure", + matcher: suffixMatcher, + input: "suffix-is-not-here", + }, + { + desc: "suffix match success with ignore case", + matcher: suffixMatcherIgnoreCase, + input: "hereIsTheSuFFix", + wantMatch: true, + }, + { + desc: "suffix match failure with ignore case", + matcher: suffixMatcherIgnoreCase, + input: "SUFFIX-is-not-here", + }, + { + desc: "regex match success", + matcher: regexMatcher, + input: "goodregex", + wantMatch: true, + }, + { + desc: "regex match failure", + matcher: regexMatcher, + input: "regex-is-not-here", + }, + { + desc: "contains match success", + matcher: containsMatcher, + input: "IScontainsHERE", + wantMatch: true, + }, + { + desc: "contains match failure", + matcher: containsMatcher, + input: "con-tains-is-not-here", + }, + { + desc: "contains match success with ignore case", + matcher: containsMatcherIgnoreCase, + input: "isCONTAINShere", + wantMatch: true, + }, + { + desc: "contains match failure with ignore case", + matcher: containsMatcherIgnoreCase, + input: "CON-TAINS-is-not-here", + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if gotMatch := test.matcher.Match(test.input); gotMatch != test.wantMatch { + t.Errorf("StringMatcher.Match(%s) returned %v, want %v", test.input, gotMatch, test.wantMatch) + } + }) + } +} + +func newStringP(s string) *string { + return &s +} diff --git a/internal/xds/rbac/matchers.go b/internal/xds/rbac/matchers.go new file mode 100644 index 00000000000..28dabf46591 --- /dev/null +++ b/internal/xds/rbac/matchers.go @@ -0,0 +1,426 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rbac + +import ( + "errors" + "fmt" + "net" + "regexp" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3route_componentspb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + internalmatcher "google.golang.org/grpc/internal/xds/matcher" +) + +// matcher is an interface that takes data about incoming RPC's and returns +// whether it matches with whatever matcher implements this interface. +type matcher interface { + match(data *rpcData) bool +} + +// policyMatcher helps determine whether an incoming RPC call matches a policy. +// A policy is a logical role (e.g. Service Admin), which is comprised of +// permissions and principals. A principal is an identity (or identities) for a +// downstream subject which are assigned the policy (role), and a permission is +// an action(s) that a principal(s) can take. A policy matches if both a +// permission and a principal match, which will be determined by the child or +// permissions and principal matchers. policyMatcher implements the matcher +// interface. +type policyMatcher struct { + permissions *orMatcher + principals *orMatcher +} + +func newPolicyMatcher(policy *v3rbacpb.Policy) (*policyMatcher, error) { + permissions, err := matchersFromPermissions(policy.Permissions) + if err != nil { + return nil, err + } + principals, err := matchersFromPrincipals(policy.Principals) + if err != nil { + return nil, err + } + return &policyMatcher{ + permissions: &orMatcher{matchers: permissions}, + principals: &orMatcher{matchers: principals}, + }, nil +} + +func (pm *policyMatcher) match(data *rpcData) bool { + // A policy matches if and only if at least one of its permissions match the + // action taking place AND at least one if its principals match the + // downstream peer. + return pm.permissions.match(data) && pm.principals.match(data) +} + +// matchersFromPermissions takes a list of permissions (can also be +// a single permission, e.g. from a not matcher which is logically !permission) +// and returns a list of matchers which correspond to that permission. This will +// be called in many instances throughout the initial construction of the RBAC +// engine from the AND and OR matchers and also from the NOT matcher. +func matchersFromPermissions(permissions []*v3rbacpb.Permission) ([]matcher, error) { + var matchers []matcher + for _, permission := range permissions { + switch permission.GetRule().(type) { + case *v3rbacpb.Permission_AndRules: + mList, err := matchersFromPermissions(permission.GetAndRules().Rules) + if err != nil { + return nil, err + } + matchers = append(matchers, &andMatcher{matchers: mList}) + case *v3rbacpb.Permission_OrRules: + mList, err := matchersFromPermissions(permission.GetOrRules().Rules) + if err != nil { + return nil, err + } + matchers = append(matchers, &orMatcher{matchers: mList}) + case *v3rbacpb.Permission_Any: + matchers = append(matchers, &alwaysMatcher{}) + case *v3rbacpb.Permission_Header: + m, err := newHeaderMatcher(permission.GetHeader()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Permission_UrlPath: + m, err := newURLPathMatcher(permission.GetUrlPath()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Permission_DestinationIp: + // Due to this being on server side, the destination IP is the local + // IP. + m, err := newLocalIPMatcher(permission.GetDestinationIp()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Permission_DestinationPort: + matchers = append(matchers, newPortMatcher(permission.GetDestinationPort())) + case *v3rbacpb.Permission_NotRule: + mList, err := matchersFromPermissions([]*v3rbacpb.Permission{{Rule: permission.GetNotRule().Rule}}) + if err != nil { + return nil, err + } + matchers = append(matchers, ¬Matcher{matcherToNot: mList[0]}) + case *v3rbacpb.Permission_Metadata: + // Not supported in gRPC RBAC currently - a permission typed as + // Metadata in the initial config will be a no-op. + case *v3rbacpb.Permission_RequestedServerName: + // Not supported in gRPC RBAC currently - a permission typed as + // requested server name in the initial config will be a no-op. + } + } + return matchers, nil +} + +func matchersFromPrincipals(principals []*v3rbacpb.Principal) ([]matcher, error) { + var matchers []matcher + for _, principal := range principals { + switch principal.GetIdentifier().(type) { + case *v3rbacpb.Principal_AndIds: + mList, err := matchersFromPrincipals(principal.GetAndIds().Ids) + if err != nil { + return nil, err + } + matchers = append(matchers, &andMatcher{matchers: mList}) + case *v3rbacpb.Principal_OrIds: + mList, err := matchersFromPrincipals(principal.GetOrIds().Ids) + if err != nil { + return nil, err + } + matchers = append(matchers, &orMatcher{matchers: mList}) + case *v3rbacpb.Principal_Any: + matchers = append(matchers, &alwaysMatcher{}) + case *v3rbacpb.Principal_Authenticated_: + authenticatedMatcher, err := newAuthenticatedMatcher(principal.GetAuthenticated()) + if err != nil { + return nil, err + } + matchers = append(matchers, authenticatedMatcher) + case *v3rbacpb.Principal_DirectRemoteIp: + m, err := newRemoteIPMatcher(principal.GetDirectRemoteIp()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_Header: + // Do we need an error here? + m, err := newHeaderMatcher(principal.GetHeader()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_UrlPath: + m, err := newURLPathMatcher(principal.GetUrlPath()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_NotId: + mList, err := matchersFromPrincipals([]*v3rbacpb.Principal{{Identifier: principal.GetNotId().Identifier}}) + if err != nil { + return nil, err + } + matchers = append(matchers, ¬Matcher{matcherToNot: mList[0]}) + case *v3rbacpb.Principal_SourceIp: + // The source ip principal identifier is deprecated. Thus, a + // principal typed as a source ip in the identifier will be a no-op. + // The config should use DirectRemoteIp instead. + case *v3rbacpb.Principal_RemoteIp: + // RBAC in gRPC treats direct_remote_ip and remote_ip as logically + // equivalent, as per A41. + m, err := newRemoteIPMatcher(principal.GetRemoteIp()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_Metadata: + // Not supported in gRPC RBAC currently - a principal typed as + // Metadata in the initial config will be a no-op. + } + } + return matchers, nil +} + +// orMatcher is a matcher where it successfully matches if one of it's +// children successfully match. It also logically represents a principal or +// permission, but can also be it's own entity further down the tree of +// matchers. orMatcher implements the matcher interface. +type orMatcher struct { + matchers []matcher +} + +func (om *orMatcher) match(data *rpcData) bool { + // Range through child matchers and pass in data about incoming RPC, and + // only one child matcher has to match to be logically successful. + for _, m := range om.matchers { + if m.match(data) { + return true + } + } + return false +} + +// andMatcher is a matcher that is successful if every child matcher +// matches. andMatcher implements the matcher interface. +type andMatcher struct { + matchers []matcher +} + +func (am *andMatcher) match(data *rpcData) bool { + for _, m := range am.matchers { + if !m.match(data) { + return false + } + } + return true +} + +// alwaysMatcher is a matcher that will always match. This logically +// represents an any rule for a permission or a principal. alwaysMatcher +// implements the matcher interface. +type alwaysMatcher struct { +} + +func (am *alwaysMatcher) match(data *rpcData) bool { + return true +} + +// notMatcher is a matcher that nots an underlying matcher. notMatcher +// implements the matcher interface. +type notMatcher struct { + matcherToNot matcher +} + +func (nm *notMatcher) match(data *rpcData) bool { + return !nm.matcherToNot.match(data) +} + +// headerMatcher is a matcher that matches on incoming HTTP Headers present +// in the incoming RPC. headerMatcher implements the matcher interface. +type headerMatcher struct { + matcher internalmatcher.HeaderMatcher +} + +func newHeaderMatcher(headerMatcherConfig *v3route_componentspb.HeaderMatcher) (*headerMatcher, error) { + var m internalmatcher.HeaderMatcher + switch headerMatcherConfig.HeaderMatchSpecifier.(type) { + case *v3route_componentspb.HeaderMatcher_ExactMatch: + m = internalmatcher.NewHeaderExactMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetExactMatch()) + case *v3route_componentspb.HeaderMatcher_SafeRegexMatch: + regex, err := regexp.Compile(headerMatcherConfig.GetSafeRegexMatch().Regex) + if err != nil { + return nil, err + } + m = internalmatcher.NewHeaderRegexMatcher(headerMatcherConfig.Name, regex) + case *v3route_componentspb.HeaderMatcher_RangeMatch: + m = internalmatcher.NewHeaderRangeMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetRangeMatch().Start, headerMatcherConfig.GetRangeMatch().End) + case *v3route_componentspb.HeaderMatcher_PresentMatch: + m = internalmatcher.NewHeaderPresentMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetPresentMatch()) + case *v3route_componentspb.HeaderMatcher_PrefixMatch: + m = internalmatcher.NewHeaderPrefixMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetPrefixMatch()) + case *v3route_componentspb.HeaderMatcher_SuffixMatch: + m = internalmatcher.NewHeaderSuffixMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetSuffixMatch()) + case *v3route_componentspb.HeaderMatcher_ContainsMatch: + m = internalmatcher.NewHeaderContainsMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetContainsMatch()) + default: + return nil, errors.New("unknown header matcher type") + } + if headerMatcherConfig.InvertMatch { + m = internalmatcher.NewInvertMatcher(m) + } + return &headerMatcher{matcher: m}, nil +} + +func (hm *headerMatcher) match(data *rpcData) bool { + return hm.matcher.Match(data.md) +} + +// urlPathMatcher matches on the URL Path of the incoming RPC. In gRPC, this +// logically maps to the full method name the RPC is calling on the server side. +// urlPathMatcher implements the matcher interface. +type urlPathMatcher struct { + stringMatcher internalmatcher.StringMatcher +} + +func newURLPathMatcher(pathMatcher *v3matcherpb.PathMatcher) (*urlPathMatcher, error) { + stringMatcher, err := internalmatcher.StringMatcherFromProto(pathMatcher.GetPath()) + if err != nil { + return nil, err + } + return &urlPathMatcher{stringMatcher: stringMatcher}, nil +} + +func (upm *urlPathMatcher) match(data *rpcData) bool { + return upm.stringMatcher.Match(data.fullMethod) +} + +// remoteIPMatcher and localIPMatcher both are matchers that match against +// a CIDR Range. Two different matchers are needed as the remote and destination +// ip addresses come from different parts of the data about incoming RPC's +// passed in. Matching a CIDR Range means to determine whether the IP Address +// falls within the CIDR Range or not. They both implement the matcher +// interface. +type remoteIPMatcher struct { + // ipNet represents the CidrRange that this matcher was configured with. + // This is what will remote and destination IP's will be matched against. + ipNet *net.IPNet +} + +func newRemoteIPMatcher(cidrRange *v3corepb.CidrRange) (*remoteIPMatcher, error) { + // Convert configuration to a cidrRangeString, as Go standard library has + // methods that parse cidr string. + cidrRangeString := fmt.Sprintf("%s/%d", cidrRange.AddressPrefix, cidrRange.PrefixLen.Value) + _, ipNet, err := net.ParseCIDR(cidrRangeString) + if err != nil { + return nil, err + } + return &remoteIPMatcher{ipNet: ipNet}, nil +} + +func (sim *remoteIPMatcher) match(data *rpcData) bool { + return sim.ipNet.Contains(net.IP(net.ParseIP(data.peerInfo.Addr.String()))) +} + +type localIPMatcher struct { + ipNet *net.IPNet +} + +func newLocalIPMatcher(cidrRange *v3corepb.CidrRange) (*localIPMatcher, error) { + cidrRangeString := fmt.Sprintf("%s/%d", cidrRange.AddressPrefix, cidrRange.PrefixLen.Value) + _, ipNet, err := net.ParseCIDR(cidrRangeString) + if err != nil { + return nil, err + } + return &localIPMatcher{ipNet: ipNet}, nil +} + +func (dim *localIPMatcher) match(data *rpcData) bool { + return dim.ipNet.Contains(net.IP(net.ParseIP(data.localAddr.String()))) +} + +// portMatcher matches on whether the destination port of the RPC matches the +// destination port this matcher was instantiated with. portMatcher +// implements the matcher interface. +type portMatcher struct { + destinationPort uint32 +} + +func newPortMatcher(destinationPort uint32) *portMatcher { + return &portMatcher{destinationPort: destinationPort} +} + +func (pm *portMatcher) match(data *rpcData) bool { + return data.destinationPort == pm.destinationPort +} + +// authenticatedMatcher matches on the name of the Principal. If set, the URI +// SAN or DNS SAN in that order is used from the certificate, otherwise the +// subject field is used. If unset, it applies to any user that is +// authenticated. authenticatedMatcher implements the matcher interface. +type authenticatedMatcher struct { + stringMatcher *internalmatcher.StringMatcher +} + +func newAuthenticatedMatcher(authenticatedMatcherConfig *v3rbacpb.Principal_Authenticated) (*authenticatedMatcher, error) { + // Represents this line in the RBAC documentation = "If unset, it applies to + // any user that is authenticated" (see package-level comments). + if authenticatedMatcherConfig.PrincipalName == nil { + return &authenticatedMatcher{}, nil + } + stringMatcher, err := internalmatcher.StringMatcherFromProto(authenticatedMatcherConfig.PrincipalName) + if err != nil { + return nil, err + } + return &authenticatedMatcher{stringMatcher: &stringMatcher}, nil +} + +func (am *authenticatedMatcher) match(data *rpcData) bool { + // Represents this line in the RBAC documentation = "If unset, it applies to + // any user that is authenticated" (see package-level comments). An + // authenticated downstream in a stateful TLS connection will have to + // provide a certificate to prove their identity. Thus, you can simply check + // if there is a certificate present. + if am.stringMatcher == nil { + return len(data.certs) != 0 + } + // "If there is no client certificate (thus no SAN nor Subject), check if "" + // (empty string) matches. If it matches, the principal_name is said to + // match" - A41 + if len(data.certs) == 0 { + return am.stringMatcher.Match("") + } + cert := data.certs[0] + // The order of matching as per the RBAC documentation (see package-level comments) + // is as follows: URI SANs, DNS SANs, and then subject name. + for _, uriSAN := range cert.URIs { + if am.stringMatcher.Match(uriSAN.String()) { + return true + } + } + for _, dnsSAN := range cert.DNSNames { + if am.stringMatcher.Match(dnsSAN) { + return true + } + } + return am.stringMatcher.Match(cert.Subject.String()) +} diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go new file mode 100644 index 00000000000..a25f9cfdeef --- /dev/null +++ b/internal/xds/rbac/rbac_engine.go @@ -0,0 +1,225 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package rbac provides service-level and method-level access control for a +// service. See +// https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/rbac/v3/rbac.proto#role-based-access-control-rbac +// for documentation. +package rbac + +import ( + "context" + "crypto/x509" + "errors" + "fmt" + "net" + "strconv" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/transport" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +const logLevel = 2 + +var logger = grpclog.Component("rbac") + +var getConnection = transport.GetConnection + +// ChainEngine represents a chain of RBAC Engines, used to make authorization +// decisions on incoming RPCs. +type ChainEngine struct { + chainedEngines []*engine +} + +// NewChainEngine returns a chain of RBAC engines, used to make authorization +// decisions on incoming RPCs. Returns a non-nil error for invalid policies. +func NewChainEngine(policies []*v3rbacpb.RBAC) (*ChainEngine, error) { + engines := make([]*engine, 0, len(policies)) + for _, policy := range policies { + engine, err := newEngine(policy) + if err != nil { + return nil, err + } + engines = append(engines, engine) + } + return &ChainEngine{chainedEngines: engines}, nil +} + +// IsAuthorized determines if an incoming RPC is authorized based on the chain of RBAC +// engines and their associated actions. +// +// Errors returned by this function are compatible with the status package. +func (cre *ChainEngine) IsAuthorized(ctx context.Context) error { + // This conversion step (i.e. pulling things out of ctx) can be done once, + // and then be used for the whole chain of RBAC Engines. + rpcData, err := newRPCData(ctx) + if err != nil { + logger.Errorf("newRPCData: %v", err) + return status.Errorf(codes.Internal, "gRPC RBAC: %v", err) + } + for _, engine := range cre.chainedEngines { + matchingPolicyName, ok := engine.findMatchingPolicy(rpcData) + if logger.V(logLevel) && ok { + logger.Infof("incoming RPC matched to policy %v in engine with action %v", matchingPolicyName, engine.action) + } + + switch { + case engine.action == v3rbacpb.RBAC_ALLOW && !ok: + return status.Errorf(codes.PermissionDenied, "incoming RPC did not match an allow policy") + case engine.action == v3rbacpb.RBAC_DENY && ok: + return status.Errorf(codes.PermissionDenied, "incoming RPC matched a deny policy %q", matchingPolicyName) + } + // Every policy in the engine list must be queried. Thus, iterate to the + // next policy. + } + // If the incoming RPC gets through all of the engines successfully (i.e. + // doesn't not match an allow or match a deny engine), the RPC is authorized + // to proceed. + return nil +} + +// engine is used for matching incoming RPCs to policies. +type engine struct { + policies map[string]*policyMatcher + // action must be ALLOW or DENY. + action v3rbacpb.RBAC_Action +} + +// newEngine creates an RBAC Engine based on the contents of policy. Returns a +// non-nil error if the policy is invalid. +func newEngine(config *v3rbacpb.RBAC) (*engine, error) { + a := *config.Action.Enum() + if a != v3rbacpb.RBAC_ALLOW && a != v3rbacpb.RBAC_DENY { + return nil, fmt.Errorf("unsupported action %s", config.Action) + } + + policies := make(map[string]*policyMatcher, len(config.Policies)) + for name, policy := range config.Policies { + matcher, err := newPolicyMatcher(policy) + if err != nil { + return nil, err + } + policies[name] = matcher + } + return &engine{ + policies: policies, + action: a, + }, nil +} + +// findMatchingPolicy determines if an incoming RPC matches a policy. On a +// successful match, it returns the name of the matching policy and a true bool +// to specify that there was a matching policy found. It returns false in +// the case of not finding a matching policy. +func (r *engine) findMatchingPolicy(rpcData *rpcData) (string, bool) { + for policy, matcher := range r.policies { + if matcher.match(rpcData) { + return policy, true + } + } + return "", false +} + +// newRPCData takes an incoming context (should be a context representing state +// needed for server RPC Call with metadata, peer info (used for source ip/port +// and TLS information) and connection (used for destination ip/port) piped into +// it) and the method name of the Service being called server side and populates +// an rpcData struct ready to be passed to the RBAC Engine to find a matching +// policy. +func newRPCData(ctx context.Context) (*rpcData, error) { + // The caller should populate all of these fields (i.e. for empty headers, + // pipe an empty md into context). + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, errors.New("missing metadata in incoming context") + } + // ":method can be hard-coded to POST if unavailable" - A41 + md[":method"] = []string{"POST"} + // "If the transport exposes TE in Metadata, then RBAC must special-case the + // header to treat it as not present." - A41 + delete(md, "TE") + + pi, ok := peer.FromContext(ctx) + if !ok { + return nil, errors.New("missing peer info in incoming context") + } + + // The methodName will be available in the passed in ctx from a unary or streaming + // interceptor, as grpc.Server pipes in a transport stream which contains the methodName + // into contexts available in both unary or streaming interceptors. + mn, ok := grpc.Method(ctx) + if !ok { + return nil, errors.New("missing method in incoming context") + } + + // The connection is needed in order to find the destination address and + // port of the incoming RPC Call. + conn := getConnection(ctx) + if conn == nil { + return nil, errors.New("missing connection in incoming context") + } + _, dPort, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + return nil, fmt.Errorf("error parsing local address: %v", err) + } + dp, err := strconv.ParseUint(dPort, 10, 32) + if err != nil { + return nil, fmt.Errorf("error parsing local address: %v", err) + } + + var peerCertificates []*x509.Certificate + if pi.AuthInfo != nil { + tlsInfo, ok := pi.AuthInfo.(credentials.TLSInfo) + if ok { + peerCertificates = tlsInfo.State.PeerCertificates + } + } + + return &rpcData{ + md: md, + peerInfo: pi, + fullMethod: mn, + destinationPort: uint32(dp), + localAddr: conn.LocalAddr(), + certs: peerCertificates, + }, nil +} + +// rpcData wraps data pulled from an incoming RPC that the RBAC engine needs to +// find a matching policy. +type rpcData struct { + // md is the HTTP Headers that are present in the incoming RPC. + md metadata.MD + // peerInfo is information about the downstream peer. + peerInfo *peer.Peer + // fullMethod is the method name being called on the upstream service. + fullMethod string + // destinationPort is the port that the RPC is being sent to on the + // server. + destinationPort uint32 + // localAddr is the address that the RPC is being sent to. + localAddr net.Addr + // certs are the certificates presented by the peer during a TLS + // handshake. + certs []*x509.Certificate +} diff --git a/internal/xds/rbac/rbac_engine_test.go b/internal/xds/rbac/rbac_engine_test.go new file mode 100644 index 00000000000..17832458209 --- /dev/null +++ b/internal/xds/rbac/rbac_engine_test.go @@ -0,0 +1,1007 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rbac + +import ( + "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "net" + "net/url" + "testing" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type addr struct { + ipAddress string +} + +func (addr) Network() string { return "" } +func (a *addr) String() string { return a.ipAddress } + +// TestNewChainEngine tests the construction of the ChainEngine. Due to some +// types of RBAC configuration being logically wrong and returning an error +// rather than successfully constructing the RBAC Engine, this test tests both +// RBAC Configurations deemed successful and also RBAC Configurations that will +// raise errors. +func (s) TestNewChainEngine(t *testing.T) { + tests := []struct { + name string + policies []*v3rbacpb.RBAC + wantErr bool + }{ + { + name: "SuccessCaseAnyMatchSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SuccessCaseAnyMatchMultiple", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + { + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SuccessCaseSimplePolicySingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 8080}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + // SuccessCaseSimplePolicyTwoPolicies tests the construction of the + // chained engines in the case where there are two policies in a list, + // one with an allow policy and one with a deny policy. A situation + // where two policies (allow and deny) is a very common use case for + // this API, and should successfully build. + { + name: "SuccessCaseSimplePolicyTwoPolicies", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 8080}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + { + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 8080}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SuccessCaseEnvoyExampleSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "service-admin": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "cluster.local/ns/default/sa/admin"}}}}}, + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "cluster.local/ns/default/sa/superuser"}}}}}, + }, + }, + "product-viewer": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/products"}}}}}}, + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 80}}, + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 443}}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SourceIpMatcherSuccessSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + }, + { + name: "SourceIpMatcherFailureSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "not a correct address", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "DestinationIpMatcherSuccess", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-destination-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationIp{DestinationIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "DestinationIpMatcherFailure", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-destination-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationIp{DestinationIp: &v3corepb.CidrRange{AddressPrefix: "not a correct address", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "MatcherToNotPolicy", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "not-secret-content": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_NotRule{NotRule: &v3rbacpb.Permission{Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/secret-content"}}}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "MatcherToNotPrinicipal", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "not-from-certain-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_NotId{NotId: &v3rbacpb.Principal{Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}}}, + }, + }, + }, + }, + }, + }, + // PrinicpalProductViewer tests the construction of a chained engine + // with a policy that allows any downstream to send a GET request on a + // certain path. + { + name: "PrincipalProductViewer", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "product-viewer": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + { + Identifier: &v3rbacpb.Principal_AndIds{AndIds: &v3rbacpb.Principal_Set{Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/books"}}}}}}, + {Identifier: &v3rbacpb.Principal_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/cars"}}}}}}, + }, + }}}, + }}}, + }, + }, + }, + }, + }, + }, + }, + // Certain Headers tests the construction of a chained engine with a + // policy that allows any downstream to send an HTTP request with + // certain headers. + { + name: "CertainHeaders", + policies: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-headers": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + { + Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{SafeRegexMatch: &v3matcherpb.RegexMatcher{Regex: "GET"}}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_RangeMatch{RangeMatch: &v3typepb.Int64Range{ + Start: 0, + End: 64, + }}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PresentMatch{PresentMatch: true}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ContainsMatch{ContainsMatch: "GET"}}}}, + }}}, + }, + }, + }, + }, + }, + }, + }, + { + name: "LogAction", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_LOG, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ActionNotSpecified", + policies: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if _, err := NewChainEngine(test.policies); (err != nil) != test.wantErr { + t.Fatalf("NewChainEngine(%+v) returned err: %v, wantErr: %v", test.policies, err, test.wantErr) + } + }) + } +} + +// TestChainEngine tests the chain of RBAC Engines by configuring the chain of +// engines in a certain way in different scenarios. After configuring the chain +// of engines in a certain way, this test pings the chain of engines with +// different types of data representing incoming RPC's (piped into a context), +// and verifies that it works as expected. +func (s) TestChainEngine(t *testing.T) { + defer func(gc func(ctx context.Context) net.Conn) { + getConnection = gc + }(getConnection) + tests := []struct { + name string + rbacConfigs []*v3rbacpb.RBAC + rbacQueries []struct { + rpcData *rpcData + wantStatusCode codes.Code + } + }{ + // SuccessCaseAnyMatch tests a single RBAC Engine instantiated with + // a config with a policy with any rules for both permissions and + // principals, meaning that any data about incoming RPC's that the RBAC + // Engine is queried with should match that policy. + { + name: "SuccessCaseAnyMatch", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + { + rpcData: &rpcData{ + fullMethod: "some method", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + }, + }, + // SuccessCaseSimplePolicy is a test that tests a single policy + // that only allows an rpc to proceed if the rpc is calling with a certain + // path. + { + name: "SuccessCaseSimplePolicy", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This RPC should match with the local host fan policy. Thus, + // this RPC should be allowed to proceed. + { + rpcData: &rpcData{ + fullMethod: "localhost-fan-page", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + + // This RPC shouldn't match with the local host fan policy. Thus, + // this rpc shouldn't be allowed to proceed. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // SuccessCaseEnvoyExample is a test based on the example provided + // in the EnvoyProxy docs. The RBAC Config contains two policies, + // service admin and product viewer, that provides an example of a real + // RBAC Config that might be configured for a given for a given backend + // service. + { + name: "SuccessCaseEnvoyExample", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "service-admin": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "//cluster.local/ns/default/sa/admin"}}}}}, + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "//cluster.local/ns/default/sa/superuser"}}}}}, + }, + }, + "product-viewer": { + Permissions: []*v3rbacpb.Permission{ + { + Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/products"}}}}}}, + }, + }, + }, + }, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the service admin + // policy. + { + rpcData: &rpcData{ + fullMethod: "some method", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + { + URIs: []*url.URL{ + { + Host: "cluster.local", + Path: "/ns/default/sa/admin", + }, + }, + }, + }, + }, + }, + }, + }, + wantStatusCode: codes.OK, + }, + // These incoming RPC calls should not match any policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + { + rpcData: &rpcData{ + fullMethod: "get-product-list", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + { + Subject: pkix.Name{ + CommonName: "localhost", + }, + }, + }, + }, + }, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + { + name: "NotMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "not-secret-content": { + Permissions: []*v3rbacpb.Permission{ + { + Rule: &v3rbacpb.Permission_NotRule{ + NotRule: &v3rbacpb.Permission{Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/secret-content"}}}}}}, + }, + }, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the not-secret-content policy. + { + rpcData: &rpcData{ + fullMethod: "/regular-content", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This incoming RPC Call shouldn't match with the not-secret-content-policy. + { + rpcData: &rpcData{ + fullMethod: "/secret-content", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + { + name: "DirectRemoteIpMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-direct-remote-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the certain-direct-remote-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This incoming RPC Call shouldn't match with the certain-direct-remote-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // This test tests a RBAC policy configured with a remote-ip policy. + // This should be logically equivalent to configuring a Engine with a + // direct-remote-ip policy, as per A41 - "allow equating RBAC's + // direct_remote_ip and remote_ip." + { + name: "RemoteIpMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-remote-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_RemoteIp{RemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the certain-remote-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This incoming RPC Call shouldn't match with the certain-remote-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + { + name: "DestinationIpMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-destination-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationIp{DestinationIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call shouldn't match with the + // certain-destination-ip policy, as the test listens on local + // host. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // AllowAndDenyPolicy tests a policy with an allow (on path) and + // deny (on port) policy chained together. This represents how a user + // configured interceptor would use this, and also is a potential + // configuration for a dynamic xds interceptor. + { + name: "AllowAndDenyPolicy", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + Action: v3rbacpb.RBAC_ALLOW, + }, + { + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + Action: v3rbacpb.RBAC_DENY, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This RPC should match with the allow policy, and shouldn't + // match with the deny and thus should be allowed to proceed. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This RPC should match with both the allow policy and deny policy + // and thus shouldn't be allowed to proceed as matched with deny. + { + rpcData: &rpcData{ + fullMethod: "localhost-fan-page", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + // This RPC shouldn't match with either policy, and thus + // shouldn't be allowed to proceed as didn't match with allow. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + // This RPC shouldn't match with allow, match with deny, and + // thus shouldn't be allowed to proceed. + { + rpcData: &rpcData{ + fullMethod: "localhost-fan-page", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // This test tests that when there are no SANs or Subject's + // distinguished name in incoming RPC's, that authenticated matchers + // match against the empty string. + { + name: "default-matching-no-credentials", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "service-admin": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: ""}}}}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the service admin + // policy. No authentication info is provided, so the + // authenticated matcher should match to the string matcher on + // the empty string, matching to the service-admin policy. + { + rpcData: &rpcData{ + fullMethod: "some method", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Instantiate the chainedRBACEngine with different configurations that are + // interesting to test and to query. + cre, err := NewChainEngine(test.rbacConfigs) + if err != nil { + t.Fatalf("Error constructing RBAC Engine: %v", err) + } + // Query the created chain of RBAC Engines with different args to see + // if the chain of RBAC Engines configured as such works as intended. + for _, data := range test.rbacQueries { + func() { + // Construct the context with three data points that have enough + // information to represent incoming RPC's. This will be how a + // user uses this API. A user will have to put MD, PeerInfo, and + // the connection the RPC is sent on in the context. + ctx := metadata.NewIncomingContext(context.Background(), data.rpcData.md) + + // Make a TCP connection with a certain destination port. The + // address/port of this connection will be used to populate the + // destination ip/port in RPCData struct. This represents what + // the user of ChainEngine will have to place into + // context, as this is only way to get destination ip and port. + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error listening: %v", err) + } + defer lis.Close() + connCh := make(chan net.Conn, 1) + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + connCh <- conn + }() + _, err = net.Dial("tcp", lis.Addr().String()) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + conn := <-connCh + defer conn.Close() + getConnection = func(context.Context) net.Conn { + return conn + } + ctx = peer.NewContext(ctx, data.rpcData.peerInfo) + stream := &ServerTransportStreamWithMethod{ + method: data.rpcData.fullMethod, + } + + ctx = grpc.NewContextWithServerTransportStream(ctx, stream) + err = cre.IsAuthorized(ctx) + if gotCode := status.Code(err); gotCode != data.wantStatusCode { + t.Fatalf("IsAuthorized(%+v, %+v) returned (%+v), want(%+v)", ctx, data.rpcData.fullMethod, gotCode, data.wantStatusCode) + } + }() + } + }) + } +} + +type ServerTransportStreamWithMethod struct { + method string +} + +func (sts *ServerTransportStreamWithMethod) Method() string { + return sts.method +} + +func (sts *ServerTransportStreamWithMethod) SetHeader(md metadata.MD) error { + return nil +} + +func (sts *ServerTransportStreamWithMethod) SendHeader(md metadata.MD) error { + return nil +} + +func (sts *ServerTransportStreamWithMethod) SetTrailer(md metadata.MD) error { + return nil +} diff --git a/internal/xds_handshake_cluster.go b/internal/xds_handshake_cluster.go new file mode 100644 index 00000000000..3677c3f04f8 --- /dev/null +++ b/internal/xds_handshake_cluster.go @@ -0,0 +1,40 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package internal + +import ( + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/resolver" +) + +// handshakeClusterNameKey is the type used as the key to store cluster name in +// the Attributes field of resolver.Address. +type handshakeClusterNameKey struct{} + +// SetXDSHandshakeClusterName returns a copy of addr in which the Attributes field +// is updated with the cluster name. +func SetXDSHandshakeClusterName(addr resolver.Address, clusterName string) resolver.Address { + addr.Attributes = addr.Attributes.WithValues(handshakeClusterNameKey{}, clusterName) + return addr +} + +// GetXDSHandshakeClusterName returns cluster name stored in attr. +func GetXDSHandshakeClusterName(attr *attributes.Attributes) (string, bool) { + v := attr.Value(handshakeClusterNameKey{}) + name, ok := v.(string) + return name, ok +} diff --git a/interop/client/client.go b/interop/client/client.go index 8854ed2d76e..f41f56fbbd5 100644 --- a/interop/client/client.go +++ b/interop/client/client.go @@ -20,9 +20,13 @@ package main import ( + "crypto/tls" + "crypto/x509" "flag" + "io/ioutil" "net" "strconv" + "time" "google.golang.org/grpc" _ "google.golang.org/grpc/balancer/grpclb" @@ -34,6 +38,7 @@ import ( "google.golang.org/grpc/interop" "google.golang.org/grpc/resolver" "google.golang.org/grpc/testdata" + _ "google.golang.org/grpc/xds/googledirectpath" testgrpc "google.golang.org/grpc/interop/grpc_testing" ) @@ -44,19 +49,24 @@ const ( ) var ( - caFile = flag.String("ca_file", "", "The file containning the CA root cert file") - useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true") - useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)") - customCredentialsType = flag.String("custom_credentials_type", "", "Custom creds to use, excluding TLS or ALTS") - altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address") - testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") - serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file") - oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens") - defaultServiceAccount = flag.String("default_service_account", "", "Email of GCE default service account") - serverHost = flag.String("server_host", "localhost", "The server host name") - serverPort = flag.Int("server_port", 10000, "The server port number") - tlsServerName = flag.String("server_host_override", "", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") - testCase = flag.String("test_case", "large_unary", + caFile = flag.String("ca_file", "", "The file containning the CA root cert file") + useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true") + useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)") + customCredentialsType = flag.String("custom_credentials_type", "", "Custom creds to use, excluding TLS or ALTS") + altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address") + testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") + serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file") + oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens") + defaultServiceAccount = flag.String("default_service_account", "", "Email of GCE default service account") + serverHost = flag.String("server_host", "localhost", "The server host name") + serverPort = flag.Int("server_port", 10000, "The server port number") + serviceConfigJSON = flag.String("service_config_json", "", "Disables service config lookups and sets the provided string as the default service config.") + soakIterations = flag.Int("soak_iterations", 10, "The number of iterations to use for the two soak tests: rpc_soak and channel_soak") + soakMaxFailures = flag.Int("soak_max_failures", 0, "The number of iterations in soak tests that are allowed to fail (either due to non-OK status code or exceeding the per-iteration max acceptable latency).") + soakPerIterationMaxAcceptableLatencyMs = flag.Int("soak_per_iteration_max_acceptable_latency_ms", 1000, "The number of milliseconds a single iteration in the two soak tests (rpc_soak and channel_soak) should take.") + soakOverallTimeoutSeconds = flag.Int("soak_overall_timeout_seconds", 10, "The overall number of seconds after which a soak test should stop and fail, if the desired number of iterations have not yet completed.") + tlsServerName = flag.String("server_host_override", "", "The server name used to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") + testCase = flag.String("test_case", "large_unary", `Configure different test cases. Valid options are: empty_unary : empty (zero bytes) request and response; large_unary : single request and (large) response; @@ -126,26 +136,32 @@ func main() { } resolver.SetDefaultScheme("dns") - serverAddr := net.JoinHostPort(*serverHost, strconv.Itoa(*serverPort)) + serverAddr := *serverHost + if *serverPort != 0 { + serverAddr = net.JoinHostPort(*serverHost, strconv.Itoa(*serverPort)) + } var opts []grpc.DialOption switch credsChosen { case credsTLS: - var sn string - if *tlsServerName != "" { - sn = *tlsServerName - } - var creds credentials.TransportCredentials + var roots *x509.CertPool if *testCA { - var err error if *caFile == "" { *caFile = testdata.Path("ca.pem") } - creds, err = credentials.NewClientTLSFromFile(*caFile, sn) + b, err := ioutil.ReadFile(*caFile) if err != nil { - logger.Fatalf("Failed to create TLS credentials %v", err) + logger.Fatalf("Failed to read root certificate file %q: %v", *caFile, err) + } + roots = x509.NewCertPool() + if !roots.AppendCertsFromPEM(b) { + logger.Fatalf("Failed to append certificates: %s", string(b)) } + } + var creds credentials.TransportCredentials + if *tlsServerName != "" { + creds = credentials.NewClientTLSFromCert(roots, *tlsServerName) } else { - creds = credentials.NewClientTLSFromCert(nil, sn) + creds = credentials.NewTLS(&tls.Config{RootCAs: roots}) } opts = append(opts, grpc.WithTransportCredentials(creds)) case credsALTS: @@ -183,7 +199,9 @@ func main() { opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(interop.GetToken(*serviceAccountKeyFile, *oauthScope)))) } } - opts = append(opts, grpc.WithBlock()) + if len(*serviceConfigJSON) > 0 { + opts = append(opts, grpc.WithDisableServiceConfig(), grpc.WithDefaultServiceConfig(*serviceConfigJSON)) + } conn, err := grpc.Dial(serverAddr, opts...) if err != nil { logger.Fatalf("Fail to dial: %v", err) @@ -278,6 +296,12 @@ func main() { case "pick_first_unary": interop.DoPickFirstUnary(tc) logger.Infoln("PickFirstUnary done") + case "rpc_soak": + interop.DoSoakTest(tc, serverAddr, opts, false /* resetChannel */, *soakIterations, *soakMaxFailures, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Now().Add(time.Duration(*soakOverallTimeoutSeconds)*time.Second)) + logger.Infoln("RpcSoak done") + case "channel_soak": + interop.DoSoakTest(tc, serverAddr, opts, true /* resetChannel */, *soakIterations, *soakMaxFailures, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Now().Add(time.Duration(*soakOverallTimeoutSeconds)*time.Second)) + logger.Infoln("ChannelSoak done") default: logger.Fatal("Unsupported test case: ", *testCase) } diff --git a/interop/grpc_testing/benchmark_service_grpc.pb.go b/interop/grpc_testing/benchmark_service_grpc.pb.go index 1dcba4587d2..f4e4436e97e 100644 --- a/interop/grpc_testing/benchmark_service_grpc.pb.go +++ b/interop/grpc_testing/benchmark_service_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/testing/benchmark_service.proto package grpc_testing diff --git a/interop/grpc_testing/report_qps_scenario_service_grpc.pb.go b/interop/grpc_testing/report_qps_scenario_service_grpc.pb.go index b0fe8c8f5ee..4bf3fce68ab 100644 --- a/interop/grpc_testing/report_qps_scenario_service_grpc.pb.go +++ b/interop/grpc_testing/report_qps_scenario_service_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/testing/report_qps_scenario_service.proto package grpc_testing diff --git a/interop/grpc_testing/test_grpc.pb.go b/interop/grpc_testing/test_grpc.pb.go index ad5310aed62..137a1e98ce6 100644 --- a/interop/grpc_testing/test_grpc.pb.go +++ b/interop/grpc_testing/test_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/testing/test.proto package grpc_testing diff --git a/interop/grpc_testing/worker_service_grpc.pb.go b/interop/grpc_testing/worker_service_grpc.pb.go index cc49b22b926..a97366df09a 100644 --- a/interop/grpc_testing/worker_service_grpc.pb.go +++ b/interop/grpc_testing/worker_service_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: grpc/testing/worker_service.proto package grpc_testing diff --git a/interop/grpclb_fallback/client.go b/interop/grpclb_fallback/client_linux.go similarity index 99% rename from interop/grpclb_fallback/client.go rename to interop/grpclb_fallback/client_linux.go index 61b2fae6968..c9b25a894b3 100644 --- a/interop/grpclb_fallback/client.go +++ b/interop/grpclb_fallback/client_linux.go @@ -1,5 +1,3 @@ -// +build linux,!appengine - /* * * Copyright 2019 gRPC authors. diff --git a/interop/test_utils.go b/interop/test_utils.go index cbcbcc4da17..19a5c1f7cd3 100644 --- a/interop/test_utils.go +++ b/interop/test_utils.go @@ -20,10 +20,12 @@ package interop import ( + "bytes" "context" "fmt" "io" "io/ioutil" + "os" "strings" "time" @@ -31,6 +33,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/grpc" + "google.golang.org/grpc/benchmark/stats" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" @@ -673,6 +676,91 @@ func DoPickFirstUnary(tc testgrpc.TestServiceClient) { } } +func doOneSoakIteration(ctx context.Context, tc testgrpc.TestServiceClient, resetChannel bool, serverAddr string, dopts []grpc.DialOption) (latency time.Duration, err error) { + start := time.Now() + client := tc + if resetChannel { + var conn *grpc.ClientConn + conn, err = grpc.Dial(serverAddr, dopts...) + if err != nil { + return + } + defer conn.Close() + client = testgrpc.NewTestServiceClient(conn) + } + // per test spec, don't include channel shutdown in latency measurement + defer func() { latency = time.Since(start) }() + // do a large-unary RPC + pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseSize: int32(largeRespSize), + Payload: pl, + } + var reply *testpb.SimpleResponse + reply, err = client.UnaryCall(ctx, req) + if err != nil { + err = fmt.Errorf("/TestService/UnaryCall RPC failed: %s", err) + return + } + t := reply.GetPayload().GetType() + s := len(reply.GetPayload().GetBody()) + if t != testpb.PayloadType_COMPRESSABLE || s != largeRespSize { + err = fmt.Errorf("got the reply with type %d len %d; want %d, %d", t, s, testpb.PayloadType_COMPRESSABLE, largeRespSize) + return + } + return +} + +// DoSoakTest runs large unary RPCs in a loop for a configurable number of times, with configurable failure thresholds. +// If resetChannel is false, then each RPC will be performed on tc. Otherwise, each RPC will be performed on a new +// stub that is created with the provided server address and dial options. +func DoSoakTest(tc testgrpc.TestServiceClient, serverAddr string, dopts []grpc.DialOption, resetChannel bool, soakIterations int, maxFailures int, perIterationMaxAcceptableLatency time.Duration, overallDeadline time.Time) { + start := time.Now() + ctx, cancel := context.WithDeadline(context.Background(), overallDeadline) + defer cancel() + iterationsDone := 0 + totalFailures := 0 + hopts := stats.HistogramOptions{ + NumBuckets: 20, + GrowthFactor: 1, + BaseBucketSize: 1, + MinValue: 0, + } + h := stats.NewHistogram(hopts) + for i := 0; i < soakIterations; i++ { + if time.Now().After(overallDeadline) { + break + } + iterationsDone++ + latency, err := doOneSoakIteration(ctx, tc, resetChannel, serverAddr, dopts) + latencyMs := int64(latency / time.Millisecond) + h.Add(latencyMs) + if err != nil { + totalFailures++ + fmt.Fprintf(os.Stderr, "soak iteration: %d elapsed_ms: %d failed: %s\n", i, latencyMs, err) + continue + } + if latency > perIterationMaxAcceptableLatency { + totalFailures++ + fmt.Fprintf(os.Stderr, "soak iteration: %d elapsed_ms: %d exceeds max acceptable latency: %d\n", i, latencyMs, perIterationMaxAcceptableLatency.Milliseconds()) + continue + } + fmt.Fprintf(os.Stderr, "soak iteration: %d elapsed_ms: %d succeeded\n", i, latencyMs) + } + var b bytes.Buffer + h.Print(&b) + fmt.Fprintln(os.Stderr, "Histogram of per-iteration latencies in milliseconds:") + fmt.Fprintln(os.Stderr, b.String()) + fmt.Fprintf(os.Stderr, "soak test ran: %d / %d iterations. total failures: %d. max failures threshold: %d. See breakdown above for which iterations succeeded, failed, and why for more info.\n", iterationsDone, soakIterations, totalFailures, maxFailures) + if iterationsDone < soakIterations { + logger.Fatalf("soak test consumed all %f seconds of time and quit early, only having ran %d out of desired %d iterations.", overallDeadline.Sub(start).Seconds(), iterationsDone, soakIterations) + } + if totalFailures > maxFailures { + logger.Fatalf("soak test total failures: %d exceeds max failures threshold: %d.", totalFailures, maxFailures) + } +} + type testServer struct { testgrpc.UnimplementedTestServiceServer } diff --git a/interop/xds/client/Dockerfile b/interop/xds/client/Dockerfile new file mode 100644 index 00000000000..533bb6adb3e --- /dev/null +++ b/interop/xds/client/Dockerfile @@ -0,0 +1,37 @@ +# Copyright 2021 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dockerfile for building the xDS interop client. To build the image, run the +# following command from grpc-go directory: +# docker build -t -f interop/xds/client/Dockerfile . + +FROM golang:1.16-alpine as build + +# Make a grpc-go directory and copy the repo into it. +WORKDIR /go/src/grpc-go +COPY . . + +# Build a static binary without cgo so that we can copy just the binary in the +# final image, and can get rid of Go compiler and gRPC-Go dependencies. +RUN go build -tags osusergo,netgo interop/xds/client/client.go + +# Second stage of the build which copies over only the client binary and skips +# the Go compiler and gRPC repo from the earlier stage. This significantly +# reduces the docker image size. +FROM alpine +COPY --from=build /go/src/grpc-go/client . +ENV GRPC_GO_LOG_VERBOSITY_LEVEL=2 +ENV GRPC_GO_LOG_SEVERITY_LEVEL="info" +ENV GRPC_GO_LOG_FORMATTER="json" +ENTRYPOINT ["./client"] diff --git a/interop/xds/client/client.go b/interop/xds/client/client.go index 5b755272d3e..190c8a7d786 100644 --- a/interop/xds/client/client.go +++ b/interop/xds/client/client.go @@ -31,9 +31,13 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/admin" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" _ "google.golang.org/grpc/xds" @@ -173,6 +177,7 @@ var ( rpcTimeout = flag.Duration("rpc_timeout", 20*time.Second, "Per RPC timeout") server = flag.String("server", "localhost:8080", "Address of server to connect to") statsPort = flag.Int("stats_port", 8081, "Port to expose peer distribution stats service") + secureMode = flag.Bool("secure_mode", false, "If true, retrieve security configuration from the management server. Else, use insecure credentials.") rpcCfgs atomic.Value @@ -309,12 +314,13 @@ const ( emptyCall string = "EmptyCall" ) -func parseRPCTypes(rpcStr string) (ret []string) { +func parseRPCTypes(rpcStr string) []string { if len(rpcStr) == 0 { return []string{unaryCall} } rpcs := strings.Split(rpcStr, ",") + ret := make([]string, 0, len(rpcStr)) for _, r := range rpcs { switch r { case unaryCall, emptyCall: @@ -324,7 +330,7 @@ func parseRPCTypes(rpcStr string) (ret []string) { log.Fatalf("unsupported RPC type: %v", r) } } - return + return ret } type rpcConfig struct { @@ -370,11 +376,26 @@ func main() { defer s.Stop() testgrpc.RegisterLoadBalancerStatsServiceServer(s, &statsService{}) testgrpc.RegisterXdsUpdateClientConfigureServiceServer(s, &configureService{}) + reflection.Register(s) + cleanup, err := admin.Register(s) + if err != nil { + logger.Fatalf("Failed to register admin: %v", err) + } + defer cleanup() go s.Serve(lis) + creds := insecure.NewCredentials() + if *secureMode { + var err error + creds, err = xds.NewClientCredentials(xds.ClientOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + logger.Fatalf("Failed to create xDS credentials: %v", err) + } + } + clients := make([]testgrpc.TestServiceClient, *numChannels) for i := 0; i < *numChannels; i++ { - conn, err := grpc.Dial(*server, grpc.WithInsecure()) + conn, err := grpc.Dial(*server, grpc.WithTransportCredentials(creds)) if err != nil { logger.Fatalf("Fail to dial: %v", err) } diff --git a/interop/xds/server/Dockerfile b/interop/xds/server/Dockerfile new file mode 100644 index 00000000000..cd8dcb5ccaa --- /dev/null +++ b/interop/xds/server/Dockerfile @@ -0,0 +1,36 @@ +# Copyright 2021 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dockerfile for building the xDS interop server. To build the image, run the +# following command from grpc-go directory: +# docker build -t -f interop/xds/server/Dockerfile . + +FROM golang:1.16-alpine as build + +# Make a grpc-go directory and copy the repo into it. +WORKDIR /go/src/grpc-go +COPY . . + +# Build a static binary without cgo so that we can copy just the binary in the +# final image, and can get rid of the Go compiler and gRPC-Go dependencies. +RUN go build -tags osusergo,netgo interop/xds/server/server.go + +# Second stage of the build which copies over only the client binary and skips +# the Go compiler and gRPC repo from the earlier stage. This significantly +# reduces the docker image size. +FROM alpine +COPY --from=build /go/src/grpc-go/server . +ENV GRPC_GO_LOG_VERBOSITY_LEVEL=2 +ENV GRPC_GO_LOG_SEVERITY_LEVEL="info" +ENTRYPOINT ["./server"] diff --git a/interop/xds/server/server.go b/interop/xds/server/server.go index 4989eb728ee..afbbc56af89 100644 --- a/interop/xds/server/server.go +++ b/interop/xds/server/server.go @@ -1,6 +1,6 @@ /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,34 +16,46 @@ * */ -// Binary server for xDS interop tests. +// Binary server is the server used for xDS interop tests. package main import ( "context" "flag" + "fmt" "log" "net" "os" - "strconv" "google.golang.org/grpc" + "google.golang.org/grpc/admin" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/health" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/xds" + xdscreds "google.golang.org/grpc/credentials/xds" + healthpb "google.golang.org/grpc/health/grpc_health_v1" testgrpc "google.golang.org/grpc/interop/grpc_testing" testpb "google.golang.org/grpc/interop/grpc_testing" ) var ( - port = flag.Int("port", 8080, "The server port") - serverID = flag.String("server_id", "go_server", "Server ID included in response") - hostname = getHostname() + port = flag.Int("port", 8080, "Listening port for test service") + maintenancePort = flag.Int("maintenance_port", 8081, "Listening port for maintenance services like health, reflection, channelz etc when -secure_mode is true. When -secure_mode is false, all these services will be registered on -port") + serverID = flag.String("server_id", "go_server", "Server ID included in response") + secureMode = flag.Bool("secure_mode", false, "If true, retrieve security configuration from the management server. Else, use insecure credentials.") + hostNameOverride = flag.String("host_name_override", "", "If set, use this as the hostname instead of the real hostname") logger = grpclog.Component("interop") ) func getHostname() string { + if *hostNameOverride != "" { + return *hostNameOverride + } hostname, err := os.Hostname() if err != nil { log.Fatalf("failed to get hostname: %v", err) @@ -51,28 +63,127 @@ func getHostname() string { return hostname } -type server struct { +// testServiceImpl provides an implementation of the TestService defined in +// grpc.testing package. +type testServiceImpl struct { testgrpc.UnimplementedTestServiceServer + hostname string + serverID string +} + +func (s *testServiceImpl) EmptyCall(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + grpc.SetHeader(ctx, metadata.Pairs("hostname", s.hostname)) + return &testpb.Empty{}, nil +} + +func (s *testServiceImpl) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + grpc.SetHeader(ctx, metadata.Pairs("hostname", s.hostname)) + return &testpb.SimpleResponse{ServerId: s.serverID, Hostname: s.hostname}, nil +} + +// xdsUpdateHealthServiceImpl provides an implementation of the +// XdsUpdateHealthService defined in grpc.testing package. +type xdsUpdateHealthServiceImpl struct { + testgrpc.UnimplementedXdsUpdateHealthServiceServer + healthServer *health.Server } -func (s *server) EmptyCall(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { - grpc.SetHeader(ctx, metadata.Pairs("hostname", hostname)) +func (x *xdsUpdateHealthServiceImpl) SetServing(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + x.healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + return &testpb.Empty{}, nil + +} + +func (x *xdsUpdateHealthServiceImpl) SetNotServing(_ context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + x.healthServer.SetServingStatus("", healthpb.HealthCheckResponse_NOT_SERVING) return &testpb.Empty{}, nil } -func (s *server) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { - grpc.SetHeader(ctx, metadata.Pairs("hostname", hostname)) - return &testpb.SimpleResponse{ServerId: *serverID, Hostname: hostname}, nil +func xdsServingModeCallback(addr net.Addr, args xds.ServingModeChangeArgs) { + logger.Infof("Serving mode for xDS server at %s changed to %s", addr.String(), args.Mode) + if args.Err != nil { + logger.Infof("ServingModeCallback returned error: %v", args.Err) + } } func main() { flag.Parse() - p := strconv.Itoa(*port) - lis, err := net.Listen("tcp", ":"+p) + + if *secureMode && *port == *maintenancePort { + logger.Fatal("-port and -maintenance_port must be different when -secure_mode is set") + } + + testService := &testServiceImpl{hostname: getHostname(), serverID: *serverID} + healthServer := health.NewServer() + updateHealthService := &xdsUpdateHealthServiceImpl{healthServer: healthServer} + + // If -secure_mode is not set, expose all services on -port with a regular + // gRPC server. + if !*secureMode { + lis, err := net.Listen("tcp4", fmt.Sprintf(":%d", *port)) + if err != nil { + logger.Fatalf("net.Listen(%s) failed: %v", fmt.Sprintf(":%d", *port), err) + } + + server := grpc.NewServer() + testgrpc.RegisterTestServiceServer(server, testService) + healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + healthpb.RegisterHealthServer(server, healthServer) + testgrpc.RegisterXdsUpdateHealthServiceServer(server, updateHealthService) + reflection.Register(server) + cleanup, err := admin.Register(server) + if err != nil { + logger.Fatalf("Failed to register admin services: %v", err) + } + defer cleanup() + if err := server.Serve(lis); err != nil { + logger.Errorf("Serve() failed: %v", err) + } + return + } + + // Create a listener on -port to expose the test service. + testLis, err := net.Listen("tcp4", fmt.Sprintf(":%d", *port)) if err != nil { - logger.Fatalf("failed to listen: %v", err) + logger.Fatalf("net.Listen(%s) failed: %v", fmt.Sprintf(":%d", *port), err) + } + + // Create server-side xDS credentials with a plaintext fallback. + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + logger.Fatalf("Failed to create xDS credentials: %v", err) + } + + // Create an xDS enabled gRPC server, register the test service + // implementation and start serving. + testServer := xds.NewGRPCServer(grpc.Creds(creds), xds.ServingModeCallback(xdsServingModeCallback)) + testgrpc.RegisterTestServiceServer(testServer, testService) + go func() { + if err := testServer.Serve(testLis); err != nil { + logger.Errorf("test server Serve() failed: %v", err) + } + }() + defer testServer.Stop() + + // Create a listener on -maintenance_port to expose other services. + maintenanceLis, err := net.Listen("tcp4", fmt.Sprintf(":%d", *maintenancePort)) + if err != nil { + logger.Fatalf("net.Listen(%s) failed: %v", fmt.Sprintf(":%d", *maintenancePort), err) + } + + // Create a regular gRPC server and register the maintenance services on + // it and start serving. + maintenanceServer := grpc.NewServer() + healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + healthpb.RegisterHealthServer(maintenanceServer, healthServer) + testgrpc.RegisterXdsUpdateHealthServiceServer(maintenanceServer, updateHealthService) + reflection.Register(maintenanceServer) + cleanup, err := admin.Register(maintenanceServer) + if err != nil { + logger.Fatalf("Failed to register admin services: %v", err) + } + defer cleanup() + if err := maintenanceServer.Serve(maintenanceLis); err != nil { + logger.Errorf("maintenance server Serve() failed: %v", err) } - s := grpc.NewServer() - testgrpc.RegisterTestServiceServer(s, &server{}) - s.Serve(lis) } diff --git a/metadata/metadata.go b/metadata/metadata.go index cf6d1b94781..3604c7819fd 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -75,13 +75,9 @@ func Pairs(kv ...string) MD { panic(fmt.Sprintf("metadata: Pairs got the odd number of input pairs for metadata: %d", len(kv))) } md := MD{} - var key string - for i, s := range kv { - if i%2 == 0 { - key = strings.ToLower(s) - continue - } - md[key] = append(md[key], s) + for i := 0; i < len(kv); i += 2 { + key := strings.ToLower(kv[i]) + md[key] = append(md[key], kv[i+1]) } return md } @@ -97,12 +93,16 @@ func (md MD) Copy() MD { } // Get obtains the values for a given key. +// +// k is converted to lowercase before searching in md. func (md MD) Get(k string) []string { k = strings.ToLower(k) return md[k] } // Set sets the value of a given key with a slice of values. +// +// k is converted to lowercase before storing in md. func (md MD) Set(k string, vals ...string) { if len(vals) == 0 { return @@ -111,7 +111,10 @@ func (md MD) Set(k string, vals ...string) { md[k] = vals } -// Append adds the values to key k, not overwriting what was already stored at that key. +// Append adds the values to key k, not overwriting what was already stored at +// that key. +// +// k is converted to lowercase before storing in md. func (md MD) Append(k string, vals ...string) { if len(vals) == 0 { return @@ -120,9 +123,17 @@ func (md MD) Append(k string, vals ...string) { md[k] = append(md[k], vals...) } +// Delete removes the values for a given key k which is converted to lowercase +// before removing it from md. +func (md MD) Delete(k string) { + k = strings.ToLower(k) + delete(md, k) +} + // Join joins any number of mds into a single MD. -// The order of values for each key is determined by the order in which -// the mds containing those values are presented to Join. +// +// The order of values for each key is determined by the order in which the mds +// containing those values are presented to Join. func Join(mds ...MD) MD { out := MD{} for _, md := range mds { @@ -149,8 +160,8 @@ func NewOutgoingContext(ctx context.Context, md MD) context.Context { } // AppendToOutgoingContext returns a new context with the provided kv merged -// with any existing metadata in the context. Please refer to the -// documentation of Pairs for a description of kv. +// with any existing metadata in the context. Please refer to the documentation +// of Pairs for a description of kv. func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context { if len(kv)%2 == 1 { panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv))) @@ -163,20 +174,34 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md.md, added: added}) } -// FromIncomingContext returns the incoming metadata in ctx if it exists. The -// returned MD should not be modified. Writing to it may cause races. -// Modification should be made to copies of the returned MD. -func FromIncomingContext(ctx context.Context) (md MD, ok bool) { - md, ok = ctx.Value(mdIncomingKey{}).(MD) - return +// FromIncomingContext returns the incoming metadata in ctx if it exists. +// +// All keys in the returned MD are lowercase. +func FromIncomingContext(ctx context.Context) (MD, bool) { + md, ok := ctx.Value(mdIncomingKey{}).(MD) + if !ok { + return nil, false + } + out := MD{} + for k, v := range md { + // We need to manually convert all keys to lower case, because MD is a + // map, and there's no guarantee that the MD attached to the context is + // created using our helper functions. + key := strings.ToLower(k) + out[key] = v + } + return out, true } -// FromOutgoingContextRaw returns the un-merged, intermediary contents -// of rawMD. Remember to perform strings.ToLower on the keys. The returned -// MD should not be modified. Writing to it may cause races. Modification -// should be made to copies of the returned MD. +// FromOutgoingContextRaw returns the un-merged, intermediary contents of rawMD. // -// This is intended for gRPC-internal use ONLY. +// Remember to perform strings.ToLower on the keys, for both the returned MD (MD +// is a map, there's no guarantee it's created using our helper functions) and +// the extra kv pairs (AppendToOutgoingContext doesn't turn them into +// lowercase). +// +// This is intended for gRPC-internal use ONLY. Users should use +// FromOutgoingContext instead. func FromOutgoingContextRaw(ctx context.Context) (MD, [][]string, bool) { raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD) if !ok { @@ -186,21 +211,34 @@ func FromOutgoingContextRaw(ctx context.Context) (MD, [][]string, bool) { return raw.md, raw.added, true } -// FromOutgoingContext returns the outgoing metadata in ctx if it exists. The -// returned MD should not be modified. Writing to it may cause races. -// Modification should be made to copies of the returned MD. +// FromOutgoingContext returns the outgoing metadata in ctx if it exists. +// +// All keys in the returned MD are lowercase. func FromOutgoingContext(ctx context.Context) (MD, bool) { raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD) if !ok { return nil, false } - mds := make([]MD, 0, len(raw.added)+1) - mds = append(mds, raw.md) - for _, vv := range raw.added { - mds = append(mds, Pairs(vv...)) + out := MD{} + for k, v := range raw.md { + // We need to manually convert all keys to lower case, because MD is a + // map, and there's no guarantee that the MD attached to the context is + // created using our helper functions. + key := strings.ToLower(k) + out[key] = v + } + for _, added := range raw.added { + if len(added)%2 == 1 { + panic(fmt.Sprintf("metadata: FromOutgoingContext got an odd number of input pairs for metadata: %d", len(added))) + } + + for i := 0; i < len(added); i += 2 { + key := strings.ToLower(added[i]) + out[key] = append(out[key], added[i+1]) + } } - return Join(mds...), ok + return out, ok } type rawMD struct { diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index f1fb5f6d324..89be06eaada 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -169,6 +169,35 @@ func (s) TestAppend(t *testing.T) { } } +func (s) TestDelete(t *testing.T) { + for _, test := range []struct { + md MD + deleteKey string + want MD + }{ + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "My-Optional-Header", + want: Pairs(), + }, + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "Other-Key", + want: Pairs("my-optional-header", "42"), + }, + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "my-OptIoNal-HeAder", + want: Pairs(), + }, + } { + test.md.Delete(test.deleteKey) + if !reflect.DeepEqual(test.md, test.want) { + t.Errorf("value of metadata is %v, want %v", test.md, test.want) + } + } +} + func (s) TestAppendToOutgoingContext(t *testing.T) { // Pre-existing metadata tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) diff --git a/picker_wrapper.go b/picker_wrapper.go index a58174b6f43..e8367cb8993 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -144,10 +144,10 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. acw, ok := pickResult.SubConn.(*acBalancerWrapper) if !ok { - logger.Error("subconn returned from pick is not *acBalancerWrapper") + logger.Errorf("subconn returned from pick is type %T, not *acBalancerWrapper", pickResult.SubConn) continue } - if t, ok := acw.getAddrConn().getReadyTransport(); ok { + if t := acw.getAddrConn().getReadyTransport(); t != nil { if channelz.IsOn() { return t, doneChannelzWrapper(acw, pickResult.Done), nil } diff --git a/pickfirst.go b/pickfirst.go index b858c2a5e63..f194d14a081 100644 --- a/pickfirst.go +++ b/pickfirst.go @@ -107,10 +107,12 @@ func (b *pickfirstBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.S } switch s.ConnectivityState { - case connectivity.Ready, connectivity.Idle: + case connectivity.Ready: b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{result: balancer.PickResult{SubConn: sc}}}) case connectivity.Connecting: b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable}}) + case connectivity.Idle: + b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &idlePicker{sc: sc}}) case connectivity.TransientFailure: b.cc.UpdateState(balancer.State{ ConnectivityState: s.ConnectivityState, @@ -122,6 +124,12 @@ func (b *pickfirstBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.S func (b *pickfirstBalancer) Close() { } +func (b *pickfirstBalancer) ExitIdle() { + if b.state == connectivity.Idle { + b.sc.Connect() + } +} + type picker struct { result balancer.PickResult err error @@ -131,6 +139,17 @@ func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { return p.result, p.err } +// idlePicker is used when the SubConn is IDLE and kicks the SubConn into +// CONNECTING when Pick is called. +type idlePicker struct { + sc balancer.SubConn +} + +func (i *idlePicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + i.sc.Connect() + return balancer.PickResult{}, balancer.ErrNoSubConnAvailable +} + func init() { balancer.Register(newPickfirstBuilder()) } diff --git a/profiling/proto/service_grpc.pb.go b/profiling/proto/service_grpc.pb.go index bfdcc69bffb..a0656149bda 100644 --- a/profiling/proto/service_grpc.pb.go +++ b/profiling/proto/service_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: profiling/proto/service.proto package proto diff --git a/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go b/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go index c2b7429a06b..7d05c14ebd8 100644 --- a/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go +++ b/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: reflection/grpc_reflection_v1alpha/reflection.proto package grpc_reflection_v1alpha diff --git a/reflection/grpc_testing/test_grpc.pb.go b/reflection/grpc_testing/test_grpc.pb.go index 76ec8d52d68..235b5d82484 100644 --- a/reflection/grpc_testing/test_grpc.pb.go +++ b/reflection/grpc_testing/test_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: reflection/grpc_testing/test.proto package grpc_testing diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index d2696168b10..82a5ba7f244 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -54,9 +54,19 @@ import ( "google.golang.org/grpc/status" ) +// GRPCServer is the interface provided by a gRPC server. It is implemented by +// *grpc.Server, but could also be implemented by other concrete types. It acts +// as a registry, for accumulating the services exposed by the server. +type GRPCServer interface { + grpc.ServiceRegistrar + GetServiceInfo() map[string]grpc.ServiceInfo +} + +var _ GRPCServer = (*grpc.Server)(nil) + type serverReflectionServer struct { rpb.UnimplementedServerReflectionServer - s *grpc.Server + s GRPCServer initSymbols sync.Once serviceNames []string @@ -64,7 +74,7 @@ type serverReflectionServer struct { } // Register registers the server reflection service on the given gRPC server. -func Register(s *grpc.Server) { +func Register(s GRPCServer) { rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ s: s, }) diff --git a/regenerate.sh b/regenerate.sh index fc6725b89f8..dfd3226a1d9 100755 --- a/regenerate.sh +++ b/regenerate.sh @@ -48,11 +48,6 @@ mkdir -p ${WORKDIR}/googleapis/google/rpc echo "curl https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto" curl --silent https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto > ${WORKDIR}/googleapis/google/rpc/code.proto -# Pull in the MeshCA service proto. -mkdir -p ${WORKDIR}/istio/istio/google/security/meshca/v1 -echo "curl https://raw.githubusercontent.com/istio/istio/master/security/proto/providers/google/meshca.proto" -curl --silent https://raw.githubusercontent.com/istio/istio/master/security/proto/providers/google/meshca.proto > ${WORKDIR}/istio/istio/google/security/meshca/v1/meshca.proto - mkdir -p ${WORKDIR}/out # Generates sources without the embed requirement @@ -76,7 +71,6 @@ SOURCES=( ${WORKDIR}/grpc-proto/grpc/service_config/service_config.proto ${WORKDIR}/grpc-proto/grpc/testing/*.proto ${WORKDIR}/grpc-proto/grpc/core/*.proto - ${WORKDIR}/istio/istio/google/security/meshca/v1/meshca.proto ) # These options of the form 'Mfoo.proto=bar' instruct the codegen to use an @@ -122,8 +116,4 @@ mv ${WORKDIR}/out/grpc/service_config/service_config.pb.go internal/proto/grpc_s mv ${WORKDIR}/out/grpc/testing/*.pb.go interop/grpc_testing/ mv ${WORKDIR}/out/grpc/core/*.pb.go interop/grpc_testing/core/ -# istio/google/security/meshca/v1/meshca.proto does not have a go_package option. -mkdir -p ${WORKDIR}/out/google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1/ -mv ${WORKDIR}/out/istio/google/security/meshca/v1/* ${WORKDIR}/out/google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1/ - cp -R ${WORKDIR}/out/google.golang.org/grpc/* . diff --git a/resolver/manual/manual.go b/resolver/manual/manual.go index 3679d702ab9..f6e7b5ae358 100644 --- a/resolver/manual/manual.go +++ b/resolver/manual/manual.go @@ -27,7 +27,9 @@ import ( // NewBuilderWithScheme creates a new test resolver builder with the given scheme. func NewBuilderWithScheme(scheme string) *Resolver { return &Resolver{ + BuildCallback: func(resolver.Target, resolver.ClientConn, resolver.BuildOptions) {}, ResolveNowCallback: func(resolver.ResolveNowOptions) {}, + CloseCallback: func() {}, scheme: scheme, } } @@ -35,11 +37,17 @@ func NewBuilderWithScheme(scheme string) *Resolver { // Resolver is also a resolver builder. // It's build() function always returns itself. type Resolver struct { + // BuildCallback is called when the Build method is called. Must not be + // nil. Must not be changed after the resolver may be built. + BuildCallback func(resolver.Target, resolver.ClientConn, resolver.BuildOptions) // ResolveNowCallback is called when the ResolveNow method is called on the // resolver. Must not be nil. Must not be changed after the resolver may // be built. ResolveNowCallback func(resolver.ResolveNowOptions) - scheme string + // CloseCallback is called when the Close method is called. Must not be + // nil. Must not be changed after the resolver may be built. + CloseCallback func() + scheme string // Fields actually belong to the resolver. CC resolver.ClientConn @@ -54,6 +62,7 @@ func (r *Resolver) InitialState(s resolver.State) { // Build returns itself for Resolver, because it's both a builder and a resolver. func (r *Resolver) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { + r.BuildCallback(target, cc, opts) r.CC = cc if r.bootstrapState != nil { r.UpdateState(*r.bootstrapState) @@ -72,9 +81,16 @@ func (r *Resolver) ResolveNow(o resolver.ResolveNowOptions) { } // Close is a noop for Resolver. -func (*Resolver) Close() {} +func (r *Resolver) Close() { + r.CloseCallback() +} // UpdateState calls CC.UpdateState. func (r *Resolver) UpdateState(s resolver.State) { r.CC.UpdateState(s) } + +// ReportError calls CC.ReportError. +func (r *Resolver) ReportError(err error) { + r.CC.ReportError(err) +} diff --git a/resolver/resolver.go b/resolver/resolver.go index e9fa8e33d92..6a9d234a597 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -181,7 +181,7 @@ type State struct { // gRPC to add new methods to this interface. type ClientConn interface { // UpdateState updates the state of the ClientConn appropriately. - UpdateState(State) + UpdateState(State) error // ReportError notifies the ClientConn that the Resolver encountered an // error. The ClientConn will notify the load balancer and begin calling // ResolveNow on the Resolver with exponential backoff. diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go index f2d81968f9e..2c47cd54f07 100644 --- a/resolver_conn_wrapper.go +++ b/resolver_conn_wrapper.go @@ -22,7 +22,6 @@ import ( "fmt" "strings" "sync" - "time" "google.golang.org/grpc/balancer" "google.golang.org/grpc/credentials" @@ -41,8 +40,7 @@ type ccResolverWrapper struct { done *grpcsync.Event curState resolver.State - pollingMu sync.Mutex - polling chan struct{} + incomingMu sync.Mutex // Synchronizes all the incoming calls. } // newCCResolverWrapper uses the resolver.Builder to build a Resolver and @@ -93,71 +91,37 @@ func (ccr *ccResolverWrapper) close() { ccr.resolverMu.Unlock() } -// poll begins or ends asynchronous polling of the resolver based on whether -// err is ErrBadResolverState. -func (ccr *ccResolverWrapper) poll(err error) { - ccr.pollingMu.Lock() - defer ccr.pollingMu.Unlock() - if err != balancer.ErrBadResolverState { - // stop polling - if ccr.polling != nil { - close(ccr.polling) - ccr.polling = nil - } - return - } - if ccr.polling != nil { - // already polling - return - } - p := make(chan struct{}) - ccr.polling = p - go func() { - for i := 0; ; i++ { - ccr.resolveNow(resolver.ResolveNowOptions{}) - t := time.NewTimer(ccr.cc.dopts.resolveNowBackoff(i)) - select { - case <-p: - t.Stop() - return - case <-ccr.done.Done(): - // Resolver has been closed. - t.Stop() - return - case <-t.C: - select { - case <-p: - return - default: - } - // Timer expired; re-resolve. - } - } - }() -} - -func (ccr *ccResolverWrapper) UpdateState(s resolver.State) { +func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { - return + return nil } channelz.Infof(logger, ccr.cc.channelzID, "ccResolverWrapper: sending update to cc: %v", s) if channelz.IsOn() { ccr.addChannelzTraceEvent(s) } ccr.curState = s - ccr.poll(ccr.cc.updateResolverState(ccr.curState, nil)) + if err := ccr.cc.updateResolverState(ccr.curState, nil); err == balancer.ErrBadResolverState { + return balancer.ErrBadResolverState + } + return nil } func (ccr *ccResolverWrapper) ReportError(err error) { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return } channelz.Warningf(logger, ccr.cc.channelzID, "ccResolverWrapper: reporting error to cc: %v", err) - ccr.poll(ccr.cc.updateResolverState(resolver.State{}, err)) + ccr.cc.updateResolverState(resolver.State{}, err) } // NewAddress is called by the resolver implementation to send addresses to gRPC. func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return } @@ -166,12 +130,14 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { ccr.addChannelzTraceEvent(resolver.State{Addresses: addrs, ServiceConfig: ccr.curState.ServiceConfig}) } ccr.curState.Addresses = addrs - ccr.poll(ccr.cc.updateResolverState(ccr.curState, nil)) + ccr.cc.updateResolverState(ccr.curState, nil) } // NewServiceConfig is called by the resolver implementation to send service // configs to gRPC. func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return } @@ -183,14 +149,13 @@ func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { scpr := parseServiceConfig(sc) if scpr.Err != nil { channelz.Warningf(logger, ccr.cc.channelzID, "ccResolverWrapper: error parsing service config: %v", scpr.Err) - ccr.poll(balancer.ErrBadResolverState) return } if channelz.IsOn() { ccr.addChannelzTraceEvent(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: scpr}) } ccr.curState.ServiceConfig = scpr - ccr.poll(ccr.cc.updateResolverState(ccr.curState, nil)) + ccr.cc.updateResolverState(ccr.curState, nil) } func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.ParseResult { diff --git a/resolver_conn_wrapper_test.go b/resolver_conn_wrapper_test.go index f13a408937b..81c5b9ea874 100644 --- a/resolver_conn_wrapper_test.go +++ b/resolver_conn_wrapper_test.go @@ -67,62 +67,6 @@ func (s) TestDialParseTargetUnknownScheme(t *testing.T) { } } -func testResolverErrorPolling(t *testing.T, badUpdate func(*manual.Resolver), goodUpdate func(*manual.Resolver), dopts ...DialOption) { - boIter := make(chan int) - resolverBackoff := func(v int) time.Duration { - boIter <- v - return 0 - } - - r := manual.NewBuilderWithScheme("whatever") - rn := make(chan struct{}) - defer func() { close(rn) }() - r.ResolveNowCallback = func(resolver.ResolveNowOptions) { rn <- struct{}{} } - - defaultDialOptions := []DialOption{ - WithInsecure(), - WithResolvers(r), - withResolveNowBackoff(resolverBackoff), - } - cc, err := Dial(r.Scheme()+":///test.server", append(defaultDialOptions, dopts...)...) - if err != nil { - t.Fatalf("Dial(_, _) = _, %v; want _, nil", err) - } - defer cc.Close() - badUpdate(r) - - panicAfter := time.AfterFunc(5*time.Second, func() { panic("timed out polling resolver") }) - defer panicAfter.Stop() - - // Ensure ResolveNow is called, then Backoff with the right parameter, several times - for i := 0; i < 7; i++ { - <-rn - if v := <-boIter; v != i { - t.Errorf("Backoff call %v uses value %v", i, v) - } - } - - // UpdateState will block if ResolveNow is being called (which blocks on - // rn), so call it in a goroutine. - goodUpdate(r) - - // Wait awhile to ensure ResolveNow and Backoff stop being called when the - // state is OK (i.e. polling was cancelled). - for { - t := time.NewTimer(50 * time.Millisecond) - select { - case <-rn: - // ClientConn is still calling ResolveNow - <-boIter - time.Sleep(5 * time.Millisecond) - continue - case <-t.C: - // ClientConn stopped calling ResolveNow; success - } - break - } -} - const happyBalancerName = "happy balancer" func init() { @@ -136,35 +80,6 @@ func init() { stub.Register(happyBalancerName, bf) } -// TestResolverErrorPolling injects resolver errors and verifies ResolveNow is -// called with the appropriate backoff strategy being consulted between -// ResolveNow calls. -func (s) TestResolverErrorPolling(t *testing.T) { - testResolverErrorPolling(t, func(r *manual.Resolver) { - r.CC.ReportError(errors.New("res err")) - }, func(r *manual.Resolver) { - // UpdateState will block if ResolveNow is being called (which blocks on - // rn), so call it in a goroutine. - go r.CC.UpdateState(resolver.State{}) - }, - WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, happyBalancerName))) -} - -// TestServiceConfigErrorPolling injects a service config error and verifies -// ResolveNow is called with the appropriate backoff strategy being consulted -// between ResolveNow calls. -func (s) TestServiceConfigErrorPolling(t *testing.T) { - testResolverErrorPolling(t, func(r *manual.Resolver) { - badsc := r.CC.ParseServiceConfig("bad config") - r.UpdateState(resolver.State{ServiceConfig: badsc}) - }, func(r *manual.Resolver) { - // UpdateState will block if ResolveNow is being called (which blocks on - // rn), so call it in a goroutine. - go r.CC.UpdateState(resolver.State{}) - }, - WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, happyBalancerName))) -} - // TestResolverErrorInBuild makes the resolver.Builder call into the ClientConn // during the Build call. We use two separate mutexes in the code which make // sure there is no data race in this code path, and also that there is no diff --git a/rpc_util.go b/rpc_util.go index f781aaf751e..05904c816de 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -258,7 +258,8 @@ func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) { } // WaitForReady configures the action to take when an RPC is attempted on broken -// connections or unreachable servers. If waitForReady is false, the RPC will fail +// connections or unreachable servers. If waitForReady is false and the +// connection is in the TRANSIENT_FAILURE state, the RPC will fail // immediately. Otherwise, the RPC client will block the call until a // connection is available (or the call is canceled or times out) and will // retry the call if it fails due to a transient error. gRPC will not retry if @@ -429,9 +430,10 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error { } func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {} -// ForceCodec returns a CallOption that will set the given Codec to be -// used for all request and response messages for a call. The result of calling -// String() will be used as the content-subtype in a case-insensitive manner. +// ForceCodec returns a CallOption that will set codec to be used for all +// request and response messages for a call. The result of calling Name() will +// be used as the content-subtype after converting to lowercase, unless +// CallContentSubtype is also used. // // See Content-Type on // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for @@ -834,33 +836,45 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { // toRPCErr converts an error into an error from the status package. func toRPCErr(err error) error { - if err == nil || err == io.EOF { + switch err { + case nil, io.EOF: return err - } - if err == io.ErrUnexpectedEOF { + case context.DeadlineExceeded: + return status.Error(codes.DeadlineExceeded, err.Error()) + case context.Canceled: + return status.Error(codes.Canceled, err.Error()) + case io.ErrUnexpectedEOF: return status.Error(codes.Internal, err.Error()) } - if _, ok := status.FromError(err); ok { - return err - } + switch e := err.(type) { case transport.ConnectionError: return status.Error(codes.Unavailable, e.Desc) - default: - switch err { - case context.DeadlineExceeded: - return status.Error(codes.DeadlineExceeded, err.Error()) - case context.Canceled: - return status.Error(codes.Canceled, err.Error()) - } + case *transport.NewStreamError: + return toRPCErr(e.Err) } + + if _, ok := status.FromError(err); ok { + return err + } + return status.Error(codes.Unknown, err.Error()) } // setCallInfoCodec should only be called after CallOptions have been applied. func setCallInfoCodec(c *callInfo) error { if c.codec != nil { - // codec was already set by a CallOption; use it. + // codec was already set by a CallOption; use it, but set the content + // subtype if it is not set. + if c.contentSubtype == "" { + // c.codec is a baseCodec to hide the difference between grpc.Codec and + // encoding.Codec (Name vs. String method name). We only support + // setting content subtype from encoding.Codec to avoid a behavior + // change with the deprecated version. + if ec, ok := c.codec.(encoding.Codec); ok { + c.contentSubtype = strings.ToLower(ec.Name()) + } + } return nil } diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index 534a3ed417b..1892c5ed766 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -181,6 +181,9 @@ type ClientOptions struct { RootOptions RootCertificateOptions // VType is the verification type on the client side. VType VerificationType + // RevocationConfig is the configurations for certificate revocation checks. + // It could be nil if such checks are not needed. + RevocationConfig *RevocationConfig } // ServerOptions contains the fields needed to be filled by the server. @@ -199,6 +202,9 @@ type ServerOptions struct { RequireClientCert bool // VType is the verification type on the server side. VType VerificationType + // RevocationConfig is the configurations for certificate revocation checks. + // It could be nil if such checks are not needed. + RevocationConfig *RevocationConfig } func (o *ClientOptions) config() (*tls.Config, error) { @@ -356,11 +362,12 @@ func (o *ServerOptions) config() (*tls.Config, error) { // advancedTLSCreds is the credentials required for authenticating a connection // using TLS. type advancedTLSCreds struct { - config *tls.Config - verifyFunc CustomVerificationFunc - getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) - isClient bool - vType VerificationType + config *tls.Config + verifyFunc CustomVerificationFunc + getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error) + isClient bool + vType VerificationType + revocationConfig *RevocationConfig } func (c advancedTLSCreds) Info() credentials.ProtocolInfo { @@ -451,6 +458,14 @@ func buildVerifyFunc(c *advancedTLSCreds, return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { chains := verifiedChains var leafCert *x509.Certificate + rawCertList := make([]*x509.Certificate, len(rawCerts)) + for i, asn1Data := range rawCerts { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return err + } + rawCertList[i] = cert + } if c.vType == CertAndHostVerification || c.vType == CertVerification { // perform possible trust credential reloading and certificate check rootCAs := c.config.RootCAs @@ -469,14 +484,6 @@ func buildVerifyFunc(c *advancedTLSCreds, rootCAs = results.TrustCerts } // Verify peers' certificates against RootCAs and get verifiedChains. - certs := make([]*x509.Certificate, len(rawCerts)) - for i, asn1Data := range rawCerts { - cert, err := x509.ParseCertificate(asn1Data) - if err != nil { - return err - } - certs[i] = cert - } keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} if !c.isClient { keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} @@ -487,7 +494,7 @@ func buildVerifyFunc(c *advancedTLSCreds, Intermediates: x509.NewCertPool(), KeyUsages: keyUsages, } - for _, cert := range certs[1:] { + for _, cert := range rawCertList[1:] { opts.Intermediates.AddCert(cert) } // Perform default hostname check if specified. @@ -501,11 +508,21 @@ func buildVerifyFunc(c *advancedTLSCreds, opts.DNSName = parsedName } var err error - chains, err = certs[0].Verify(opts) + chains, err = rawCertList[0].Verify(opts) if err != nil { return err } - leafCert = certs[0] + leafCert = rawCertList[0] + } + // Perform certificate revocation check if specified. + if c.revocationConfig != nil { + verifiedChains := chains + if verifiedChains == nil { + verifiedChains = [][]*x509.Certificate{rawCertList} + } + if err := CheckChainRevocation(verifiedChains, *c.revocationConfig); err != nil { + return err + } } // Perform custom verification check if specified. if c.verifyFunc != nil { @@ -529,11 +546,12 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error) return nil, err } tc := &advancedTLSCreds{ - config: conf, - isClient: true, - getRootCAs: o.RootOptions.GetRootCertificates, - verifyFunc: o.VerifyPeer, - vType: o.VType, + config: conf, + isClient: true, + getRootCAs: o.RootOptions.GetRootCertificates, + verifyFunc: o.VerifyPeer, + vType: o.VType, + revocationConfig: o.RevocationConfig, } tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) return tc, nil @@ -547,11 +565,12 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) return nil, err } tc := &advancedTLSCreds{ - config: conf, - isClient: false, - getRootCAs: o.RootOptions.GetRootCertificates, - verifyFunc: o.VerifyPeer, - vType: o.VType, + config: conf, + isClient: false, + getRootCAs: o.RootOptions.GetRootCertificates, + verifyFunc: o.VerifyPeer, + vType: o.VType, + revocationConfig: o.RevocationConfig, } tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos) return tc, nil diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index 4bb9e645b0a..8cddfc234b1 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -380,7 +380,7 @@ func (s) TestEnd2End(t *testing.T) { } clientTLSCreds, err := NewClientCreds(clientOptions) if err != nil { - t.Fatalf("clientTLSCreds failed to create") + t.Fatalf("clientTLSCreds failed to create: %v", err) } // ------------------------Scenario 1------------------------------------ // stage = 0, initial connection should succeed @@ -796,7 +796,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) { } clientTLSCreds, err := NewClientCreds(clientOptions) if err != nil { - t.Fatalf("clientTLSCreds failed to create") + t.Fatalf("clientTLSCreds failed to create: %v", err) } shouldFail := false if test.expectError { diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index 64da81a1700..7092d46e60f 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -27,10 +27,12 @@ import ( "net" "testing" + lru "github.com/hashicorp/golang-lru" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/security/advancedtls/internal/testutils" + "google.golang.org/grpc/security/advancedtls/testdata" ) type s struct { @@ -339,6 +341,10 @@ func (s) TestClientServerHandshake(t *testing.T) { getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { return nil, fmt.Errorf("bad root certificate reloading") } + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } for _, test := range []struct { desc string clientCert []tls.Certificate @@ -349,6 +355,7 @@ func (s) TestClientServerHandshake(t *testing.T) { clientVType VerificationType clientRootProvider certprovider.Provider clientIdentityProvider certprovider.Provider + clientRevocationConfig *RevocationConfig clientExpectHandshakeError bool serverMutualTLS bool serverCert []tls.Certificate @@ -359,6 +366,7 @@ func (s) TestClientServerHandshake(t *testing.T) { serverVType VerificationType serverRootProvider certprovider.Provider serverIdentityProvider certprovider.Provider + serverRevocationConfig *RevocationConfig serverExpectError bool }{ // Client: nil setting except verifyFuncGood @@ -642,6 +650,30 @@ func (s) TestClientServerHandshake(t *testing.T) { serverRootProvider: fakeProvider{isClient: false}, serverVType: CertVerification, }, + // Client: set valid credentials with the revocation config + // Server: set valid credentials with the revocation config + // Expected Behavior: success, because non of the certificate chains sent in the connection are revoked + { + desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS", + clientCert: []tls.Certificate{cs.ClientCert1}, + clientGetRoot: getRootCAsForClient, + clientVerifyFunc: clientVerifyFuncGood, + clientVType: CertVerification, + clientRevocationConfig: &RevocationConfig{ + RootDir: testdata.Path("crl"), + AllowUndetermined: true, + Cache: cache, + }, + serverMutualTLS: true, + serverCert: []tls.Certificate{cs.ServerCert1}, + serverGetRoot: getRootCAsForServer, + serverVType: CertVerification, + serverRevocationConfig: &RevocationConfig{ + RootDir: testdata.Path("crl"), + AllowUndetermined: true, + Cache: cache, + }, + }, } { test := test t.Run(test.desc, func(t *testing.T) { @@ -665,6 +697,7 @@ func (s) TestClientServerHandshake(t *testing.T) { RequireClientCert: test.serverMutualTLS, VerifyPeer: test.serverVerifyFunc, VType: test.serverVType, + RevocationConfig: test.serverRevocationConfig, } go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) { serverRawConn, err := lis.Accept() @@ -706,7 +739,8 @@ func (s) TestClientServerHandshake(t *testing.T) { GetRootCertificates: test.clientGetRoot, RootProvider: test.clientRootProvider, }, - VType: test.clientVType, + VType: test.clientVType, + RevocationConfig: test.clientRevocationConfig, } clientTLS, err := NewClientCreds(clientOptions) if err != nil { diff --git a/security/advancedtls/crl.go b/security/advancedtls/crl.go new file mode 100644 index 00000000000..3931c1ec629 --- /dev/null +++ b/security/advancedtls/crl.go @@ -0,0 +1,499 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package advancedtls + +import ( + "bytes" + "crypto/sha1" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + "time" + + "google.golang.org/grpc/grpclog" +) + +var grpclogLogger = grpclog.Component("advancedtls") + +// Cache is an interface to cache CRL files. +// The cache implementation must be concurrency safe. +// A fixed size lru cache from golang-lru is recommended. +type Cache interface { + // Add adds a value to the cache. + Add(key, value interface{}) bool + // Get looks up a key's value from the cache. + Get(key interface{}) (value interface{}, ok bool) +} + +// RevocationConfig contains options for CRL lookup. +type RevocationConfig struct { + // RootDir is the directory to search for CRL files. + // Directory format must match OpenSSL X509_LOOKUP_hash_dir(3). + RootDir string + // AllowUndetermined controls if certificate chains with RevocationUndetermined + // revocation status are allowed to complete. + AllowUndetermined bool + // Cache will store CRL files if not nil, otherwise files are reloaded for every lookup. + Cache Cache +} + +// RevocationStatus is the revocation status for a certificate or chain. +type RevocationStatus int + +const ( + // RevocationUndetermined means we couldn't find or verify a CRL for the cert. + RevocationUndetermined RevocationStatus = iota + // RevocationUnrevoked means we found the CRL for the cert and the cert is not revoked. + RevocationUnrevoked + // RevocationRevoked means we found the CRL and the cert is revoked. + RevocationRevoked +) + +func (s RevocationStatus) String() string { + return [...]string{"RevocationUndetermined", "RevocationUnrevoked", "RevocationRevoked"}[s] +} + +// certificateListExt contains a pkix.CertificateList and parsed +// extensions that aren't provided by the golang CRL parser. +type certificateListExt struct { + CertList *pkix.CertificateList + // RFC5280, 5.2.1, all conforming CRLs must have a AKID with the ID method. + AuthorityKeyID []byte +} + +const tagDirectoryName = 4 + +var ( + // RFC5280, 5.2.4 id-ce-deltaCRLIndicator OBJECT IDENTIFIER ::= { id-ce 27 } + oidDeltaCRLIndicator = asn1.ObjectIdentifier{2, 5, 29, 27} + // RFC5280, 5.2.5 id-ce-issuingDistributionPoint OBJECT IDENTIFIER ::= { id-ce 28 } + oidIssuingDistributionPoint = asn1.ObjectIdentifier{2, 5, 29, 28} + // RFC5280, 5.3.3 id-ce-certificateIssuer OBJECT IDENTIFIER ::= { id-ce 29 } + oidCertificateIssuer = asn1.ObjectIdentifier{2, 5, 29, 29} + // RFC5290, 4.2.1.1 id-ce-authorityKeyIdentifier OBJECT IDENTIFIER ::= { id-ce 35 } + oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35} +) + +// x509NameHash implements the OpenSSL X509_NAME_hash function for hashed directory lookups. +func x509NameHash(r pkix.RDNSequence) string { + var canonBytes []byte + // First, canonicalize all the strings. + for _, rdnSet := range r { + for i, rdn := range rdnSet { + value, ok := rdn.Value.(string) + if !ok { + continue + } + // OpenSSL trims all whitespace, does a tolower, and removes extra spaces between words. + // Implemented in x509_name_canon in OpenSSL + canonStr := strings.Join(strings.Fields( + strings.TrimSpace(strings.ToLower(value))), " ") + // Then it changes everything to UTF8 strings + rdnSet[i].Value = asn1.RawValue{Tag: asn1.TagUTF8String, Bytes: []byte(canonStr)} + + } + } + + // Finally, OpenSSL drops the initial sequence tag + // so we marshal all the RDNs separately instead of as a group. + for _, canonRdn := range r { + b, err := asn1.Marshal(canonRdn) + if err != nil { + continue + } + canonBytes = append(canonBytes, b...) + } + + issuerHash := sha1.Sum(canonBytes) + // Openssl takes the first 4 bytes and encodes them as a little endian + // uint32 and then uses the hex to make the file name. + // In C++, this would be: + // (((unsigned long)md[0]) | ((unsigned long)md[1] << 8L) | + // ((unsigned long)md[2] << 16L) | ((unsigned long)md[3] << 24L) + // ) & 0xffffffffL; + fileHash := binary.LittleEndian.Uint32(issuerHash[0:4]) + return fmt.Sprintf("%08x", fileHash) +} + +// CheckRevocation checks the connection for revoked certificates based on RFC5280. +// This implementation has the following major limitations: +// * Indirect CRL files are not supported. +// * CRL loading is only supported from directories in the X509_LOOKUP_hash_dir format. +// * OnlySomeReasons is not supported. +// * Delta CRL files are not supported. +// * Certificate CRLDistributionPoint must be URLs, but are then ignored and converted into a file path. +// * CRL checks are done after path building, which goes against RFC4158. +func CheckRevocation(conn tls.ConnectionState, cfg RevocationConfig) error { + return CheckChainRevocation(conn.VerifiedChains, cfg) +} + +// CheckChainRevocation checks the verified certificate chain +// for revoked certificates based on RFC5280. +func CheckChainRevocation(verifiedChains [][]*x509.Certificate, cfg RevocationConfig) error { + // Iterate the verified chains looking for one that is RevocationUnrevoked. + // A single RevocationUnrevoked chain is enough to allow the connection, and a single RevocationRevoked + // chain does not mean the connection should fail. + count := make(map[RevocationStatus]int) + for _, chain := range verifiedChains { + switch checkChain(chain, cfg) { + case RevocationUnrevoked: + // If any chain is RevocationUnrevoked then return no error. + return nil + case RevocationRevoked: + // If this chain is revoked, keep looking for another chain. + count[RevocationRevoked]++ + continue + case RevocationUndetermined: + if cfg.AllowUndetermined { + return nil + } + count[RevocationUndetermined]++ + continue + } + } + return fmt.Errorf("no unrevoked chains found: %v", count) +} + +// checkChain will determine and check all certificates in chain against the CRL +// defined in the certificate with the following rules: +// 1. If any certificate is RevocationRevoked, return RevocationRevoked. +// 2. If any certificate is RevocationUndetermined, return RevocationUndetermined. +// 3. If all certificates are RevocationUnrevoked, return RevocationUnrevoked. +func checkChain(chain []*x509.Certificate, cfg RevocationConfig) RevocationStatus { + chainStatus := RevocationUnrevoked + for _, c := range chain { + switch checkCert(c, chain, cfg) { + case RevocationRevoked: + // Easy case, if a cert in the chain is revoked, the chain is revoked. + return RevocationRevoked + case RevocationUndetermined: + // If we couldn't find the revocation status for a cert, the chain is at best RevocationUndetermined + // keep looking to see if we find a cert in the chain that's RevocationRevoked, + // but return RevocationUndetermined at a minimum. + chainStatus = RevocationUndetermined + case RevocationUnrevoked: + // Continue iterating up the cert chain. + continue + } + } + return chainStatus +} + +func cachedCrl(rawIssuer []byte, cache Cache) (*certificateListExt, bool) { + val, ok := cache.Get(hex.EncodeToString(rawIssuer)) + if !ok { + return nil, false + } + crl, ok := val.(*certificateListExt) + if !ok { + return nil, false + } + // If the CRL is expired, force a reload. + if crl.CertList.HasExpired(time.Now()) { + return nil, false + } + return crl, true +} + +// fetchIssuerCRL fetches and verifies the CRL for rawIssuer from disk or cache if configured in cfg. +func fetchIssuerCRL(crlDistributionPoint string, rawIssuer []byte, crlVerifyCrt []*x509.Certificate, cfg RevocationConfig) (*certificateListExt, error) { + if cfg.Cache != nil { + if crl, ok := cachedCrl(rawIssuer, cfg.Cache); ok { + return crl, nil + } + } + + crl, err := fetchCRL(crlDistributionPoint, rawIssuer, cfg) + if err != nil { + return nil, fmt.Errorf("fetchCRL(%v) failed err = %v", crlDistributionPoint, err) + } + + if err := verifyCRL(crl, rawIssuer, crlVerifyCrt); err != nil { + return nil, fmt.Errorf("verifyCRL(%v) failed err = %v", crlDistributionPoint, err) + } + if cfg.Cache != nil { + cfg.Cache.Add(hex.EncodeToString(rawIssuer), crl) + } + return crl, nil +} + +// checkCert checks a single certificate against the CRL defined in the certificate. +// It will fetch and verify the CRL(s) defined by CRLDistributionPoints. +// If we can't load any authoritative CRL files, the status is RevocationUndetermined. +// c is the certificate to check. +// crlVerifyCrt is the group of possible certificates to verify the crl. +func checkCert(c *x509.Certificate, crlVerifyCrt []*x509.Certificate, cfg RevocationConfig) RevocationStatus { + if len(c.CRLDistributionPoints) == 0 { + return RevocationUnrevoked + } + // Iterate through CRL distribution points to check for status + for _, dp := range c.CRLDistributionPoints { + crl, err := fetchIssuerCRL(dp, c.RawIssuer, crlVerifyCrt, cfg) + if err != nil { + grpclogLogger.Warningf("getIssuerCRL(%v) err = %v", c.Issuer, err) + continue + } + revocation, err := checkCertRevocation(c, crl) + if err != nil { + grpclogLogger.Warningf("checkCertRevocation(CRL %v) failed %v", crl.CertList.TBSCertList.Issuer, err) + // We couldn't check the CRL file for some reason, so continue + // to the next file + continue + } + // Here we've gotten a CRL that loads and verifies. + // We only handle all-reasons CRL files, so this file + // is authoritative for the certificate. + return revocation + + } + // We couldn't load any CRL files for the certificate, so we don't know if it's RevocationUnrevoked or not. + return RevocationUndetermined +} + +func checkCertRevocation(c *x509.Certificate, crl *certificateListExt) (RevocationStatus, error) { + // Per section 5.3.3 we prime the certificate issuer with the CRL issuer. + // Subsequent entries use the previous entry's issuer. + rawEntryIssuer, err := asn1.Marshal(crl.CertList.TBSCertList.Issuer) + if err != nil { + return RevocationUndetermined, err + } + + // Loop through all the revoked certificates. + for _, revCert := range crl.CertList.TBSCertList.RevokedCertificates { + // 5.3 Loop through CRL entry extensions for needed information. + for _, ext := range revCert.Extensions { + if oidCertificateIssuer.Equal(ext.Id) { + extIssuer, err := parseCertIssuerExt(ext) + if err != nil { + grpclogLogger.Info(err) + if ext.Critical { + return RevocationUndetermined, err + } + // Since this is a non-critical extension, we can skip it even though + // there was a parsing failure. + continue + } + rawEntryIssuer = extIssuer + } else if ext.Critical { + return RevocationUndetermined, fmt.Errorf("checkCertRevocation: Unhandled critical extension: %v", ext.Id) + } + } + + // If the issuer and serial number appear in the CRL, the certificate is revoked. + if bytes.Equal(c.RawIssuer, rawEntryIssuer) && c.SerialNumber.Cmp(revCert.SerialNumber) == 0 { + // CRL contains the serial, so return revoked. + return RevocationRevoked, nil + } + } + // We did not find the serial in the CRL file that was valid for the cert + // so the certificate is not revoked. + return RevocationUnrevoked, nil +} + +func parseCertIssuerExt(ext pkix.Extension) ([]byte, error) { + // 5.3.3 Certificate Issuer + // CertificateIssuer ::= GeneralNames + // GeneralNames ::= SEQUENCE SIZE (1..MAX) OF GeneralName + var generalNames []asn1.RawValue + if rest, err := asn1.Unmarshal(ext.Value, &generalNames); err != nil || len(rest) != 0 { + return nil, fmt.Errorf("asn1.Unmarshal failed err = %v", err) + } + + for _, generalName := range generalNames { + // GeneralName ::= CHOICE { + // otherName [0] OtherName, + // rfc822Name [1] IA5String, + // dNSName [2] IA5String, + // x400Address [3] ORAddress, + // directoryName [4] Name, + // ediPartyName [5] EDIPartyName, + // uniformResourceIdentifier [6] IA5String, + // iPAddress [7] OCTET STRING, + // registeredID [8] OBJECT IDENTIFIER } + if generalName.Tag == tagDirectoryName { + return generalName.Bytes, nil + } + } + // Conforming CRL issuers MUST include in this extension the + // distinguished name (DN) from the issuer field of the certificate that + // corresponds to this CRL entry. + // If we couldn't get a directoryName, we can't reason about this file so cert status is + // RevocationUndetermined. + return nil, errors.New("no DN found in certificate issuer") +} + +// RFC 5280, 4.2.1.1 +type authKeyID struct { + ID []byte `asn1:"optional,tag:0"` +} + +// RFC5280, 5.2.5 +// id-ce-issuingDistributionPoint OBJECT IDENTIFIER ::= { id-ce 28 } + +// IssuingDistributionPoint ::= SEQUENCE { +// distributionPoint [0] DistributionPointName OPTIONAL, +// onlyContainsUserCerts [1] BOOLEAN DEFAULT FALSE, +// onlyContainsCACerts [2] BOOLEAN DEFAULT FALSE, +// onlySomeReasons [3] ReasonFlags OPTIONAL, +// indirectCRL [4] BOOLEAN DEFAULT FALSE, +// onlyContainsAttributeCerts [5] BOOLEAN DEFAULT FALSE } + +// -- at most one of onlyContainsUserCerts, onlyContainsCACerts, +// -- and onlyContainsAttributeCerts may be set to TRUE. +type issuingDistributionPoint struct { + DistributionPoint asn1.RawValue `asn1:"optional,tag:0"` + OnlyContainsUserCerts bool `asn1:"optional,tag:1"` + OnlyContainsCACerts bool `asn1:"optional,tag:2"` + OnlySomeReasons asn1.BitString `asn1:"optional,tag:3"` + IndirectCRL bool `asn1:"optional,tag:4"` + OnlyContainsAttributeCerts bool `asn1:"optional,tag:5"` +} + +// parseCRLExtensions parses the extensions for a CRL +// and checks that they're supported by the parser. +func parseCRLExtensions(c *pkix.CertificateList) (*certificateListExt, error) { + if c == nil { + return nil, errors.New("c is nil, expected any value") + } + certList := &certificateListExt{CertList: c} + + for _, ext := range c.TBSCertList.Extensions { + switch { + case oidDeltaCRLIndicator.Equal(ext.Id): + return nil, fmt.Errorf("delta CRLs unsupported") + + case oidAuthorityKeyIdentifier.Equal(ext.Id): + var a authKeyID + if rest, err := asn1.Unmarshal(ext.Value, &a); err != nil { + return nil, fmt.Errorf("asn1.Unmarshal failed. err = %v", err) + } else if len(rest) != 0 { + return nil, errors.New("trailing data after AKID extension") + } + certList.AuthorityKeyID = a.ID + + case oidIssuingDistributionPoint.Equal(ext.Id): + var dp issuingDistributionPoint + if rest, err := asn1.Unmarshal(ext.Value, &dp); err != nil { + return nil, fmt.Errorf("asn1.Unmarshal failed. err = %v", err) + } else if len(rest) != 0 { + return nil, errors.New("trailing data after IssuingDistributionPoint extension") + } + + if dp.OnlyContainsUserCerts || dp.OnlyContainsCACerts || dp.OnlyContainsAttributeCerts { + return nil, errors.New("CRL only contains some certificate types") + } + if dp.IndirectCRL { + return nil, errors.New("indirect CRLs unsupported") + } + if dp.OnlySomeReasons.BitLength != 0 { + return nil, errors.New("onlySomeReasons unsupported") + } + + case ext.Critical: + return nil, fmt.Errorf("unsupported critical extension: %v", ext.Id) + } + } + + if len(certList.AuthorityKeyID) == 0 { + return nil, errors.New("authority key identifier extension missing") + } + return certList, nil +} + +func fetchCRL(loc string, rawIssuer []byte, cfg RevocationConfig) (*certificateListExt, error) { + var parsedCRL *certificateListExt + // 6.3.3 (a) (1) (ii) + // According to X509_LOOKUP_hash_dir the format is issuer_hash.rN where N is an increasing number. + // There are no gaps, so we break when we can't find a file. + for i := 0; ; i++ { + // Unmarshal to RDNSeqence according to http://go/godoc/crypto/x509/pkix/#Name. + var r pkix.RDNSequence + rest, err := asn1.Unmarshal(rawIssuer, &r) + if len(rest) != 0 || err != nil { + return nil, fmt.Errorf("asn1.Unmarshal(Issuer) len(rest) = %v, err = %v", len(rest), err) + } + crlPath := fmt.Sprintf("%s.r%d", filepath.Join(cfg.RootDir, x509NameHash(r)), i) + crlBytes, err := ioutil.ReadFile(crlPath) + if err != nil { + // Break when we can't read a CRL file. + grpclogLogger.Infof("readFile: %v", err) + break + } + + crl, err := x509.ParseCRL(crlBytes) + if err != nil { + // Parsing errors for a CRL shouldn't happen so fail. + return nil, fmt.Errorf("x509.ParseCrl(%v) failed err = %v", crlPath, err) + } + var certList *certificateListExt + if certList, err = parseCRLExtensions(crl); err != nil { + grpclogLogger.Infof("fetchCRL: unsupported crl %v, err = %v", crlPath, err) + // Continue to find a supported CRL + continue + } + + rawCRLIssuer, err := asn1.Marshal(certList.CertList.TBSCertList.Issuer) + if err != nil { + return nil, fmt.Errorf("asn1.Marshal(%v) failed err = %v", certList.CertList.TBSCertList.Issuer, err) + } + // RFC5280, 6.3.3 (b) Verify the issuer and scope of the complete CRL. + if bytes.Equal(rawIssuer, rawCRLIssuer) { + parsedCRL = certList + // Continue to find the highest number in the .rN suffix. + continue + } + } + + if parsedCRL == nil { + return nil, fmt.Errorf("fetchCrls no CRLs found for issuer") + } + return parsedCRL, nil +} + +func verifyCRL(crl *certificateListExt, rawIssuer []byte, chain []*x509.Certificate) error { + // RFC5280, 6.3.3 (f) Obtain and validateate the certification path for the issuer of the complete CRL + // We intentionally limit our CRLs to be signed with the same certificate path as the certificate + // so we can use the chain from the connection. + rawCRLIssuer, err := asn1.Marshal(crl.CertList.TBSCertList.Issuer) + if err != nil { + return fmt.Errorf("asn1.Marshal(%v) failed err = %v", crl.CertList.TBSCertList.Issuer, err) + } + + for _, c := range chain { + // Use the key where the subject and KIDs match. + // This departs from RFC4158, 3.5.12 which states that KIDs + // cannot eliminate certificates, but RFC5280, 5.2.1 states that + // "Conforming CRL issuers MUST use the key identifier method, and MUST + // include this extension in all CRLs issued." + // So, this is much simpler than RFC4158 and should be compatible. + if bytes.Equal(c.SubjectKeyId, crl.AuthorityKeyID) && bytes.Equal(c.RawSubject, rawCRLIssuer) { + // RFC5280, 6.3.3 (g) Validate signature. + return c.CheckCRLSignature(crl.CertList) + } + } + return fmt.Errorf("verifyCRL: No certificates mached CRL issuer (%v)", crl.CertList.TBSCertList.Issuer) +} diff --git a/security/advancedtls/crl_test.go b/security/advancedtls/crl_test.go new file mode 100644 index 00000000000..ec4483304c7 --- /dev/null +++ b/security/advancedtls/crl_test.go @@ -0,0 +1,718 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package advancedtls + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/hex" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "os" + "path" + "strings" + "testing" + "time" + + lru "github.com/hashicorp/golang-lru" + "google.golang.org/grpc/security/advancedtls/testdata" +) + +func TestX509NameHash(t *testing.T) { + nameTests := []struct { + in pkix.Name + out string + }{ + { + in: pkix.Name{ + Country: []string{"US"}, + Organization: []string{"Example"}, + }, + out: "9cdd41ff", + }, + { + in: pkix.Name{ + Country: []string{"us"}, + Organization: []string{"example"}, + }, + out: "9cdd41ff", + }, + { + in: pkix.Name{ + Country: []string{" us"}, + Organization: []string{"example"}, + }, + out: "9cdd41ff", + }, + { + in: pkix.Name{ + Country: []string{"US"}, + Province: []string{"California"}, + Locality: []string{"Mountain View"}, + Organization: []string{"BoringSSL"}, + }, + out: "c24414d9", + }, + { + in: pkix.Name{ + Country: []string{"US"}, + Province: []string{"California"}, + Locality: []string{"Mountain View"}, + Organization: []string{"BoringSSL"}, + }, + out: "c24414d9", + }, + { + in: pkix.Name{ + SerialNumber: "87f4514475ba0a2b", + }, + out: "9dc713cd", + }, + { + in: pkix.Name{ + Country: []string{"US"}, + Province: []string{"California"}, + Locality: []string{"Mountain View"}, + Organization: []string{"Google LLC"}, + OrganizationalUnit: []string{"Production", "campus-sln"}, + CommonName: "Root CA (2021-02-02T07:30:36-08:00)", + }, + out: "0b35a562", + }, + { + in: pkix.Name{ + ExtraNames: []pkix.AttributeTypeAndValue{ + {Type: asn1.ObjectIdentifier{5, 5, 5, 5}, Value: "aaaa"}, + }, + }, + out: "eea339da", + }, + } + for _, tt := range nameTests { + t.Run(tt.in.String(), func(t *testing.T) { + h := x509NameHash(tt.in.ToRDNSequence()) + if h != tt.out { + t.Errorf("x509NameHash(%v): Got %v wanted %v", tt.in, h, tt.out) + } + }) + } +} + +func TestUnsupportedCRLs(t *testing.T) { + crlBytesSomeReasons := []byte(`-----BEGIN X509 CRL----- +MIIEeDCCA2ACAQEwDQYJKoZIhvcNAQELBQAwQjELMAkGA1UEBhMCVVMxHjAcBgNV +BAoTFUdvb2dsZSBUcnVzdCBTZXJ2aWNlczETMBEGA1UEAxMKR1RTIENBIDFPMRcN +MjEwNDI2MTI1OTQxWhcNMjEwNTA2MTE1OTQwWjCCAn0wIgIRAPOOG3L4VLC7CAAA +AABxQgEXDTIxMDQxOTEyMTgxOFowIQIQUK0UwBZkVdQIAAAAAHFCBRcNMjEwNDE5 +MTIxODE4WjAhAhBRIXBJaKoQkQgAAAAAcULHFw0yMTA0MjAxMjE4MTdaMCICEQCv +qQWUq5UxmQgAAAAAcULMFw0yMTA0MjAxMjE4MTdaMCICEQDdv5k1kKwKTQgAAAAA +cUOQFw0yMTA0MjExMjE4MTZaMCICEQDGIEfR8N9sEAgAAAAAcUOWFw0yMTA0MjEx +MjE4MThaMCECEBHgbLXlj5yUCAAAAABxQ/IXDTIxMDQyMTIzMDAyNlowIQIQE1wT +2GGYqKwIAAAAAHFD7xcNMjEwNDIxMjMwMDI5WjAiAhEAo/bSyDjpVtsIAAAAAHFE +txcNMjEwNDIyMjMwMDI3WjAhAhARdCrSrHE0dAgAAAAAcUS/Fw0yMTA0MjIyMzAw +MjhaMCECEHONohfWn3wwCAAAAABxRX8XDTIxMDQyMzIzMDAyOVowIgIRAOYkiUPA +os4vCAAAAABxRYgXDTIxMDQyMzIzMDAyOFowIQIQRNTow5Eg2gEIAAAAAHFGShcN +MjEwNDI0MjMwMDI2WjAhAhBX32dH4/WQ6AgAAAAAcUZNFw0yMTA0MjQyMzAwMjZa +MCICEQDHnUM1vsaP/wgAAAAAcUcQFw0yMTA0MjUyMzAwMjZaMCECEEm5rvmL8sj6 +CAAAAABxRxQXDTIxMDQyNTIzMDAyN1owIQIQW16OQs4YQYkIAAAAAHFIABcNMjEw +NDI2MTI1NDA4WjAhAhAhSohpYsJtDQgAAAAAcUgEFw0yMTA0MjYxMjU0MDlaoGkw +ZzAfBgNVHSMEGDAWgBSY0fhuEOvPm+xgnxiQG6DrfQn9KzALBgNVHRQEBAICBngw +NwYDVR0cAQH/BC0wK6AmoCSGImh0dHA6Ly9jcmwucGtpLmdvb2cvR1RTMU8xY29y +ZS5jcmyBAf8wDQYJKoZIhvcNAQELBQADggEBADPBXbxVxMJ1HC7btXExRUpJHUlU +YbeCZGx6zj5F8pkopbmpV7cpewwhm848Fx4VaFFppZQZd92O08daEC6aEqoug4qF +z6ZrOLzhuKfpW8E93JjgL91v0FYN7iOcT7+ERKCwVEwEkuxszxs7ggW6OJYJNvHh +priIdmcPoiQ3ZrIRH0vE3BfUcNXnKFGATWuDkiRI0I4A5P7NiOf+lAuGZet3/eom +0chgts6sdau10GfeUpHUd4f8e93cS/QeLeG16z7LC8vRLstU3m3vrknpZbdGqSia +97w66mqcnQh9V0swZiEnVLmLufaiuDZJ+6nUzSvLqBlb/ei3T/tKV0BoKJA= +-----END X509 CRL-----`) + + crlBytesIndirect := []byte(`-----BEGIN X509 CRL----- +MIIDGjCCAgICAQEwDQYJKoZIhvcNAQELBQAwdjELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCkNhbGlmb3JuaWExFDASBgNVBAoTC1Rlc3RpbmcgTHRkMSowKAYDVQQLEyFU +ZXN0aW5nIEx0ZCBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkxEDAOBgNVBAMTB1Rlc3Qg +Q0EXDTIxMDExNjAyMjAxNloXDTIxMDEyMDA2MjAxNlowgfIwbAIBAhcNMjEwMTE2 +MDIyMDE2WjBYMAoGA1UdFQQDCgEEMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQG +EwNVU0ExDTALBgNVBAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0 +MTAgAgEDFw0yMTAxMTYwMjIwMTZaMAwwCgYDVR0VBAMKAQEwYAIBBBcNMjEwMTE2 +MDIyMDE2WjBMMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQGEwNVU0ExDTALBgNV +BAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0MqBjMGEwHwYDVR0j +BBgwFoAURJSDWAOfhGCryBjl8dsQjBitl3swCgYDVR0UBAMCAQEwMgYDVR0cAQH/ +BCgwJqAhoB+GHWh0dHA6Ly9jcmxzLnBraS5nb29nL3Rlc3QuY3JshAH/MA0GCSqG +SIb3DQEBCwUAA4IBAQBVXX67mr2wFPmEWCe6mf/wFnPl3xL6zNOl96YJtsd7ulcS +TEbdJpaUnWFQ23+Tpzdj/lI2aQhTg5Lvii3o+D8C5r/Jc5NhSOtVJJDI/IQLh4pG +NgGdljdbJQIT5D2Z71dgbq1ocxn8DefZIJjO3jp8VnAm7AIMX2tLTySzD2MpMeMq +XmcN4lG1e4nx+xjzp7MySYO42NRY3LkphVzJhu3dRBYhBKViRJxw9hLttChitJpF +6Kh6a0QzrEY/QDJGhE1VrAD2c5g/SKnHPDVoCWo4ACIICi76KQQSIWfIdp4W/SY3 +qsSIp8gfxSyzkJP+Ngkm2DdLjlJQCZ9R0MZP9Xj4 +-----END X509 CRL-----`) + + var tests = []struct { + desc string + in []byte + }{ + { + desc: "some reasons", + in: crlBytesSomeReasons, + }, + { + desc: "indirect", + in: crlBytesIndirect, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + crl, err := x509.ParseCRL(tt.in) + if err != nil { + t.Fatal(err) + } + if _, err := parseCRLExtensions(crl); err == nil { + t.Error("expected error got ok") + } + }) + } +} + +func TestCheckCertRevocation(t *testing.T) { + dummyCrlFile := []byte(`-----BEGIN X509 CRL----- +MIIDGjCCAgICAQEwDQYJKoZIhvcNAQELBQAwdjELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCkNhbGlmb3JuaWExFDASBgNVBAoTC1Rlc3RpbmcgTHRkMSowKAYDVQQLEyFU +ZXN0aW5nIEx0ZCBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkxEDAOBgNVBAMTB1Rlc3Qg +Q0EXDTIxMDExNjAyMjAxNloXDTIxMDEyMDA2MjAxNlowgfIwbAIBAhcNMjEwMTE2 +MDIyMDE2WjBYMAoGA1UdFQQDCgEEMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQG +EwNVU0ExDTALBgNVBAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0 +MTAgAgEDFw0yMTAxMTYwMjIwMTZaMAwwCgYDVR0VBAMKAQEwYAIBBBcNMjEwMTE2 +MDIyMDE2WjBMMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQGEwNVU0ExDTALBgNV +BAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0MqBjMGEwHwYDVR0j +BBgwFoAURJSDWAOfhGCryBjl8dsQjBitl3swCgYDVR0UBAMCAQEwMgYDVR0cAQH/ +BCgwJqAhoB+GHWh0dHA6Ly9jcmxzLnBraS5nb29nL3Rlc3QuY3JshAH/MA0GCSqG +SIb3DQEBCwUAA4IBAQBVXX67mr2wFPmEWCe6mf/wFnPl3xL6zNOl96YJtsd7ulcS +TEbdJpaUnWFQ23+Tpzdj/lI2aQhTg5Lvii3o+D8C5r/Jc5NhSOtVJJDI/IQLh4pG +NgGdljdbJQIT5D2Z71dgbq1ocxn8DefZIJjO3jp8VnAm7AIMX2tLTySzD2MpMeMq +XmcN4lG1e4nx+xjzp7MySYO42NRY3LkphVzJhu3dRBYhBKViRJxw9hLttChitJpF +6Kh6a0QzrEY/QDJGhE1VrAD2c5g/SKnHPDVoCWo4ACIICi76KQQSIWfIdp4W/SY3 +qsSIp8gfxSyzkJP+Ngkm2DdLjlJQCZ9R0MZP9Xj4 +-----END X509 CRL-----`) + crl, err := x509.ParseCRL(dummyCrlFile) + if err != nil { + t.Fatalf("x509.ParseCRL(dummyCrlFile) failed: %v", err) + } + crlExt := &certificateListExt{CertList: crl} + var crlIssuer pkix.Name + crlIssuer.FillFromRDNSequence(&crl.TBSCertList.Issuer) + + var revocationTests = []struct { + desc string + in x509.Certificate + revoked RevocationStatus + }{ + { + desc: "Single revoked", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test1", + }, + SerialNumber: big.NewInt(2), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationRevoked, + }, + { + desc: "Revoked no entry issuer", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test1", + }, + SerialNumber: big.NewInt(3), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationRevoked, + }, + { + desc: "Revoked new entry issuer", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test2", + }, + SerialNumber: big.NewInt(4), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationRevoked, + }, + { + desc: "Single unrevoked", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test2", + }, + SerialNumber: big.NewInt(1), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationUnrevoked, + }, + { + desc: "Single unrevoked Issuer", + in: x509.Certificate{ + Issuer: crlIssuer, + SerialNumber: big.NewInt(2), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationUnrevoked, + }, + } + + for _, tt := range revocationTests { + rawIssuer, err := asn1.Marshal(tt.in.Issuer.ToRDNSequence()) + if err != nil { + t.Fatalf("asn1.Marshal(%v) failed: %v", tt.in.Issuer.ToRDNSequence(), err) + } + tt.in.RawIssuer = rawIssuer + t.Run(tt.desc, func(t *testing.T) { + rev, err := checkCertRevocation(&tt.in, crlExt) + if err != nil { + t.Errorf("checkCertRevocation(%v) err = %v", tt.in.Issuer, err) + } else if rev != tt.revoked { + t.Errorf("checkCertRevocation(%v(%v)) returned %v wanted %v", + tt.in.Issuer, tt.in.SerialNumber, rev, tt.revoked) + } + }) + } +} + +func makeChain(t *testing.T, name string) []*x509.Certificate { + t.Helper() + + certChain := make([]*x509.Certificate, 0) + + rest, err := ioutil.ReadFile(name) + if err != nil { + t.Fatalf("ioutil.ReadFile(%v) failed %v", name, err) + } + for len(rest) > 0 { + var block *pem.Block + block, rest = pem.Decode(rest) + c, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("ParseCertificate error %v", err) + } + t.Logf("Parsed Cert sub = %v iss = %v", c.Subject, c.Issuer) + certChain = append(certChain, c) + } + return certChain +} + +func loadCRL(t *testing.T, path string) *pkix.CertificateList { + b, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("readFile(%v) failed err = %v", path, err) + } + crl, err := x509.ParseCRL(b) + if err != nil { + t.Fatalf("ParseCrl(%v) failed err = %v", path, err) + } + return crl +} + +func TestCachedCRL(t *testing.T) { + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } + + tests := []struct { + desc string + val interface{} + ok bool + }{ + { + desc: "Valid", + val: &certificateListExt{ + CertList: &pkix.CertificateList{ + TBSCertList: pkix.TBSCertificateList{ + NextUpdate: time.Now().Add(time.Hour), + }, + }}, + ok: true, + }, + { + desc: "Expired", + val: &certificateListExt{ + CertList: &pkix.CertificateList{ + TBSCertList: pkix.TBSCertificateList{ + NextUpdate: time.Now().Add(-time.Hour), + }, + }}, + ok: false, + }, + { + desc: "Wrong Type", + val: "string", + ok: false, + }, + { + desc: "Empty", + val: nil, + ok: false, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if tt.val != nil { + cache.Add(hex.EncodeToString([]byte(tt.desc)), tt.val) + } + _, ok := cachedCrl([]byte(tt.desc), cache) + if tt.ok != ok { + t.Errorf("Cache ok error expected %v vs %v", tt.ok, ok) + } + }) + } +} + +func TestGetIssuerCRLCache(t *testing.T) { + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } + + tests := []struct { + desc string + rawIssuer []byte + certs []*x509.Certificate + }{ + { + desc: "Valid", + rawIssuer: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1].RawIssuer, + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + }, + { + desc: "Unverified", + rawIssuer: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1].RawIssuer, + }, + { + desc: "Not Found", + rawIssuer: []byte("not_found"), + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + cache.Purge() + _, err := fetchIssuerCRL("test", tt.rawIssuer, tt.certs, RevocationConfig{ + RootDir: testdata.Path("."), + Cache: cache, + }) + if err == nil && cache.Len() == 0 { + t.Error("Verified CRL not added to cache") + } + if err != nil && cache.Len() != 0 { + t.Error("Unverified CRL added to cache") + } + }) + } +} + +func TestVerifyCrl(t *testing.T) { + tampered := loadCRL(t, testdata.Path("crl/1.crl")) + // Change the signature so it won't verify + tampered.SignatureValue.Bytes[0]++ + + verifyTests := []struct { + desc string + crl *pkix.CertificateList + certs []*x509.Certificate + cert *x509.Certificate + errWant string + }{ + { + desc: "Pass intermediate", + crl: loadCRL(t, testdata.Path("crl/1.crl")), + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1], + errWant: "", + }, + { + desc: "Pass leaf", + crl: loadCRL(t, testdata.Path("crl/2.crl")), + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[2], + errWant: "", + }, + { + desc: "Fail wrong cert chain", + crl: loadCRL(t, testdata.Path("crl/3.crl")), + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/revokedInt.pem"))[1], + errWant: "No certificates mached", + }, + { + desc: "Fail no certs", + crl: loadCRL(t, testdata.Path("crl/1.crl")), + certs: []*x509.Certificate{}, + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1], + errWant: "No certificates mached", + }, + { + desc: "Fail Tampered signature", + crl: tampered, + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1], + errWant: "verification failure", + }, + } + + for _, tt := range verifyTests { + t.Run(tt.desc, func(t *testing.T) { + crlExt, err := parseCRLExtensions(tt.crl) + if err != nil { + t.Fatalf("parseCRLExtensions(%v) failed, err = %v", tt.crl.TBSCertList.Issuer, err) + } + err = verifyCRL(crlExt, tt.cert.RawIssuer, tt.certs) + switch { + case tt.errWant == "" && err != nil: + t.Errorf("Valid CRL did not verify err = %v", err) + case tt.errWant != "" && err == nil: + t.Error("Invalid CRL verified") + case tt.errWant != "" && !strings.Contains(err.Error(), tt.errWant): + t.Errorf("fetchIssuerCRL(_, %v, %v, _) = %v; want Contains(%v)", tt.cert.RawIssuer, tt.certs, err, tt.errWant) + } + }) + } +} + +func TestRevokedCert(t *testing.T) { + revokedIntChain := makeChain(t, testdata.Path("crl/revokedInt.pem")) + revokedLeafChain := makeChain(t, testdata.Path("crl/revokedLeaf.pem")) + validChain := makeChain(t, testdata.Path("crl/unrevoked.pem")) + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } + + var revocationTests = []struct { + desc string + in tls.ConnectionState + revoked bool + allowUndetermined bool + }{ + { + desc: "Single unrevoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{validChain}}, + revoked: false, + }, + { + desc: "Single revoked intermediate", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{revokedIntChain}}, + revoked: true, + }, + { + desc: "Single revoked leaf", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{revokedLeafChain}}, + revoked: true, + }, + { + desc: "Multi one revoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{validChain, revokedLeafChain}}, + revoked: false, + }, + { + desc: "Multi revoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{revokedLeafChain, revokedIntChain}}, + revoked: true, + }, + { + desc: "Multi unrevoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{validChain, validChain}}, + revoked: false, + }, + { + desc: "Undetermined revoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ + {&x509.Certificate{CRLDistributionPoints: []string{"test"}}}, + }}, + revoked: true, + }, + { + desc: "Undetermined allowed", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ + {&x509.Certificate{CRLDistributionPoints: []string{"test"}}}, + }}, + revoked: false, + allowUndetermined: true, + }, + } + + for _, tt := range revocationTests { + t.Run(tt.desc, func(t *testing.T) { + err := CheckRevocation(tt.in, RevocationConfig{ + RootDir: testdata.Path("crl"), + AllowUndetermined: tt.allowUndetermined, + Cache: cache, + }) + t.Logf("CheckRevocation err = %v", err) + if tt.revoked && err == nil { + t.Error("Revoked certificate chain was allowed") + } else if !tt.revoked && err != nil { + t.Error("Unrevoked certificate not allowed") + } + }) + } +} + +func setupTLSConn(t *testing.T) (net.Listener, *x509.Certificate, *ecdsa.PrivateKey) { + t.Helper() + templ := x509.Certificate{ + SerialNumber: big.NewInt(5), + BasicConstraintsValid: true, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + Subject: pkix.Name{CommonName: "test-cert"}, + KeyUsage: x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + IPAddresses: []net.IP{net.ParseIP("::1")}, + CRLDistributionPoints: []string{"http://static.corp.google.com/crl/campus-sln/borg"}, + } + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("ecdsa.GenerateKey failed err = %v", err) + } + rawCert, err := x509.CreateCertificate(rand.Reader, &templ, &templ, key.Public(), key) + if err != nil { + t.Fatalf("x509.CreateCertificate failed err = %v", err) + } + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + t.Fatalf("x509.ParseCertificate failed err = %v", err) + } + + srvCfg := tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{cert.Raw}, + PrivateKey: key, + }, + }, + } + l, err := tls.Listen("tcp6", "[::1]:0", &srvCfg) + if err != nil { + t.Fatalf("tls.Listen failed err = %v", err) + } + return l, cert, key +} + +// TestVerifyConnection will setup a client/server connection and check revocation in the real TLS dialer +func TestVerifyConnection(t *testing.T) { + lis, cert, key := setupTLSConn(t) + defer func() { + lis.Close() + }() + + var handshakeTests = []struct { + desc string + revoked []pkix.RevokedCertificate + success bool + }{ + { + desc: "Empty CRL", + revoked: []pkix.RevokedCertificate{}, + success: true, + }, + { + desc: "Revoked Cert", + revoked: []pkix.RevokedCertificate{ + { + SerialNumber: cert.SerialNumber, + RevocationTime: time.Now(), + }, + }, + success: false, + }, + } + for _, tt := range handshakeTests { + t.Run(tt.desc, func(t *testing.T) { + // Accept one connection. + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("tls.Accept failed err = %v", err) + } else { + conn.Write([]byte("Hello, World!")) + conn.Close() + } + }() + + dir, err := ioutil.TempDir("", "crl_dir") + if err != nil { + t.Fatalf("ioutil.TempDir failed err = %v", err) + } + defer os.RemoveAll(dir) + + crl, err := cert.CreateCRL(rand.Reader, key, tt.revoked, time.Now(), time.Now().Add(time.Hour)) + if err != nil { + t.Fatalf("templ.CreateCRL failed err = %v", err) + } + + err = ioutil.WriteFile(path.Join(dir, fmt.Sprintf("%s.r0", x509NameHash(cert.Subject.ToRDNSequence()))), crl, 0777) + if err != nil { + t.Fatalf("ioutil.WriteFile failed err = %v", err) + } + + cp := x509.NewCertPool() + cp.AddCert(cert) + cliCfg := tls.Config{ + RootCAs: cp, + VerifyConnection: func(cs tls.ConnectionState) error { + return CheckRevocation(cs, RevocationConfig{RootDir: dir}) + }, + } + conn, err := tls.Dial(lis.Addr().Network(), lis.Addr().String(), &cliCfg) + t.Logf("tls.Dial err = %v", err) + if tt.success && err != nil { + t.Errorf("Expected success got err = %v", err) + } + if !tt.success && err == nil { + t.Error("Expected error, but got success") + } + if err == nil { + conn.Close() + } + }) + } +} diff --git a/security/advancedtls/examples/go.mod b/security/advancedtls/examples/go.mod index 936aa476893..20ed81e24d3 100644 --- a/security/advancedtls/examples/go.mod +++ b/security/advancedtls/examples/go.mod @@ -3,7 +3,7 @@ module google.golang.org/grpc/security/advancedtls/examples go 1.15 require ( - google.golang.org/grpc v1.33.1 + google.golang.org/grpc v1.38.0 google.golang.org/grpc/examples v0.0.0-20201112215255-90f1b3ee835b google.golang.org/grpc/security/advancedtls v0.0.0-20201112215255-90f1b3ee835b ) diff --git a/security/advancedtls/examples/go.sum b/security/advancedtls/examples/go.sum index 519267dbc27..272f1afa407 100644 --- a/security/advancedtls/examples/go.sum +++ b/security/advancedtls/examples/go.sum @@ -1,20 +1,25 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -22,35 +27,41 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 h1:LCO0fg4kb6WwkXQXRQQgUYsFeFb5taTX5WAx5O/Vt28= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= @@ -66,6 +77,6 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/security/advancedtls/go.mod b/security/advancedtls/go.mod index be35029503d..75527018ee7 100644 --- a/security/advancedtls/go.mod +++ b/security/advancedtls/go.mod @@ -4,7 +4,8 @@ go 1.14 require ( github.com/google/go-cmp v0.5.1 // indirect - google.golang.org/grpc v1.31.0 + github.com/hashicorp/golang-lru v0.5.4 + google.golang.org/grpc v1.38.0 google.golang.org/grpc/examples v0.0.0-20201112215255-90f1b3ee835b ) diff --git a/security/advancedtls/go.sum b/security/advancedtls/go.sum index 519267dbc27..272f1afa407 100644 --- a/security/advancedtls/go.sum +++ b/security/advancedtls/go.sum @@ -1,20 +1,25 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -22,35 +27,41 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 h1:LCO0fg4kb6WwkXQXRQQgUYsFeFb5taTX5WAx5O/Vt28= google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= @@ -66,6 +77,6 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/security/advancedtls/sni.go b/security/advancedtls/sni.go index 120acf2b376..3e7befb1f90 100644 --- a/security/advancedtls/sni.go +++ b/security/advancedtls/sni.go @@ -1,5 +1,3 @@ -// +build !appengine,go1.14 - /* * * Copyright 2020 gRPC authors. diff --git a/security/advancedtls/sni_beforego114.go b/security/advancedtls/sni_beforego114.go deleted file mode 100644 index 26a09b98849..00000000000 --- a/security/advancedtls/sni_beforego114.go +++ /dev/null @@ -1,42 +0,0 @@ -// +build !appengine,!go1.14 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package advancedtls - -import ( - "crypto/tls" - "fmt" -) - -// buildGetCertificates returns the first cert contained in ServerOptions for -// non-appengine builds before version 1.4. -func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) { - if o.IdentityOptions.GetIdentityCertificatesForServer == nil { - return nil, fmt.Errorf("function GetCertificates must be specified") - } - certificates, err := o.IdentityOptions.GetIdentityCertificatesForServer(clientHello) - if err != nil { - return nil, err - } - if len(certificates) == 0 { - return nil, fmt.Errorf("no certificates configured") - } - return certificates[0], nil -} diff --git a/security/advancedtls/testdata/crl/0b35a562.r0 b/security/advancedtls/testdata/crl/0b35a562.r0 new file mode 120000 index 00000000000..1a84eabdfc7 --- /dev/null +++ b/security/advancedtls/testdata/crl/0b35a562.r0 @@ -0,0 +1 @@ +5.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/0b35a562.r1 b/security/advancedtls/testdata/crl/0b35a562.r1 new file mode 120000 index 00000000000..6e6f1097891 --- /dev/null +++ b/security/advancedtls/testdata/crl/0b35a562.r1 @@ -0,0 +1 @@ +1.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/1.crl b/security/advancedtls/testdata/crl/1.crl new file mode 100644 index 00000000000..5b12ded4a66 --- /dev/null +++ b/security/advancedtls/testdata/crl/1.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAyMS0wMi0wMlQwNzozMDozNi0wODowMCkX +DTIxMDIwMjE1MzAzNloXDTIxMDIwOTE1MzAzNlqgLzAtMB8GA1UdIwQYMBaAFPQN +tnCIBcG4ReQgoVi0kPgTROseMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0gAMEUC +IQDB9WEPBPHEo5xjCv8CT9okockJJnkLDOus6FypVLqj5QIgYw9/PYLwb41/Uc+4 +LLTAsfdDWh7xBJmqvVQglMoJOEc= +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/1ab871c8.r0 b/security/advancedtls/testdata/crl/1ab871c8.r0 new file mode 120000 index 00000000000..f2cd877e7ed --- /dev/null +++ b/security/advancedtls/testdata/crl/1ab871c8.r0 @@ -0,0 +1 @@ +2.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/2.crl b/security/advancedtls/testdata/crl/2.crl new file mode 100644 index 00000000000..5ca9afd7141 --- /dev/null +++ b/security/advancedtls/testdata/crl/2.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAyMS0wMi0wMlQwNzozMDozNi0wODowMCkX +DTIxMDIwMjE1MzAzNloXDTIxMDIwOTE1MzAzNlqgLzAtMB8GA1UdIwQYMBaAFBjo +V5Jnk/gp1k7fmWwkvTk/cF/IMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0gAMEUC +IQDgjA1Vj/pNFtNRL0vFEdapmFoArHM2+rn4IiP8jYLsCAIgAj2KEHbbtJ3zl5XP +WVW6ZyW7r3wIX+Bt3vLJWPrQtf8= +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/3.crl b/security/advancedtls/testdata/crl/3.crl new file mode 100644 index 00000000000..d37ad2247f5 --- /dev/null +++ b/security/advancedtls/testdata/crl/3.crl @@ -0,0 +1,11 @@ +-----BEGIN X509 CRL----- +MIIBiDCCAS8CAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAyMS0wMi0wMlQwNzozMTo1NC0wODowMCkX +DTIxMDIwMjE1MzE1NFoXDTIxMDIwOTE1MzE1NFowJzAlAhQAroEYW855BRqTrlov +5cBCGvkutxcNMjEwMjAyMTUzMTU0WqAvMC0wHwYDVR0jBBgwFoAUeq/TQ959KbWk +/um08jSTXogXpWUwCgYDVR0UBAMCAQEwCgYIKoZIzj0EAwIDRwAwRAIgaSOIhJDg +wOLYlbXkmxW0cqy/AfOUNYbz5D/8/FfvhosCICftg7Vzlu0Nh83jikyjy+wtkiJt +ZYNvGFQ3Sp2L3A9e +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/4.crl b/security/advancedtls/testdata/crl/4.crl new file mode 100644 index 00000000000..d4ee6f7cf18 --- /dev/null +++ b/security/advancedtls/testdata/crl/4.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAyMS0wMi0wMlQwNzozMTo1NC0wODowMCkX +DTIxMDIwMjE1MzE1NFoXDTIxMDIwOTE1MzE1NFqgLzAtMB8GA1UdIwQYMBaAFIVn +8tIFgZpIdhomgYJ2c5ULLzpSMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0gAMEUC +ICupTvOqgAyRa1nn7+Pe/1vvlJPAQ8gUfTQsQ6XX3v6oAiEA08B2PsK6aTEwzjry +pXqhlUNZFzgaXrVVQuEJbyJ1qoU= +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/5.crl b/security/advancedtls/testdata/crl/5.crl new file mode 100644 index 00000000000..d1c24f0f25a --- /dev/null +++ b/security/advancedtls/testdata/crl/5.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBXzCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAyMS0wMi0wMlQwNzozMjo1Ny0wODowMCkX +DTIxMDIwMjE1MzI1N1oXDTIxMDIwOTE1MzI1N1qgLzAtMB8GA1UdIwQYMBaAFN+g +xTAtSTlb5Qqvrbp4rZtsaNzqMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0cAMEQC +IHrRKjieY7w7gxvpkJAdszPZBlaSSp/c9wILutBTy7SyAiAwhaHfgas89iRfaBs2 +EhGIeK39A+kSzqu6qEQBHpK36g== +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/6.crl b/security/advancedtls/testdata/crl/6.crl new file mode 100644 index 00000000000..87ef378f6ab --- /dev/null +++ b/security/advancedtls/testdata/crl/6.crl @@ -0,0 +1,11 @@ +-----BEGIN X509 CRL----- +MIIBiDCCAS8CAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAyMS0wMi0wMlQwNzozMjo1Ny0wODowMCkX +DTIxMDIwMjE1MzI1N1oXDTIxMDIwOTE1MzI1N1owJzAlAhQAxSe/pGmyvzN7mxm5 +6ZJTYUXYuhcNMjEwMjAyMTUzMjU3WqAvMC0wHwYDVR0jBBgwFoAUpZ30UJXB4lI9 +j2SzodCtRFckrRcwCgYDVR0UBAMCAQEwCgYIKoZIzj0EAwIDRwAwRAIgRg3u7t3b +oyV5FhMuGGzWnfIwnKclpT8imnp8tEN253sCIFUY7DjiDohwu4Zup3bWs1OaZ3q3 +cm+j0H/oe8zzCAgp +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/71eac5a2.r0 b/security/advancedtls/testdata/crl/71eac5a2.r0 new file mode 120000 index 00000000000..9f37924cae0 --- /dev/null +++ b/security/advancedtls/testdata/crl/71eac5a2.r0 @@ -0,0 +1 @@ +4.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/7a1799af.r0 b/security/advancedtls/testdata/crl/7a1799af.r0 new file mode 120000 index 00000000000..f34df5b59c0 --- /dev/null +++ b/security/advancedtls/testdata/crl/7a1799af.r0 @@ -0,0 +1 @@ +3.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/8828a7e6.r0 b/security/advancedtls/testdata/crl/8828a7e6.r0 new file mode 120000 index 00000000000..70bead214cc --- /dev/null +++ b/security/advancedtls/testdata/crl/8828a7e6.r0 @@ -0,0 +1 @@ +6.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/README.md b/security/advancedtls/testdata/crl/README.md new file mode 100644 index 00000000000..00cb09c3192 --- /dev/null +++ b/security/advancedtls/testdata/crl/README.md @@ -0,0 +1,48 @@ +# CRL Test Data + +This directory contains cert chains and CRL files for revocation testing. + +To print the chain, use a command like, + +```shell +openssl crl2pkcs7 -nocrl -certfile security/crl/x509/client/testdata/revokedLeaf.pem | openssl pkcs7 -print_certs -text -noout +``` + +The crl file symlinks are generated with `openssl rehash` + +## unrevoked.pem + +A certificate chain with CRL files and unrevoked certs + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=Root CA (2021-02-02T07:30:36-08:00) + * 1.crl + +NOTE: 1.crl file is symlinked with 5.crl to simulate two issuers that hash to +the same value to test that loading multiple files works. + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=node CA (2021-02-02T07:30:36-08:00) + * 2.crl + +## revokedInt.pem + +Certificate chain where the intermediate is revoked + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=Root CA (2021-02-02T07:31:54-08:00) + * 3.crl +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=node CA (2021-02-02T07:31:54-08:00) + * 4.crl + +## revokedLeaf.pem + +Certificate chain where the leaf is revoked + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=Root CA (2021-02-02T07:32:57-08:00) + * 5.crl +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=node CA (2021-02-02T07:32:57-08:00) + * 6.crl diff --git a/security/advancedtls/testdata/crl/deee447d.r0 b/security/advancedtls/testdata/crl/deee447d.r0 new file mode 120000 index 00000000000..1a84eabdfc7 --- /dev/null +++ b/security/advancedtls/testdata/crl/deee447d.r0 @@ -0,0 +1 @@ +5.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/revokedInt.pem b/security/advancedtls/testdata/crl/revokedInt.pem new file mode 100644 index 00000000000..8b7282ff822 --- /dev/null +++ b/security/advancedtls/testdata/crl/revokedInt.pem @@ -0,0 +1,58 @@ +-----BEGIN CERTIFICATE----- +MIIDAzCCAqmgAwIBAgITAWjKwm2dNQvkO62Jgyr5rAvVQzAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNSb290IENBICgyMDIx +LTAyLTAyVDA3OjMxOjU0LTA4OjAwKTAgFw0yMTAyMDIxNTMxNTRaGA85OTk5MTIz +MTIzNTk1OVowgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYw +FAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYD +VQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9v +dCBDQSAoMjAyMS0wMi0wMlQwNzozMTo1NC0wODowMCkwWTATBgcqhkjOPQIBBggq +hkjOPQMBBwNCAAQhA0/puhTtSxbVVHseVhL2z7QhpPyJs5Q4beKi7tpaYRDmVn6p +Phh+jbRzg8Qj4gKI/Q1rrdm4rKer63LHpdWdo4GzMIGwMA4GA1UdDwEB/wQEAwIB +BjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUwAwEB +/zAdBgNVHQ4EFgQUeq/TQ959KbWk/um08jSTXogXpWUwHwYDVR0jBBgwFoAUeq/T +Q959KbWk/um08jSTXogXpWUwLgYDVR0RBCcwJYYjc3BpZmZlOi8vY2FtcHVzLXNs +bi5wcm9kLmdvb2dsZS5jb20wCgYIKoZIzj0EAwIDSAAwRQIgOSQZvyDPQwVOWnpF +zWvI+DS2yXIj/2T2EOvJz2XgcK4CIQCL0mh/+DxLiO4zzbInKr0mxpGSxSeZCUk7 +1ZF7AeLlbw== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDizCCAzKgAwIBAgIUAK6BGFvOeQUak65aL+XAQhr5LrcwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMTo1NC0wODowMCkwIBcNMjEwMjAyMTUzMTU0WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI25v +ZGUgQ0EgKDIwMjEtMDItMDJUMDc6MzE6NTQtMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEye6UOlBos8Q3FFBiLahD9BaLTA18bO4MTPyv35T3lppvxD5X +U/AnEllOnx5OMtMjMBbIQjSkMbiQ9xNXoSqB6aOCATowggE2MA4GA1UdDwEB/wQE +AwIBBjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUhWfy0gWBmkh2GiaBgnZzlQsvOlIwHwYDVR0jBBgwFoAU +eq/TQ959KbWk/um08jSTXogXpWUwMwYDVR0RBCwwKoYoc3BpZmZlOi8vbm9kZS5j +YW1wdXMtc2xuLnByb2QuZ29vZ2xlLmNvbTA7BgNVHR4BAf8EMTAvoC0wK4YpY3Nj +cy10ZWFtLm5vZGUuY2FtcHVzLXNsbi5wcm9kLmdvb2dsZS5jb20wQgYDVR0fBDsw +OTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2NhbXB1 +cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNHADBEAiA79rPu6ZO1/0qB6RxL7jVz1200 +UTo8ioB4itbTzMnJqAIgJqp/Rc8OhpsfzQX8XnIIkl+SewT+tOxJT1MHVNMlVhc= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIC0DCCAnWgAwIBAgITXQ2c/C27OGqk4Pbu+MNJlOtpYTAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNub2RlIENBICgyMDIx +LTAyLTAyVDA3OjMxOjU0LTA4OjAwKTAgFw0yMTAyMDIxNTMxNTRaGA85OTk5MTIz +MTIzNTk1OVowADBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABN2/1le5d3hS/piw +hrNMHjd7gPEjzXwtuXQTzdV+aaeOf3ldnC6OnEF/bggym9MldQSJZLXPYSaoj430 +Vu5PRNejggEkMIIBIDAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUH +AwIGCCsGAQUFBwMBMB0GA1UdDgQWBBTEewP3JgrJPekWWGGjChVqaMhaqTAfBgNV +HSMEGDAWgBSFZ/LSBYGaSHYaJoGCdnOVCy86UjBrBgNVHREBAf8EYTBfghZqemFi +MTIucHJvZC5nb29nbGUuY29thkVzcGlmZmU6Ly9jc2NzLXRlYW0ubm9kZS5jYW1w +dXMtc2xuLnByb2QuZ29vZ2xlLmNvbS9yb2xlL2JvcmctYWRtaW4tY28wQgYDVR0f +BDswOTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2Nh +bXB1cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNJADBGAiEA9w4qp3nHpXo+6d7mZc69 +QoALfP5ynfBCArt8bAlToo8CIQCgc/lTfl2BtBko+7h/w6pKxLeuoQkvCL5gHFyK +LXE6vA== +-----END CERTIFICATE----- diff --git a/security/advancedtls/testdata/crl/revokedLeaf.pem b/security/advancedtls/testdata/crl/revokedLeaf.pem new file mode 100644 index 00000000000..b7541abf621 --- /dev/null +++ b/security/advancedtls/testdata/crl/revokedLeaf.pem @@ -0,0 +1,59 @@ +-----BEGIN CERTIFICATE----- +MIIDAzCCAqmgAwIBAgITTwodm6C4ZabFVUVa5yBw0TbzJTAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNSb290IENBICgyMDIx +LTAyLTAyVDA3OjMyOjU3LTA4OjAwKTAgFw0yMTAyMDIxNTMyNTdaGA85OTk5MTIz +MTIzNTk1OVowgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYw +FAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYD +VQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9v +dCBDQSAoMjAyMS0wMi0wMlQwNzozMjo1Ny0wODowMCkwWTATBgcqhkjOPQIBBggq +hkjOPQMBBwNCAARoZnzQWvAoyhvCLA2cFIK17khSaA9aA+flS5X9fLRt4RsfPCx3 +kim7wYKQSmBhQdc1UM4h3969r1c1Fvsh2H9qo4GzMIGwMA4GA1UdDwEB/wQEAwIB +BjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUwAwEB +/zAdBgNVHQ4EFgQU36DFMC1JOVvlCq+tunitm2xo3OowHwYDVR0jBBgwFoAU36DF +MC1JOVvlCq+tunitm2xo3OowLgYDVR0RBCcwJYYjc3BpZmZlOi8vY2FtcHVzLXNs +bi5wcm9kLmdvb2dsZS5jb20wCgYIKoZIzj0EAwIDSAAwRQIgN7S9dQOQzNih92ag +7c5uQxuz+M6wnxWj/uwGQIIghRUCIQD2UDH6kkRSYQuyP0oN7XYO3XFjmZ2Yer6m +1ZS8fyWYYA== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDjTCCAzKgAwIBAgIUAOmArBu9gihLTlqP3W7Et0UoocEwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMjo1Ny0wODowMCkwIBcNMjEwMjAyMTUzMjU3WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI25v +ZGUgQ0EgKDIwMjEtMDItMDJUMDc6MzI6NTctMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEfrgVEVQfSEFeCF1/FGeW7oq0yxecenT1BESfj4Z0zJ8p7P9W +bj1o6Rn6dUNlEhGrx7E3/4NFJ0cL1BSNGHkjiqOCATowggE2MA4GA1UdDwEB/wQE +AwIBBjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUpZ30UJXB4lI9j2SzodCtRFckrRcwHwYDVR0jBBgwFoAU +36DFMC1JOVvlCq+tunitm2xo3OowMwYDVR0RBCwwKoYoc3BpZmZlOi8vbm9kZS5j +YW1wdXMtc2xuLnByb2QuZ29vZ2xlLmNvbTA7BgNVHR4BAf8EMTAvoC0wK4YpY3Nj +cy10ZWFtLm5vZGUuY2FtcHVzLXNsbi5wcm9kLmdvb2dsZS5jb20wQgYDVR0fBDsw +OTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2NhbXB1 +cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNJADBGAiEAnuONgMqmbBlj4ibw5BgDtZUM +pboACSFJtEOJu4Yqjt0CIQDI5193J4wUcAY0BK0vO9rRfbNOIc+4ke9ieBDPSuhm +mA== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIICzzCCAnagAwIBAgIUAMUnv6Rpsr8ze5sZuemSU2FF2LowCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAy +MS0wMi0wMlQwNzozMjo1Ny0wODowMCkwIBcNMjEwMjAyMTUzMjU3WhgPOTk5OTEy +MzEyMzU5NTlaMAAwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAASCmYiIHUux5WFz +S0ksJzAPL7YTEh5o5MdXgLPB/WM6x9sVsQDSYU0PF5qc9vPNhkQzGBW79dkBnxhW +AGJkFr1Po4IBJDCCASAwDgYDVR0PAQH/BAQDAgeAMB0GA1UdJQQWMBQGCCsGAQUF +BwMCBggrBgEFBQcDATAdBgNVHQ4EFgQUCR1CGEdlks0qcxCExO0rP1B/Z7UwHwYD +VR0jBBgwFoAUpZ30UJXB4lI9j2SzodCtRFckrRcwawYDVR0RAQH/BGEwX4IWanph +YjEyLnByb2QuZ29vZ2xlLmNvbYZFc3BpZmZlOi8vY3Njcy10ZWFtLm5vZGUuY2Ft +cHVzLXNsbi5wcm9kLmdvb2dsZS5jb20vcm9sZS9ib3JnLWFkbWluLWNvMEIGA1Ud +HwQ7MDkwN6A1oDOGMWh0dHA6Ly9zdGF0aWMuY29ycC5nb29nbGUuY29tL2NybC9j +YW1wdXMtc2xuL25vZGUwCgYIKoZIzj0EAwIDRwAwRAIgK9vQYNoL8HlEwWv89ioG +aQ1+8swq6Bo/5mJBrdVLvY8CIGxo6M9vJkPdObmetWNC+lmKuZDoqJWI0AAmBT2J +mR2r +-----END CERTIFICATE----- diff --git a/security/advancedtls/testdata/crl/unrevoked.pem b/security/advancedtls/testdata/crl/unrevoked.pem new file mode 100644 index 00000000000..5c5fc58a7a5 --- /dev/null +++ b/security/advancedtls/testdata/crl/unrevoked.pem @@ -0,0 +1,58 @@ +-----BEGIN CERTIFICATE----- +MIIDBDCCAqqgAwIBAgIUALy864QhnkTdceLH52k2XVOe8IQwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMDozNi0wODowMCkwIBcNMjEwMjAyMTUzMDM2WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI1Jv +b3QgQ0EgKDIwMjEtMDItMDJUMDc6MzA6MzYtMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEYv/JS5hQ5kIgdKqYZWTKCO/6gloHAmIb1G8lmY0oXLXYNHQ4 +qHN7/pPtlcHQp0WK/hM8IGvgOUDoynA8mj0H9KOBszCBsDAOBgNVHQ8BAf8EBAMC +AQYwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMA8GA1UdEwEB/wQFMAMB +Af8wHQYDVR0OBBYEFPQNtnCIBcG4ReQgoVi0kPgTROseMB8GA1UdIwQYMBaAFPQN +tnCIBcG4ReQgoVi0kPgTROseMC4GA1UdEQQnMCWGI3NwaWZmZTovL2NhbXB1cy1z +bG4ucHJvZC5nb29nbGUuY29tMAoGCCqGSM49BAMCA0gAMEUCIQDwBn20DB4X/7Uk +Q5BR8JxQYUPxOfvuedjfeA8bPvQ2FwIgOEWa0cXJs1JxarILJeCXtdXvBgu6LEGQ +3Pk/bgz8Gek= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDizCCAzKgAwIBAgIUAM/6RKQ7Vke0i4xp5LaAqV73cmIwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMDozNi0wODowMCkwIBcNMjEwMjAyMTUzMDM2WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI25v +ZGUgQ0EgKDIwMjEtMDItMDJUMDc6MzA6MzYtMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEllnhxmMYiUPUgRGmenbnm10gXpM94zHx3D1/HumPs6arjYuT +Zlhx81XL+g4bu4HII2qcGdP+Hqj/MMFNDI9z4aOCATowggE2MA4GA1UdDwEB/wQE +AwIBBjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUGOhXkmeT+CnWTt+ZbCS9OT9wX8gwHwYDVR0jBBgwFoAU +9A22cIgFwbhF5CChWLSQ+BNE6x4wMwYDVR0RBCwwKoYoc3BpZmZlOi8vbm9kZS5j +YW1wdXMtc2xuLnByb2QuZ29vZ2xlLmNvbTA7BgNVHR4BAf8EMTAvoC0wK4YpY3Nj +cy10ZWFtLm5vZGUuY2FtcHVzLXNsbi5wcm9kLmdvb2dsZS5jb20wQgYDVR0fBDsw +OTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2NhbXB1 +cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNHADBEAiA86egqPw0qyapAeMGbHxrmYZYa +i5ARQsSKRmQixgYizQIgW+2iRWN6Kbqt4WcwpmGv/xDckdRXakF5Ign/WUDO5u4= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIICzzCCAnWgAwIBAgITYjjKfYZUKQNUjNyF+hLDGpHJKTAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNub2RlIENBICgyMDIx +LTAyLTAyVDA3OjMwOjM2LTA4OjAwKTAgFw0yMTAyMDIxNTMwMzZaGA85OTk5MTIz +MTIzNTk1OVowADBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD4r4+nCgZExYF8v +CLvGn0lY/cmam8mAkJDXRN2Ja2t+JwaTOptPmbbXft+1NTk5gCg5wB+FJCnaV3I/ +HaxEhBWjggEkMIIBIDAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUH +AwIGCCsGAQUFBwMBMB0GA1UdDgQWBBTTCjXX1Txjc00tBg/5cFzpeCSKuDAfBgNV +HSMEGDAWgBQY6FeSZ5P4KdZO35lsJL05P3BfyDBrBgNVHREBAf8EYTBfghZqemFi +MTIucHJvZC5nb29nbGUuY29thkVzcGlmZmU6Ly9jc2NzLXRlYW0ubm9kZS5jYW1w +dXMtc2xuLnByb2QuZ29vZ2xlLmNvbS9yb2xlL2JvcmctYWRtaW4tY28wQgYDVR0f +BDswOTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2Nh +bXB1cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNIADBFAiBq3URViNyMLpvzZHC1Y+4L ++35guyIJfjHu08P3S8/xswIhAJtWSQ1ZtozdOzGxg7GfUo4hR+5SP6rBTgIqXEfq +48fW +-----END CERTIFICATE----- diff --git a/security/authorization/go.mod b/security/authorization/go.mod index 0581b3401f3..ce34742af2c 100644 --- a/security/authorization/go.mod +++ b/security/authorization/go.mod @@ -1,6 +1,6 @@ module google.golang.org/grpc/security/authorization -go 1.12 +go 1.14 require ( github.com/envoyproxy/go-control-plane v0.9.5 diff --git a/security/authorization/go.sum b/security/authorization/go.sum index a953711e01e..3c7ea6cf47f 100644 --- a/security/authorization/go.sum +++ b/security/authorization/go.sum @@ -14,14 +14,11 @@ github.com/envoyproxy/go-control-plane v0.9.5 h1:lRJIqDD8yjV1YyPRqecMdytjDLs2fTX github.com/envoyproxy/go-control-plane v0.9.5/go.mod h1:OXl5to++W0ctG+EHWTFUjiypVxC/Y4VLc/KFU+al13s= github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.4 h1:87PNWwrRvUSnqS4dlcBU/ftvOIBep4sYuBLlh6rX2wk= github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -36,7 +33,6 @@ github.com/google/cel-spec v0.4.0/go.mod h1:2pBM5cU4UKjbPDXBgwWkiwBsVgnxknuEJ7C5 github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -49,7 +45,6 @@ golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20200301022130-244492dfa37a h1:GuSPYbZzB5/dcLNCwLQLsg3obCJtX9IJhpXkvY7kzk0= golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -58,11 +53,9 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -76,18 +69,14 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 h1:gSJIx1SDwno+2ElGhA4+qG2zF97qiUzTM+rQ0klBOcE= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200305110556-506484158171 h1:xes2Q2k+d/+YNXVw0FpZkIDJiaux4OVrRKXRAzH6A0U= google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1 h1:wdKvqQk7IttEw92GoRyKG2IDrUIpgpj6H6m81yfeMW0= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.1 h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk= google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= diff --git a/server.go b/server.go index 7a2aa28a114..eadf9e05fd1 100644 --- a/server.go +++ b/server.go @@ -57,12 +57,22 @@ import ( const ( defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 defaultServerMaxSendMessageSize = math.MaxInt32 + + // Server transports are tracked in a map which is keyed on listener + // address. For regular gRPC traffic, connections are accepted in Serve() + // through a call to Accept(), and we use the actual listener address as key + // when we add it to the map. But for connections received through + // ServeHTTP(), we do not have a listener and hence use this dummy value. + listenerAddressForServeHTTP = "listenerAddressForServeHTTP" ) func init() { internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials { return srv.opts.creds } + internal.DrainServerTransports = func(srv *Server, addr string) { + srv.drainServerTransports(addr) + } } var statusOK = status.New(codes.OK, "") @@ -107,9 +117,12 @@ type serverWorkerData struct { type Server struct { opts serverOptions - mu sync.Mutex // guards following - lis map[net.Listener]bool - conns map[transport.ServerTransport]bool + mu sync.Mutex // guards following + lis map[net.Listener]bool + // conns contains all active server transports. It is a map keyed on a + // listener address with the value being the set of active transports + // belonging to that listener. + conns map[string]map[transport.ServerTransport]bool serve bool drain bool cv *sync.Cond // signaled when connections close for GracefulStop @@ -266,6 +279,35 @@ func CustomCodec(codec Codec) ServerOption { }) } +// ForceServerCodec returns a ServerOption that sets a codec for message +// marshaling and unmarshaling. +// +// This will override any lookups by content-subtype for Codecs registered +// with RegisterCodec. +// +// See Content-Type on +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. Also see the documentation on RegisterCodec and +// CallContentSubtype for more details on the interaction between encoding.Codec +// and content-subtype. +// +// This function is provided for advanced users; prefer to register codecs +// using encoding.RegisterCodec. +// The server will automatically use registered codecs based on the incoming +// requests' headers. See also +// https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec. +// Will be supported throughout 1.x. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func ForceServerCodec(codec encoding.Codec) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.codec = codec + }) +} + // RPCCompressor returns a ServerOption that sets a compressor for outbound // messages. For backward compatibility, all outbound messages will be sent // using this compressor, regardless of incoming message compression. By @@ -376,6 +418,11 @@ func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOptio // InTapHandle returns a ServerOption that sets the tap handle for all the server // transport to be created. Only one can be installed. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. func InTapHandle(h tap.ServerInHandle) ServerOption { return newFuncServerOption(func(o *serverOptions) { if o.inTapHandle != nil { @@ -519,7 +566,7 @@ func NewServer(opt ...ServerOption) *Server { s := &Server{ lis: make(map[net.Listener]bool), opts: opts, - conns: make(map[transport.ServerTransport]bool), + conns: make(map[string]map[transport.ServerTransport]bool), services: make(map[string]*serviceInfo), quit: grpcsync.NewEvent(), done: grpcsync.NewEvent(), @@ -663,13 +710,6 @@ func (s *Server) GetServiceInfo() map[string]ServiceInfo { // the server being stopped. var ErrServerStopped = errors.New("grpc: the server has been stopped") -func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - if s.opts.creds == nil { - return rawConn, nil, nil - } - return s.opts.creds.ServerHandshake(rawConn) -} - type listenSocket struct { net.Listener channelzID int64 @@ -778,7 +818,7 @@ func (s *Server) Serve(lis net.Listener) error { // s.conns before this conn can be added. s.serveWG.Add(1) go func() { - s.handleRawConn(rawConn) + s.handleRawConn(lis.Addr().String(), rawConn) s.serveWG.Done() }() } @@ -786,49 +826,45 @@ func (s *Server) Serve(lis net.Listener) error { // handleRawConn forks a goroutine to handle a just-accepted connection that // has not had any I/O performed on it yet. -func (s *Server) handleRawConn(rawConn net.Conn) { +func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { if s.quit.HasFired() { rawConn.Close() return } rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) - conn, authInfo, err := s.useTransportAuthenticator(rawConn) - if err != nil { - // ErrConnDispatched means that the connection was dispatched away from - // gRPC; those connections should be left open. - if err != credentials.ErrConnDispatched { - s.mu.Lock() - s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) - s.mu.Unlock() - channelz.Warningf(logger, s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) - rawConn.Close() - } - rawConn.SetDeadline(time.Time{}) - return - } // Finish handshaking (HTTP2) - st := s.newHTTP2Transport(conn, authInfo) + st := s.newHTTP2Transport(rawConn) + rawConn.SetDeadline(time.Time{}) if st == nil { return } - rawConn.SetDeadline(time.Time{}) - if !s.addConn(st) { + if !s.addConn(lisAddr, st) { return } go func() { s.serveStreams(st) - s.removeConn(st) + s.removeConn(lisAddr, st) }() } +func (s *Server) drainServerTransports(addr string) { + s.mu.Lock() + conns := s.conns[addr] + for st := range conns { + st.Drain() + } + s.mu.Unlock() +} + // newHTTP2Transport sets up a http/2 transport (using the // gRPC http2 server transport in transport/http2_server.go). -func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) transport.ServerTransport { +func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { config := &transport.ServerConfig{ MaxStreams: s.opts.maxConcurrentStreams, - AuthInfo: authInfo, + ConnectionTimeout: s.opts.connectionTimeout, + Credentials: s.opts.creds, InTapHandle: s.opts.inTapHandle, StatsHandler: s.opts.statsHandler, KeepaliveParams: s.opts.keepaliveParams, @@ -841,13 +877,20 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr MaxHeaderListSize: s.opts.maxHeaderListSize, HeaderTableSize: s.opts.headerTableSize, } - st, err := transport.NewServerTransport("http2", c, config) + st, err := transport.NewServerTransport(c, config) if err != nil { s.mu.Lock() s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) s.mu.Unlock() - c.Close() - channelz.Warning(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err) + // ErrConnDispatched means that the connection was dispatched away from + // gRPC; those connections should be left open. + if err != credentials.ErrConnDispatched { + // Don't log on ErrConnDispatched and io.EOF to prevent log spam. + if err != io.EOF { + channelz.Warning(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err) + } + c.Close() + } return nil } @@ -924,10 +967,10 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - if !s.addConn(st) { + if !s.addConn(listenerAddressForServeHTTP, st) { return } - defer s.removeConn(st) + defer s.removeConn(listenerAddressForServeHTTP, st) s.serveStreams(st) } @@ -955,7 +998,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea return trInfo } -func (s *Server) addConn(st transport.ServerTransport) bool { +func (s *Server) addConn(addr string, st transport.ServerTransport) bool { s.mu.Lock() defer s.mu.Unlock() if s.conns == nil { @@ -967,15 +1010,28 @@ func (s *Server) addConn(st transport.ServerTransport) bool { // immediately. st.Drain() } - s.conns[st] = true + + if s.conns[addr] == nil { + // Create a map entry if this is the first connection on this listener. + s.conns[addr] = make(map[transport.ServerTransport]bool) + } + s.conns[addr][st] = true return true } -func (s *Server) removeConn(st transport.ServerTransport) { +func (s *Server) removeConn(addr string, st transport.ServerTransport) { s.mu.Lock() defer s.mu.Unlock() - if s.conns != nil { - delete(s.conns, st) + + conns := s.conns[addr] + if conns != nil { + delete(conns, st) + if len(conns) == 0 { + // If the last connection for this address is being removed, also + // remove the map entry corresponding to the address. This is used + // in GracefulStop() when waiting for all connections to be closed. + delete(s.conns, addr) + } s.cv.Broadcast() } } @@ -1040,22 +1096,29 @@ func chainUnaryServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { - return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) - } + chainedInt = chainUnaryInterceptors(interceptors) } s.opts.unaryInt = chainedInt } -// getChainUnaryHandler recursively generate the chained UnaryHandler -func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler { - if curr == len(interceptors)-1 { - return finalHandler - } - - return func(ctx context.Context, req interface{}) (interface{}, error) { - return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) +func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { + // the struct ensures the variables are allocated together, rather than separately, since we + // know they should be garbage collected together. This saves 1 allocation and decreases + // time/call by about 10% on the microbenchmark. + var state struct { + i int + next UnaryHandler + } + state.next = func(ctx context.Context, req interface{}) (interface{}, error) { + if state.i == len(interceptors)-1 { + return interceptors[state.i](ctx, req, info, handler) + } + state.i++ + return interceptors[state.i-1](ctx, req, info, state.next) + } + return state.next(ctx, req) } } @@ -1069,7 +1132,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if sh != nil { beginTime := time.Now() statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: beginTime, + IsClientStream: false, + IsServerStream: false, } sh.HandleRPC(stream.Context(), statsBegin) } @@ -1321,22 +1386,29 @@ func chainStreamServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { - return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler)) - } + chainedInt = chainStreamInterceptors(interceptors) } s.opts.streamInt = chainedInt } -// getChainStreamHandler recursively generate the chained StreamHandler -func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler { - if curr == len(interceptors)-1 { - return finalHandler - } - - return func(srv interface{}, ss ServerStream) error { - return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler)) +func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor { + return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { + // the struct ensures the variables are allocated together, rather than separately, since we + // know they should be garbage collected together. This saves 1 allocation and decreases + // time/call by about 10% on the microbenchmark. + var state struct { + i int + next StreamHandler + } + state.next = func(srv interface{}, ss ServerStream) error { + if state.i == len(interceptors)-1 { + return interceptors[state.i](srv, ss, info, handler) + } + state.i++ + return interceptors[state.i-1](srv, ss, info, state.next) + } + return state.next(srv, ss) } } @@ -1349,7 +1421,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if sh != nil { beginTime := time.Now() statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: beginTime, + IsClientStream: sd.ClientStreams, + IsServerStream: sd.ServerStreams, } sh.HandleRPC(stream.Context(), statsBegin) } @@ -1452,6 +1526,8 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp } } + ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp) + if trInfo != nil { trInfo.tr.LazyLog(&trInfo.firstLine, false) } @@ -1519,7 +1595,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.SetError() } errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) - if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil { + if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -1639,7 +1715,7 @@ func (s *Server) Stop() { s.mu.Lock() listeners := s.lis s.lis = nil - st := s.conns + conns := s.conns s.conns = nil // interrupt GracefulStop if Stop and GracefulStop are called concurrently. s.cv.Broadcast() @@ -1648,8 +1724,10 @@ func (s *Server) Stop() { for lis := range listeners { lis.Close() } - for c := range st { - c.Close() + for _, cs := range conns { + for st := range cs { + st.Close() + } } if s.opts.numServerWorkers > 0 { s.stopServerWorkers() @@ -1686,8 +1764,10 @@ func (s *Server) GracefulStop() { } s.lis = nil if !s.drain { - for st := range s.conns { - st.Drain() + for _, conns := range s.conns { + for st := range conns { + st.Drain() + } } s.drain = true } diff --git a/server_test.go b/server_test.go index fcfde30706c..b1593916014 100644 --- a/server_test.go +++ b/server_test.go @@ -22,6 +22,7 @@ import ( "context" "net" "reflect" + "strconv" "strings" "testing" "time" @@ -130,3 +131,59 @@ func (s) TestStreamContext(t *testing.T) { t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream) } } + +func BenchmarkChainUnaryInterceptor(b *testing.B) { + for _, n := range []int{1, 3, 5, 10} { + n := n + b.Run(strconv.Itoa(n), func(b *testing.B) { + interceptors := make([]UnaryServerInterceptor, 0, n) + for i := 0; i < n; i++ { + interceptors = append(interceptors, func( + ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler, + ) (interface{}, error) { + return handler(ctx, req) + }) + } + + s := NewServer(ChainUnaryInterceptor(interceptors...)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := s.opts.unaryInt(context.Background(), nil, nil, + func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }, + ); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkChainStreamInterceptor(b *testing.B) { + for _, n := range []int{1, 3, 5, 10} { + n := n + b.Run(strconv.Itoa(n), func(b *testing.B) { + interceptors := make([]StreamServerInterceptor, 0, n) + for i := 0; i < n; i++ { + interceptors = append(interceptors, func( + srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler, + ) error { + return handler(srv, ss) + }) + } + + s := NewServer(ChainStreamInterceptor(interceptors...)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := s.opts.streamInt(nil, nil, nil, func(srv interface{}, stream ServerStream) error { + return nil + }); err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/stats/stats.go b/stats/stats.go index 63e476ee7ff..0285dcc6a26 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -36,15 +36,22 @@ type RPCStats interface { IsClient() bool } -// Begin contains stats when an RPC begins. +// Begin contains stats when an RPC attempt begins. // FailFast is only valid if this Begin is from client side. type Begin struct { // Client is true if this Begin is from client side. Client bool - // BeginTime is the time when the RPC begins. + // BeginTime is the time when the RPC attempt begins. BeginTime time.Time // FailFast indicates if this RPC is failfast. FailFast bool + // IsClientStream indicates whether the RPC is a client streaming RPC. + IsClientStream bool + // IsServerStream indicates whether the RPC is a server streaming RPC. + IsServerStream bool + // IsTransparentRetryAttempt indicates whether this attempt was initiated + // due to transparently retrying a previous attempt. + IsTransparentRetryAttempt bool } // IsClient indicates if the stats information is from client side. diff --git a/stats/stats_test.go b/stats/stats_test.go index 306f2f6b8e9..dfc6edfc3d3 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -407,15 +407,17 @@ func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallReq } type expectedData struct { - method string - serverAddr string - compression string - reqIdx int - requests []proto.Message - respIdx int - responses []proto.Message - err error - failfast bool + method string + isClientStream bool + isServerStream bool + serverAddr string + compression string + reqIdx int + requests []proto.Message + respIdx int + responses []proto.Message + err error + failfast bool } type gotData struct { @@ -456,6 +458,12 @@ func checkBegin(t *testing.T, d *gotData, e *expectedData) { t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast) } } + if st.IsClientStream != e.isClientStream { + t.Fatalf("st.IsClientStream = %v, want %v", st.IsClientStream, e.isClientStream) + } + if st.IsServerStream != e.isServerStream { + t.Fatalf("st.IsServerStream = %v, want %v", st.IsServerStream, e.isServerStream) + } } func checkInHeader(t *testing.T, d *gotData, e *expectedData) { @@ -847,6 +855,9 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f err error method string + isClientStream bool + isServerStream bool + req proto.Message resp proto.Message e error @@ -864,14 +875,18 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f reqs, resp, e = te.doClientStreamCall(cc) resps = []proto.Message{resp} err = e + isClientStream = true case serverStreamRPC: method = "/grpc.testing.TestService/StreamingOutputCall" req, resps, e = te.doServerStreamCall(cc) reqs = []proto.Message{req} err = e + isServerStream = true case fullDuplexStreamRPC: method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) + isClientStream = true + isServerStream = true } if cc.success != (err == nil) { t.Fatalf("cc.success: %v, got error: %v", cc.success, err) @@ -900,12 +915,14 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f } expect := &expectedData{ - serverAddr: te.srvAddr, - compression: tc.compress, - method: method, - requests: reqs, - responses: resps, - err: err, + serverAddr: te.srvAddr, + compression: tc.compress, + method: method, + requests: reqs, + responses: resps, + err: err, + isClientStream: isClientStream, + isServerStream: isServerStream, } h.mu.Lock() @@ -1138,6 +1155,9 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map method string err error + isClientStream bool + isServerStream bool + req proto.Message resp proto.Message e error @@ -1154,14 +1174,18 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map reqs, resp, e = te.doClientStreamCall(cc) resps = []proto.Message{resp} err = e + isClientStream = true case serverStreamRPC: method = "/grpc.testing.TestService/StreamingOutputCall" req, resps, e = te.doServerStreamCall(cc) reqs = []proto.Message{req} err = e + isServerStream = true case fullDuplexStreamRPC: method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) + isClientStream = true + isServerStream = true } if cc.success != (err == nil) { t.Fatalf("cc.success: %v, got error: %v", cc.success, err) @@ -1194,13 +1218,15 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map } expect := &expectedData{ - serverAddr: te.srvAddr, - compression: tc.compress, - method: method, - requests: reqs, - responses: resps, - failfast: cc.failfast, - err: err, + serverAddr: te.srvAddr, + compression: tc.compress, + method: method, + requests: reqs, + responses: resps, + failfast: cc.failfast, + err: err, + isClientStream: isClientStream, + isServerStream: isServerStream, } h.mu.Lock() diff --git a/stream.go b/stream.go index 8ba0e8a5eea..fc3299ae1f0 100644 --- a/stream.go +++ b/stream.go @@ -52,14 +52,20 @@ import ( // of the RPC. type StreamHandler func(srv interface{}, stream ServerStream) error -// StreamDesc represents a streaming RPC service's method specification. +// StreamDesc represents a streaming RPC service's method specification. Used +// on the server when registering services and on the client when initiating +// new streams. type StreamDesc struct { - StreamName string - Handler StreamHandler - - // At least one of these is true. - ServerStreams bool - ClientStreams bool + // StreamName and Handler are only used when registering handlers on a + // server. + StreamName string // the name of the method excluding the service + Handler StreamHandler // the handler called for the method + + // ServerStreams and ClientStreams are used for registering handlers on a + // server as well as defining RPC behavior when passed to NewClientStream + // and ClientConn.NewStream. At least one must be true. + ServerStreams bool // indicates the server can perform streaming sends + ClientStreams bool // indicates the client can perform streaming sends } // Stream defines the common interface a client or server stream has to satisfy. @@ -181,7 +187,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth rpcInfo := iresolver.RPCInfo{Context: ctx, Method: method} rpcConfig, err := cc.safeConfigSelector.SelectConfig(rpcInfo) if err != nil { - return nil, status.Convert(err).Err() + return nil, toRPCErr(err) } if rpcConfig != nil { @@ -196,7 +202,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth newStream = func(ctx context.Context, done func()) (iresolver.ClientStream, error) { cs, err := rpcConfig.Interceptor.NewStream(ctx, rpcInfo, done, ns) if err != nil { - return nil, status.Convert(err).Err() + return nil, toRPCErr(err) } return cs, nil } @@ -268,33 +274,6 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client if c.creds != nil { callHdr.Creds = c.creds } - var trInfo *traceInfo - if EnableTracing { - trInfo = &traceInfo{ - tr: trace.New("grpc.Sent."+methodFamily(method), method), - firstLine: firstLine{ - client: true, - }, - } - if deadline, ok := ctx.Deadline(); ok { - trInfo.firstLine.deadline = time.Until(deadline) - } - trInfo.tr.LazyLog(&trInfo.firstLine, false) - ctx = trace.NewContext(ctx, trInfo.tr) - } - ctx = newContextWithRPCInfo(ctx, c.failFast, c.codec, cp, comp) - sh := cc.dopts.copts.StatsHandler - var beginTime time.Time - if sh != nil { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) - beginTime = time.Now() - begin := &stats.Begin{ - Client: true, - BeginTime: beginTime, - FailFast: c.failFast, - } - sh.HandleRPC(ctx, begin) - } cs := &clientStream{ callHdr: callHdr, @@ -308,7 +287,6 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client cp: cp, comp: comp, cancel: cancel, - beginTime: beginTime, firstAttempt: true, onCommit: onCommit, } @@ -317,9 +295,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client } cs.binlog = binarylog.GetMethodLogger(method) - // Only this initial attempt has stats/tracing. - // TODO(dfawley): move to newAttempt when per-attempt stats are implemented. - if err := cs.newAttemptLocked(sh, trInfo); err != nil { + if err := cs.newAttemptLocked(false /* isTransparent */); err != nil { cs.finish(err) return nil, err } @@ -367,8 +343,43 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client // newAttemptLocked creates a new attempt with a transport. // If it succeeds, then it replaces clientStream's attempt with this new attempt. -func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (retErr error) { +func (cs *clientStream) newAttemptLocked(isTransparent bool) (retErr error) { + ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.cp, cs.comp) + method := cs.callHdr.Method + sh := cs.cc.dopts.copts.StatsHandler + var beginTime time.Time + if sh != nil { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast}) + beginTime = time.Now() + begin := &stats.Begin{ + Client: true, + BeginTime: beginTime, + FailFast: cs.callInfo.failFast, + IsClientStream: cs.desc.ClientStreams, + IsServerStream: cs.desc.ServerStreams, + IsTransparentRetryAttempt: isTransparent, + } + sh.HandleRPC(ctx, begin) + } + + var trInfo *traceInfo + if EnableTracing { + trInfo = &traceInfo{ + tr: trace.New("grpc.Sent."+methodFamily(method), method), + firstLine: firstLine{ + client: true, + }, + } + if deadline, ok := ctx.Deadline(); ok { + trInfo.firstLine.deadline = time.Until(deadline) + } + trInfo.tr.LazyLog(&trInfo.firstLine, false) + ctx = trace.NewContext(ctx, trInfo.tr) + } + newAttempt := &csAttempt{ + ctx: ctx, + beginTime: beginTime, cs: cs, dc: cs.cc.dopts.dc, statsHandler: sh, @@ -383,15 +394,14 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (r } }() - if err := cs.ctx.Err(); err != nil { + if err := ctx.Err(); err != nil { return toRPCErr(err) } - ctx := cs.ctx if cs.cc.parsedTarget.Scheme == "xds" { // Add extra metadata (metadata that will be added by transport) to context // so the balancer can see them. - ctx = grpcutil.WithExtraMetadata(cs.ctx, metadata.Pairs( + ctx = grpcutil.WithExtraMetadata(ctx, metadata.Pairs( "content-type", grpcutil.ContentType(cs.callHdr.ContentSubtype), )) } @@ -411,14 +421,11 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (r func (a *csAttempt) newStream() error { cs := a.cs cs.callHdr.PreviousAttempts = cs.numRetries - s, err := a.t.NewStream(cs.ctx, cs.callHdr) + s, err := a.t.NewStream(a.ctx, cs.callHdr) if err != nil { - if _, ok := err.(transport.PerformedIOError); ok { - // Return without converting to an RPC error so retry code can - // inspect. - return err - } - return toRPCErr(err) + // Return without converting to an RPC error so retry code can + // inspect. + return err } cs.attempt.s = s cs.attempt.p = &parser{r: s} @@ -439,8 +446,7 @@ type clientStream struct { cancel context.CancelFunc // cancels all attempts - sentLast bool // sent an end stream - beginTime time.Time + sentLast bool // sent an end stream methodConfig *MethodConfig @@ -480,6 +486,7 @@ type clientStream struct { // csAttempt implements a single transport stream attempt within a // clientStream. type csAttempt struct { + ctx context.Context cs *clientStream t transport.ClientTransport s *transport.Stream @@ -498,6 +505,7 @@ type csAttempt struct { trInfo *traceInfo statsHandler stats.Handler + beginTime time.Time } func (cs *clientStream) commitAttemptLocked() { @@ -515,46 +523,57 @@ func (cs *clientStream) commitAttempt() { } // shouldRetry returns nil if the RPC should be retried; otherwise it returns -// the error that should be returned by the operation. -func (cs *clientStream) shouldRetry(err error) error { - unprocessed := false +// the error that should be returned by the operation. If the RPC should be +// retried, the bool indicates whether it is being retried transparently. +func (cs *clientStream) shouldRetry(err error) (bool, error) { if cs.attempt.s == nil { - pioErr, ok := err.(transport.PerformedIOError) - if ok { - // Unwrap error. - err = toRPCErr(pioErr.Err) - } else { - unprocessed = true + // Error from NewClientStream. + nse, ok := err.(*transport.NewStreamError) + if !ok { + // Unexpected, but assume no I/O was performed and the RPC is not + // fatal, so retry indefinitely. + return true, nil } - if !ok && !cs.callInfo.failFast { - // In the event of a non-IO operation error from NewStream, we - // never attempted to write anything to the wire, so we can retry - // indefinitely for non-fail-fast RPCs. - return nil + + // Unwrap and convert error. + err = toRPCErr(nse.Err) + + // Never retry DoNotRetry errors, which indicate the RPC should not be + // retried due to max header list size violation, etc. + if nse.DoNotRetry { + return false, err + } + + // In the event of a non-IO operation error from NewStream, we never + // attempted to write anything to the wire, so we can retry + // indefinitely. + if !nse.DoNotTransparentRetry { + return true, nil } } if cs.finished || cs.committed { // RPC is finished or committed; cannot retry. - return err + return false, err } // Wait for the trailers. + unprocessed := false if cs.attempt.s != nil { <-cs.attempt.s.Done() unprocessed = cs.attempt.s.Unprocessed() } if cs.firstAttempt && unprocessed { // First attempt, stream unprocessed: transparently retry. - return nil + return true, nil } if cs.cc.dopts.disableRetry { - return err + return false, err } pushback := 0 hasPushback := false if cs.attempt.s != nil { if !cs.attempt.s.TrailersOnly() { - return err + return false, err } // TODO(retry): Move down if the spec changes to not check server pushback @@ -565,13 +584,13 @@ func (cs *clientStream) shouldRetry(err error) error { if pushback, e = strconv.Atoi(sps[0]); e != nil || pushback < 0 { channelz.Infof(logger, cs.cc.channelzID, "Server retry pushback specified to abort (%q).", sps[0]) cs.retryThrottler.throttle() // This counts as a failure for throttling. - return err + return false, err } hasPushback = true } else if len(sps) > 1 { channelz.Warningf(logger, cs.cc.channelzID, "Server retry pushback specified multiple values (%q); not retrying.", sps) cs.retryThrottler.throttle() // This counts as a failure for throttling. - return err + return false, err } } @@ -584,16 +603,16 @@ func (cs *clientStream) shouldRetry(err error) error { rp := cs.methodConfig.RetryPolicy if rp == nil || !rp.RetryableStatusCodes[code] { - return err + return false, err } // Note: the ordering here is important; we count this as a failure // only if the code matched a retryable code. if cs.retryThrottler.throttle() { - return err + return false, err } if cs.numRetries+1 >= rp.MaxAttempts { - return err + return false, err } var dur time.Duration @@ -616,23 +635,24 @@ func (cs *clientStream) shouldRetry(err error) error { select { case <-t.C: cs.numRetries++ - return nil + return false, nil case <-cs.ctx.Done(): t.Stop() - return status.FromContextError(cs.ctx.Err()).Err() + return false, status.FromContextError(cs.ctx.Err()).Err() } } // Returns nil if a retry was performed and succeeded; error otherwise. func (cs *clientStream) retryLocked(lastErr error) error { for { - cs.attempt.finish(lastErr) - if err := cs.shouldRetry(lastErr); err != nil { + cs.attempt.finish(toRPCErr(lastErr)) + isTransparent, err := cs.shouldRetry(lastErr) + if err != nil { cs.commitAttemptLocked() return err } cs.firstAttempt = false - if err := cs.newAttemptLocked(nil, nil); err != nil { + if err := cs.newAttemptLocked(isTransparent); err != nil { return err } if lastErr = cs.replayBufferLocked(); lastErr == nil { @@ -653,7 +673,11 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func()) for { if cs.committed { cs.mu.Unlock() - return op(cs.attempt) + // toRPCErr is used in case the error from the attempt comes from + // NewClientStream, which intentionally doesn't return a status + // error to allow for further inspection; all other errors should + // already be status errors. + return toRPCErr(op(cs.attempt)) } a := cs.attempt cs.mu.Unlock() @@ -918,7 +942,7 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { return io.EOF } if a.statsHandler != nil { - a.statsHandler.HandleRPC(cs.ctx, outPayload(true, m, data, payld, time.Now())) + a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) } if channelz.IsOn() { a.t.IncrMsgSent() @@ -966,7 +990,7 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { a.mu.Unlock() } if a.statsHandler != nil { - a.statsHandler.HandleRPC(cs.ctx, &stats.InPayload{ + a.statsHandler.HandleRPC(a.ctx, &stats.InPayload{ Client: true, RecvTime: time.Now(), Payload: m, @@ -1028,12 +1052,12 @@ func (a *csAttempt) finish(err error) { if a.statsHandler != nil { end := &stats.End{ Client: true, - BeginTime: a.cs.beginTime, + BeginTime: a.beginTime, EndTime: time.Now(), Trailer: tr, Error: err, } - a.statsHandler.HandleRPC(a.cs.ctx, end) + a.statsHandler.HandleRPC(a.ctx, end) } if a.trInfo != nil && a.trInfo.tr != nil { if err == nil { diff --git a/stress/grpc_testing/metrics_grpc.pb.go b/stress/grpc_testing/metrics_grpc.pb.go index 2ece0325563..0730fad49a4 100644 --- a/stress/grpc_testing/metrics_grpc.pb.go +++ b/stress/grpc_testing/metrics_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: stress/grpc_testing/metrics.proto package grpc_testing diff --git a/tap/tap.go b/tap/tap.go index caea1ebed6e..dbf34e6bb5f 100644 --- a/tap/tap.go +++ b/tap/tap.go @@ -37,16 +37,16 @@ type Info struct { // TODO: More to be added. } -// ServerInHandle defines the function which runs before a new stream is created -// on the server side. If it returns a non-nil error, the stream will not be -// created and a RST_STREAM will be sent back to the client with REFUSED_STREAM. -// The client will receive an RPC error "code = Unavailable, desc = stream -// terminated by RST_STREAM with error code: REFUSED_STREAM". +// ServerInHandle defines the function which runs before a new stream is +// created on the server side. If it returns a non-nil error, the stream will +// not be created and an error will be returned to the client. If the error +// returned is a status error, that status code and message will be used, +// otherwise PermissionDenied will be the code and err.Error() will be the +// message. // // It's intended to be used in situations where you don't want to waste the -// resources to accept the new stream (e.g. rate-limiting). And the content of -// the error will be ignored and won't be sent back to the client. For other -// general usages, please use interceptors. +// resources to accept the new stream (e.g. rate-limiting). For other general +// usages, please use interceptors. // // Note that it is executed in the per-connection I/O goroutine(s) instead of // per-RPC goroutine. Therefore, users should NOT have any diff --git a/test/authority_test.go b/test/authority_test.go index 17ae178b73c..15afa759c90 100644 --- a/test/authority_test.go +++ b/test/authority_test.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux /* diff --git a/test/balancer_test.go b/test/balancer_test.go index bc22036dbac..e2fa4cf31d0 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -28,6 +28,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "google.golang.org/grpc" "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" @@ -37,7 +38,6 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/balancerload" - "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcutil" imetadata "google.golang.org/grpc/internal/metadata" "google.golang.org/grpc/internal/stubserver" @@ -88,7 +88,7 @@ func (b *testBalancer) UpdateClientConnState(state balancer.ClientConnState) err logger.Errorf("testBalancer: failed to NewSubConn: %v", err) return nil } - b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &picker{sc: b.sc, bal: b}}) + b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}}) b.sc.Connect() } return nil @@ -106,8 +106,10 @@ func (b *testBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubCon } switch s.ConnectivityState { - case connectivity.Ready, connectivity.Idle: + case connectivity.Ready: b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{sc: sc, bal: b}}) + case connectivity.Idle: + b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{sc: sc, bal: b, idle: true}}) case connectivity.Connecting: b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}}) case connectivity.TransientFailure: @@ -117,16 +119,23 @@ func (b *testBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubCon func (b *testBalancer) Close() {} +func (b *testBalancer) ExitIdle() {} + type picker struct { - err error - sc balancer.SubConn - bal *testBalancer + err error + sc balancer.SubConn + bal *testBalancer + idle bool } func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { if p.err != nil { return balancer.PickResult{}, p.err } + if p.idle { + p.sc.Connect() + return balancer.PickResult{}, balancer.ErrNoSubConnAvailable + } extraMD, _ := grpcutil.ExtraMetadata(info.Ctx) info.Ctx = nil // Do not validate context. p.bal.pickInfos = append(p.bal.pickInfos, info) @@ -195,14 +204,14 @@ func testPickExtraMetadata(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - // The RPCs will fail, but we don't care. We just need the pick to happen. - ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) - defer cancel1() - tc.EmptyCall(ctx1, &testpb.Empty{}) - - ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) - defer cancel2() - tc.EmptyCall(ctx2, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType)) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil) + } + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType)); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil) + } want := []metadata.MD{ // First RPC doesn't have sub-content-type. @@ -210,9 +219,8 @@ func testPickExtraMetadata(t *testing.T, e env) { // Second RPC has sub-content-type "proto". {"content-type": []string{"application/grpc+proto"}}, } - - if !cmp.Equal(b.pickExtraMDs, want) { - t.Fatalf("%s", cmp.Diff(b.pickExtraMDs, want)) + if diff := cmp.Diff(want, b.pickExtraMDs); diff != "" { + t.Fatalf("unexpected diff in metadata (-want, +got): %s", diff) } } @@ -374,8 +382,9 @@ func (testBalancerKeepAddresses) UpdateSubConnState(sc balancer.SubConn, s balan panic("not used") } -func (testBalancerKeepAddresses) Close() { -} +func (testBalancerKeepAddresses) Close() {} + +func (testBalancerKeepAddresses) ExitIdle() {} // Make sure that non-grpclb balancers don't get grpclb addresses even if name // resolver sends them @@ -698,10 +707,7 @@ func (s) TestEmptyAddrs(t *testing.T) { // Initialize pickfirst client pfr := manual.NewBuilderWithScheme("whatever") - pfrnCalled := grpcsync.NewEvent() - pfr.ResolveNowCallback = func(resolver.ResolveNowOptions) { - pfrnCalled.Fire() - } + pfr.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) pfcc, err := grpc.DialContext(ctx, pfr.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(pfr)) @@ -718,16 +724,10 @@ func (s) TestEmptyAddrs(t *testing.T) { // Remove all addresses. pfr.UpdateState(resolver.State{}) - // Wait for a ResolveNow call on the pick first client's resolver. - <-pfrnCalled.Done() // Initialize roundrobin client rrr := manual.NewBuilderWithScheme("whatever") - rrrnCalled := grpcsync.NewEvent() - rrr.ResolveNowCallback = func(resolver.ResolveNowOptions) { - rrrnCalled.Fire() - } rrr.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) rrcc, err := grpc.DialContext(ctx, rrr.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(rrr), @@ -745,8 +745,6 @@ func (s) TestEmptyAddrs(t *testing.T) { // Remove all addresses. rrr.UpdateState(resolver.State{}) - // Wait for a ResolveNow call on the round robin client's resolver. - <-rrrnCalled.Done() // Confirm several new RPCs succeed on pick first. for i := 0; i < 10; i++ { diff --git a/test/bufconn/bufconn.go b/test/bufconn/bufconn.go index 168cdb8578d..3f77f4876eb 100644 --- a/test/bufconn/bufconn.go +++ b/test/bufconn/bufconn.go @@ -21,6 +21,7 @@ package bufconn import ( + "context" "fmt" "io" "net" @@ -86,8 +87,17 @@ func (l *Listener) Addr() net.Addr { return addr{} } // providing it the server half of the connection, and returns the client half // of the connection. func (l *Listener) Dial() (net.Conn, error) { + return l.DialContext(context.Background()) +} + +// DialContext creates an in-memory full-duplex network connection, unblocks Accept by +// providing it the server half of the connection, and returns the client half +// of the connection. If ctx is Done, returns ctx.Err() +func (l *Listener) DialContext(ctx context.Context) (net.Conn, error) { p1, p2 := newPipe(l.sz), newPipe(l.sz) select { + case <-ctx.Done(): + return nil, ctx.Err() case <-l.done: return nil, errClosed case l.ch <- &conn{p1, p2}: diff --git a/test/channelz_linux_go110_test.go b/test/channelz_linux_test.go similarity index 99% rename from test/channelz_linux_go110_test.go rename to test/channelz_linux_test.go index dea374bfc08..aa6febe537a 100644 --- a/test/channelz_linux_go110_test.go +++ b/test/channelz_linux_test.go @@ -1,5 +1,3 @@ -// +build linux - /* * * Copyright 2018 gRPC authors. diff --git a/test/channelz_test.go b/test/channelz_test.go index 47e7eb92716..6cb09dd8d89 100644 --- a/test/channelz_test.go +++ b/test/channelz_test.go @@ -1689,8 +1689,22 @@ func (s) TestCZSubChannelPickedNewAddress(t *testing.T) { } te.srvs[0].Stop() te.srvs[1].Stop() - // Here, we just wait for all sockets to be up. In the future, if we implement - // IDLE, we may need to make several rpc calls to create the sockets. + // Here, we just wait for all sockets to be up. Make several rpc calls to + // create the sockets since we do not automatically reconnect. + done := make(chan struct{}) + defer close(done) + go func() { + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + tc.EmptyCall(ctx, &testpb.Empty{}) + cancel() + select { + case <-time.After(10 * time.Millisecond): + case <-done: + return + } + } + }() if err := verifyResultWithDelay(func() (bool, error) { tcs, _ := channelz.GetTopChannels(0, 0) if len(tcs) != 1 { diff --git a/test/end2end_test.go b/test/end2end_test.go index 902e9424104..957d13f731f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -55,12 +55,14 @@ import ( "google.golang.org/grpc/health" healthgrpc "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/transport" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" @@ -69,6 +71,7 @@ import ( "google.golang.org/grpc/stats" "google.golang.org/grpc/status" "google.golang.org/grpc/tap" + "google.golang.org/grpc/test/bufconn" testpb "google.golang.org/grpc/test/grpc_testing" "google.golang.org/grpc/testdata" ) @@ -506,10 +509,6 @@ type test struct { customDialOptions []grpc.DialOption resolverScheme string - // All test dialing is blocking by default. Set this to true if dial - // should be non-blocking. - nonBlockingDial bool - // These are are set once startServer is called. The common case is to have // only one testServer. srv stopper @@ -826,10 +825,6 @@ func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) if te.customCodec != nil { opts = append(opts, grpc.WithDefaultCallOptions(grpc.ForceCodec(te.customCodec))) } - if !te.nonBlockingDial && te.srvAddr != "" { - // Only do a blocking dial if server is up. - opts = append(opts, grpc.WithBlock()) - } if te.srvAddr == "" { te.srvAddr = "client.side.only.test" } @@ -1333,6 +1328,131 @@ func testConcurrentServerStopAndGoAway(t *testing.T, e env) { awaitNewConnLogOutput() } +func (s) TestDetailedConnectionCloseErrorPropagatesToRpcError(t *testing.T) { + rpcStartedOnServer := make(chan struct{}) + rpcDoneOnClient := make(chan struct{}) + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + close(rpcStartedOnServer) + <-rpcDoneOnClient + return status.Error(codes.Internal, "arbitrary status") + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // The precise behavior of this test is subject to raceyness around the timing of when TCP packets + // are sent from client to server, and when we tell the server to stop, so we need to account for both + // of these possible error messages: + // 1) If the call to ss.S.Stop() causes the server's sockets to close while there's still in-fight + // data from the client on the TCP connection, then the kernel can send an RST back to the client (also + // see https://stackoverflow.com/questions/33053507/econnreset-in-send-linux-c). Note that while this + // condition is expected to be rare due to the rpcStartedOnServer synchronization, in theory it should + // be possible, e.g. if the client sends a BDP ping at the right time. + // 2) If, for example, the call to ss.S.Stop() happens after the RPC headers have been received at the + // server, then the TCP connection can shutdown gracefully when the server's socket closes. + const possibleConnResetMsg = "connection reset by peer" + const possibleEOFMsg = "error reading from server: EOF" + // Start an RPC. Then, while the RPC is still being accepted or handled at the server, abruptly + // stop the server, killing the connection. The RPC error message should include details about the specific + // connection error that was encountered. + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall = _, %v, want _, ", ss.Client, err) + } + // Block until the RPC has been started on the server. This ensures that the ClientConn will find a healthy + // connection for the RPC to go out on initially, and that the TCP connection will shut down strictly after + // the RPC has been started on it. + <-rpcStartedOnServer + ss.S.Stop() + if _, err := stream.Recv(); err == nil || (!strings.Contains(err.Error(), possibleConnResetMsg) && !strings.Contains(err.Error(), possibleEOFMsg)) { + t.Fatalf("%v.Recv() = _, %v, want _, rpc error containing substring: %q OR %q", stream, err, possibleConnResetMsg, possibleEOFMsg) + } + close(rpcDoneOnClient) +} + +func (s) TestDetailedGoawayErrorOnGracefulClosePropagatesToRPCError(t *testing.T) { + rpcDoneOnClient := make(chan struct{}) + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + <-rpcDoneOnClient + return status.Error(codes.Internal, "arbitrary status") + }, + } + sopts := []grpc.ServerOption{ + grpc.KeepaliveParams(keepalive.ServerParameters{ + MaxConnectionAge: time.Millisecond * 100, + MaxConnectionAgeGrace: time.Millisecond, + }), + } + if err := ss.Start(sopts); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall = _, %v, want _, ", ss.Client, err) + } + const expectedErrorMessageSubstring = "received prior goaway: code: NO_ERROR" + _, err = stream.Recv() + close(rpcDoneOnClient) + if err == nil || !strings.Contains(err.Error(), expectedErrorMessageSubstring) { + t.Fatalf("%v.Recv() = _, %v, want _, rpc error containing substring: %q", stream, err, expectedErrorMessageSubstring) + } +} + +func (s) TestDetailedGoawayErrorOnAbruptClosePropagatesToRPCError(t *testing.T) { + // set the min keepalive time very low so that this test can take + // a reasonable amount of time + prev := internal.KeepaliveMinPingTime + internal.KeepaliveMinPingTime = time.Millisecond + defer func() { internal.KeepaliveMinPingTime = prev }() + + rpcDoneOnClient := make(chan struct{}) + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + <-rpcDoneOnClient + return status.Error(codes.Internal, "arbitrary status") + }, + } + sopts := []grpc.ServerOption{ + grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ + MinTime: time.Second * 1000, /* arbitrary, large value */ + }), + } + dopts := []grpc.DialOption{ + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: time.Millisecond, /* should trigger "too many pings" error quickly */ + Timeout: time.Second * 1000, /* arbitrary, large value */ + PermitWithoutStream: false, + }), + } + if err := ss.Start(sopts, dopts...); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall = _, %v, want _, ", ss.Client, err) + } + const expectedErrorMessageSubstring = `received prior goaway: code: ENHANCE_YOUR_CALM, debug data: "too_many_pings"` + _, err = stream.Recv() + close(rpcDoneOnClient) + if err == nil || !strings.Contains(err.Error(), expectedErrorMessageSubstring) { + t.Fatalf("%v.Recv() = _, %v, want _, rpc error containing substring: |%v|", stream, err, expectedErrorMessageSubstring) + } +} + func (s) TestClientConnCloseAfterGoAwayWithActiveStream(t *testing.T) { for _, e := range listTestEnv() { if e.name == "handler-tls" { @@ -1744,7 +1864,6 @@ func (s) TestServiceConfigMaxMsgSize(t *testing.T) { defer te1.tearDown() te1.resolverScheme = r.Scheme() - te1.nonBlockingDial = true te1.startServer(&testServer{security: e.security}) cc1 := te1.clientConn(grpc.WithResolvers(r)) @@ -1832,7 +1951,6 @@ func (s) TestServiceConfigMaxMsgSize(t *testing.T) { // Case2: Client API set maxReqSize to 1024 (send), maxRespSize to 1024 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv). te2 := testServiceConfigSetup(t, e) te2.resolverScheme = r.Scheme() - te2.nonBlockingDial = true te2.maxClientReceiveMsgSize = newInt(1024) te2.maxClientSendMsgSize = newInt(1024) @@ -1892,7 +2010,6 @@ func (s) TestServiceConfigMaxMsgSize(t *testing.T) { // Case3: Client API set maxReqSize to 4096 (send), maxRespSize to 4096 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv). te3 := testServiceConfigSetup(t, e) te3.resolverScheme = r.Scheme() - te3.nonBlockingDial = true te3.maxClientReceiveMsgSize = newInt(4096) te3.maxClientSendMsgSize = newInt(4096) @@ -1985,7 +2102,6 @@ func (s) TestStreamingRPCWithTimeoutInServiceConfigRecv(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") te.resolverScheme = r.Scheme() - te.nonBlockingDial = true cc := te.clientConn(grpc.WithResolvers(r)) tc := testpb.NewTestServiceClient(cc) @@ -2121,6 +2237,61 @@ func testPreloaderClientSend(t *testing.T, e env) { } } +func (s) TestPreloaderSenderSend(t *testing.T) { + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + for i := 0; i < 10; i++ { + preparedMsg := &grpc.PreparedMsg{} + err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{ + Body: []byte{'0' + uint8(i)}, + }, + }) + if err != nil { + return err + } + stream.SendMsg(preparedMsg) + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } + + var ngot int + var buf bytes.Buffer + for { + reply, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + ngot++ + if buf.Len() > 0 { + buf.WriteByte(',') + } + buf.Write(reply.GetPayload().GetBody()) + } + if want := 10; ngot != want { + t.Errorf("Got %d replies, want %d", ngot, want) + } + if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want { + t.Errorf("Got replies %q; want %q", got, want) + } +} + func (s) TestMaxMsgSizeClientDefault(t *testing.T) { for _, e := range listTestEnv() { testMaxMsgSizeClientDefault(t, e) @@ -2380,10 +2551,13 @@ type myTap struct { func (t *myTap) handle(ctx context.Context, info *tap.Info) (context.Context, error) { if info != nil { - if info.FullMethodName == "/grpc.testing.TestService/EmptyCall" { + switch info.FullMethodName { + case "/grpc.testing.TestService/EmptyCall": t.cnt++ - } else if info.FullMethodName == "/grpc.testing.TestService/UnaryCall" { + case "/grpc.testing.TestService/UnaryCall": return nil, fmt.Errorf("tap error") + case "/grpc.testing.TestService/FullDuplexCall": + return nil, status.Errorf(codes.FailedPrecondition, "test custom error") } } return ctx, nil @@ -2423,8 +2597,15 @@ func testTap(t *testing.T, e env) { ResponseSize: 45, Payload: payload, } - if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.Unavailable { - t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.Unavailable) + if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.PermissionDenied { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.PermissionDenied) + } + str, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected error creating stream: %v", err) + } + if _, err := str.Recv(); status.Code(err) != codes.FailedPrecondition { + t.Fatalf("FullDuplexCall Recv() = _, %v, want _, %s", err, codes.FailedPrecondition) } } @@ -3512,66 +3693,79 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) { } } +// Tests that the client transparently retries correctly when receiving a +// RST_STREAM with code REFUSED_STREAM. func (s) TestTransparentRetry(t *testing.T) { - for _, e := range listTestEnv() { - if e.name == "handler-tls" { - // Fails with RST_STREAM / FLOW_CONTROL_ERROR - continue - } - testTransparentRetry(t, e) - } -} - -// This test makes sure RPCs are retried times when they receive a RST_STREAM -// with the REFUSED_STREAM error code, which the InTapHandle provokes. -func testTransparentRetry(t *testing.T, e env) { - te := newTest(t, e) - attempts := 0 - successAttempt := 2 - te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { - attempts++ - if attempts < successAttempt { - return nil, errors.New("not now") - } - return ctx, nil - } - te.startServer(&testServer{security: e.security}) - defer te.tearDown() - - cc := te.clientConn() - tsc := testpb.NewTestServiceClient(cc) testCases := []struct { - successAttempt int - failFast bool - errCode codes.Code + failFast bool + errCode codes.Code }{{ - successAttempt: 1, + // success attempt: 1, (stream ID 1) }, { - successAttempt: 2, + // success attempt: 2, (stream IDs 3, 5) }, { - successAttempt: 3, - errCode: codes.Unavailable, + // no success attempt (stream IDs 7, 9) + errCode: codes.Unavailable, }, { - successAttempt: 1, - failFast: true, + // success attempt: 1 (stream ID 11), + failFast: true, }, { - successAttempt: 2, - failFast: true, + // success attempt: 2 (stream IDs 13, 15), + failFast: true, }, { - successAttempt: 3, - failFast: true, - errCode: codes.Unavailable, + // no success attempt (stream IDs 17, 19) + failFast: true, + errCode: codes.Unavailable, }} - for _, tc := range testCases { - attempts = 0 - successAttempt = tc.successAttempt - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := tsc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(!tc.failFast)) - cancel() - if status.Code(err) != tc.errCode { - t.Errorf("%+v: tsc.EmptyCall(_, _) = _, %v, want _, Code=%v", tc, err, tc.errCode) + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen. Err: %v", err) + } + defer lis.Close() + server := &httpServer{ + responses: []httpServerResponse{{ + trailers: [][]string{{ + ":status", "200", + "content-type", "application/grpc", + "grpc-status", "0", + }}, + }}, + refuseStream: func(i uint32) bool { + switch i { + case 1, 5, 11, 15: // these stream IDs succeed + return false + } + return true // these are refused + }, + } + server.start(t, lis) + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("failed to dial due to err: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client := testpb.NewTestServiceClient(cc) + + for i, tc := range testCases { + stream, err := client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("error creating stream due to err: %v", err) + } + code := func(err error) codes.Code { + if err == io.EOF { + return codes.OK + } + return status.Code(err) } + if _, err := stream.Recv(); code(err) != tc.errCode { + t.Fatalf("%v: stream.Recv() = _, %v, want error code: %v", i, err, tc.errCode) + } + } } @@ -4937,8 +5131,7 @@ func (s) TestFlowControlLogicalRace(t *testing.T) { go s.Serve(lis) - ctx := context.Background() - cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) if err != nil { t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) } @@ -4947,7 +5140,7 @@ func (s) TestFlowControlLogicalRace(t *testing.T) { failures := 0 for i := 0; i < requestCount; i++ { - ctx, cancel := context.WithTimeout(ctx, requestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) output, err := cl.StreamingOutputCall(ctx, &testpb.StreamingOutputCallRequest{}) if err != nil { t.Fatalf("StreamingOutputCall; err = %q", err) @@ -5145,7 +5338,7 @@ func (s) TestGRPCMethod(t *testing.T) { } defer ss.Stop() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { @@ -5157,6 +5350,86 @@ func (s) TestGRPCMethod(t *testing.T) { } } +// renameProtoCodec is an encoding.Codec wrapper that allows customizing the +// Name() of another codec. +type renameProtoCodec struct { + encoding.Codec + name string +} + +func (r *renameProtoCodec) Name() string { return r.name } + +// TestForceCodecName confirms that the ForceCodec call option sets the subtype +// in the content-type header according to the Name() of the codec provided. +func (s) TestForceCodecName(t *testing.T) { + wantContentTypeCh := make(chan []string, 1) + defer close(wantContentTypeCh) + + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Internal, "no metadata in context") + } + if got, want := md["content-type"], <-wantContentTypeCh; !reflect.DeepEqual(got, want) { + return nil, status.Errorf(codes.Internal, "got content-type=%q; want [%q]", got, want) + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.ForceServerCodec(encoding.GetCodec("proto"))}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + codec := &renameProtoCodec{Codec: encoding.GetCodec("proto"), name: "some-test-name"} + wantContentTypeCh <- []string{"application/grpc+some-test-name"} + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } + + // Confirm the name is converted to lowercase before transmitting. + codec.name = "aNoTHeRNaME" + wantContentTypeCh <- []string{"application/grpc+anothername"} + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } +} + +func (s) TestForceServerCodec(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + codec := &countingProtoCodec{} + if err := ss.Start([]grpc.ServerOption{grpc.ForceServerCodec(codec)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } + + unmarshalCount := atomic.LoadInt32(&codec.unmarshalCount) + const wantUnmarshalCount = 1 + if unmarshalCount != wantUnmarshalCount { + t.Fatalf("protoCodec.unmarshalCount = %d; want %d", unmarshalCount, wantUnmarshalCount) + } + marshalCount := atomic.LoadInt32(&codec.marshalCount) + const wantMarshalCount = 1 + if marshalCount != wantMarshalCount { + t.Fatalf("protoCodec.marshalCount = %d; want %d", marshalCount, wantMarshalCount) + } +} + func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) { const mdkey = "somedata" @@ -5526,6 +5799,33 @@ func (c *errCodec) Name() string { return "Fermat's near-miss." } +type countingProtoCodec struct { + marshalCount int32 + unmarshalCount int32 +} + +func (p *countingProtoCodec) Marshal(v interface{}) ([]byte, error) { + atomic.AddInt32(&p.marshalCount, 1) + vv, ok := v.(proto.Message) + if !ok { + return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) + } + return proto.Marshal(vv) +} + +func (p *countingProtoCodec) Unmarshal(data []byte, v interface{}) error { + atomic.AddInt32(&p.unmarshalCount, 1) + vv, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) + } + return proto.Unmarshal(data, vv) +} + +func (*countingProtoCodec) Name() string { + return "proto" +} + func (s) TestEncodeDoesntPanic(t *testing.T) { for _, e := range listTestEnv() { testEncodeDoesntPanic(t, e) @@ -6036,6 +6336,23 @@ func testServiceConfigMaxMsgSizeTD(t *testing.T, e env) { } } +// TestMalformedStreamMethod starts a test server and sends an RPC with a +// malformed method name. The server should respond with an UNIMPLEMENTED status +// code in this case. +func (s) TestMalformedStreamMethod(t *testing.T) { + const testMethod = "a-method-name-without-any-slashes" + te := newTest(t, tcpClearRREnv) + te.startServer(nil) + defer te.tearDown() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + err := te.clientConn().Invoke(ctx, testMethod, nil, nil) + if gotCode := status.Code(err); gotCode != codes.Unimplemented { + t.Fatalf("Invoke with method %q, got code %s, want %s", testMethod, gotCode, codes.Unimplemented) + } +} + func (s) TestMethodFromServerStream(t *testing.T) { const testMethod = "/package.service/method" e := tcpClearRREnv @@ -6244,7 +6561,7 @@ func (s) TestServeExitsWhenListenerClosed(t *testing.T) { close(done) }() - cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) if err != nil { t.Fatalf("Failed to dial server: %v", err) } @@ -6457,15 +6774,13 @@ func (s) TestDisabledIOBuffers(t *testing.T) { t.Fatalf("Failed to create listener: %v", err) } - done := make(chan struct{}) go func() { s.Serve(lis) - close(done) }() defer s.Stop() dctx, dcancel := context.WithTimeout(context.Background(), 5*time.Second) defer dcancel() - cc, err := grpc.DialContext(dctx, lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock(), grpc.WithWriteBufferSize(0), grpc.WithReadBufferSize(0)) + cc, err := grpc.DialContext(dctx, lis.Addr().String(), grpc.WithInsecure(), grpc.WithWriteBufferSize(0), grpc.WithReadBufferSize(0)) if err != nil { t.Fatalf("Failed to dial server") } @@ -6738,6 +7053,10 @@ func (s) TestGoAwayThenClose(t *testing.T) { return &testpb.SimpleResponse{}, nil }, fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil { + t.Errorf("unexpected error from send: %v", err) + return err + } // Wait forever. _, err := stream.Recv() if err == nil { @@ -6757,11 +7076,11 @@ func (s) TestGoAwayThenClose(t *testing.T) { s2 := grpc.NewServer() defer s2.Stop() testpb.RegisterTestServiceServer(s2, ts) - go s2.Serve(lis2) r := manual.NewBuilderWithScheme("whatever") r.InitialState(resolver.State{Addresses: []resolver.Address{ {Addr: lis1.Addr().String()}, + {Addr: lis2.Addr().String()}, }}) cc, err := grpc.DialContext(ctx, r.Scheme()+":///", grpc.WithResolvers(r), grpc.WithInsecure()) if err != nil { @@ -6771,30 +7090,49 @@ func (s) TestGoAwayThenClose(t *testing.T) { client := testpb.NewTestServiceClient(cc) - // Should go on connection 1. We use a long-lived RPC because it will cause GracefulStop to send GO_AWAY, but the - // connection doesn't get closed until the server stops and the client receives. + // We make a streaming RPC and do an one-message-round-trip to make sure + // it's created on connection 1. + // + // We use a long-lived RPC because it will cause GracefulStop to send + // GO_AWAY, but the connection doesn't get closed until the server stops and + // the client receives the error. stream, err := client.FullDuplexCall(ctx) if err != nil { t.Fatalf("FullDuplexCall(_) = _, %v; want _, nil", err) } + if _, err = stream.Recv(); err != nil { + t.Fatalf("unexpected error from first recv: %v", err) + } - r.UpdateState(resolver.State{Addresses: []resolver.Address{ - {Addr: lis1.Addr().String()}, - {Addr: lis2.Addr().String()}, - }}) + go s2.Serve(lis2) // Send GO_AWAY to connection 1. go s1.GracefulStop() - // Wait for connection 2 to be established. + // Wait for the ClientConn to enter IDLE state. + state := cc.GetState() + for ; state != connectivity.Idle && cc.WaitForStateChange(ctx, state); state = cc.GetState() { + } + if state != connectivity.Idle { + t.Fatalf("timed out waiting for IDLE channel state; last state = %v", state) + } + + // Initiate another RPC to create another connection. + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { + t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err) + } + + // Assert that connection 2 has been established. <-conn2Established.Done() + // Close the listener for server2 to prevent it from allowing new connections. + lis2.Close() + // Close connection 1. s1.Stop() // Wait for client to close. - _, err = stream.Recv() - if err == nil { + if _, err = stream.Recv(); err == nil { t.Fatal("expected the stream to die, but got a successful Recv") } @@ -6831,7 +7169,6 @@ func (s) TestRPCWaitsForResolver(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") te.resolverScheme = r.Scheme() - te.nonBlockingDial = true cc := te.clientConn(grpc.WithResolvers(r)) tc := testpb.NewTestServiceClient(cc) @@ -6918,7 +7255,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { ":status", "403", "content-type", "application/grpc", }, - errCode: codes.Unknown, + errCode: codes.PermissionDenied, }, { // malformed grpc-status. @@ -6937,7 +7274,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { "grpc-status", "0", "grpc-tags-bin", "???", }, - errCode: codes.Internal, + errCode: codes.Unavailable, }, { // gRPC status error. @@ -6946,14 +7283,14 @@ func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { "content-type", "application/grpc", "grpc-status", "3", }, - errCode: codes.InvalidArgument, + errCode: codes.Unavailable, }, } { doHTTPHeaderTest(t, test.errCode, test.header) } } -// Testing non-Trailers-only Trailers (delievered in second HEADERS frame) +// Testing non-Trailers-only Trailers (delivered in second HEADERS frame) func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { for _, test := range []struct { responseHeader []string @@ -6969,7 +7306,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { // trailer missing grpc-status ":status", "502", }, - errCode: codes.Unknown, + errCode: codes.Unavailable, }, { responseHeader: []string{ @@ -6981,6 +7318,18 @@ func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { "grpc-status", "0", "grpc-status-details-bin", "????", }, + errCode: codes.Unimplemented, + }, + { + responseHeader: []string{ + ":status", "200", + "content-type", "application/grpc", + }, + trailer: []string{ + // malformed grpc-status-details-bin field + "grpc-status", "0", + "grpc-status-details-bin", "????", + }, errCode: codes.Internal, }, } { @@ -6996,8 +7345,18 @@ func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) { doHTTPHeaderTest(t, codes.Internal, header, header, header) } +type httpServerResponse struct { + headers [][]string + payload []byte + trailers [][]string +} + type httpServer struct { - headerFields [][]string + // If waitForEndStream is set, wait for the client to send a frame with end + // stream in it before sending a response/refused stream. + waitForEndStream bool + refuseStream func(uint32) bool + responses []httpServerResponse } func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error { @@ -7021,6 +7380,10 @@ func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields }) } +func (s *httpServer) writePayload(framer *http2.Framer, sid uint32, payload []byte) error { + return framer.WriteData(sid, false, payload) +} + func (s *httpServer) start(t *testing.T, lis net.Listener) { // Launch an HTTP server to send back header. go func() { @@ -7045,24 +7408,66 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { writer.Flush() // necessary since client is expecting preface before declaring connection fully setup. var sid uint32 - // Read frames until a header is received. - for { - frame, err := framer.ReadFrame() - if err != nil { - t.Errorf("Error at server-side while reading frame. Err: %v", err) - return + // Loop until conn is closed and framer returns io.EOF + for requestNum := 0; ; requestNum = (requestNum + 1) % len(s.responses) { + // Read frames until a header is received. + for { + frame, err := framer.ReadFrame() + if err != nil { + if err != io.EOF { + t.Errorf("Error at server-side while reading frame. Err: %v", err) + } + return + } + sid = 0 + switch fr := frame.(type) { + case *http2.HeadersFrame: + // Respond after this if we are not waiting for an end + // stream or if this frame ends it. + if !s.waitForEndStream || fr.StreamEnded() { + sid = fr.Header().StreamID + } + + case *http2.DataFrame: + // Respond after this if we were waiting for an end stream + // and this frame ends it. (If we were not waiting for an + // end stream, this stream was already responded to when + // the headers were received.) + if s.waitForEndStream && fr.StreamEnded() { + sid = fr.Header().StreamID + } + } + if sid != 0 { + if s.refuseStream == nil || !s.refuseStream(sid) { + break + } + framer.WriteRSTStream(sid, http2.ErrCodeRefusedStream) + writer.Flush() + } } - if hframe, ok := frame.(*http2.HeadersFrame); ok { - sid = hframe.Header().StreamID - break + + response := s.responses[requestNum] + for _, header := range response.headers { + if err = s.writeHeader(framer, sid, header, false); err != nil { + t.Errorf("Error at server-side while writing headers. Err: %v", err) + return + } + writer.Flush() } - } - for i, headers := range s.headerFields { - if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil { - t.Errorf("Error at server-side while writing headers. Err: %v", err) - return + if response.payload != nil { + if err = s.writePayload(framer, sid, response.payload); err != nil { + t.Errorf("Error at server-side while writing payload. Err: %v", err) + return + } + writer.Flush() + } + for i, trailer := range response.trailers { + if err = s.writeHeader(framer, sid, trailer, i == len(response.trailers)-1); err != nil { + t.Errorf("Error at server-side while writing trailers. Err: %v", err) + return + } + writer.Flush() } - writer.Flush() } }() } @@ -7075,7 +7480,7 @@ func doHTTPHeaderTest(t *testing.T, errCode codes.Code, headerFields ...[]string } defer lis.Close() server := &httpServer{ - headerFields: headerFields, + responses: []httpServerResponse{{trailers: headerFields}}, } server.start(t, lis) cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) @@ -7251,3 +7656,395 @@ func (s) TestCanceledRPCCallOptionRace(t *testing.T) { } wg.Wait() } + +func (s) TestClientSettingsFloodCloseConn(t *testing.T) { + // Tests that the server properly closes its transport if the client floods + // settings frames and then closes the connection. + + // Minimize buffer sizes to stimulate failure condition more quickly. + s := grpc.NewServer(grpc.WriteBufferSize(20)) + l := bufconn.Listen(20) + go s.Serve(l) + + // Dial our server and handshake. + conn, err := l.Dial() + if err != nil { + t.Fatalf("Error dialing bufconn: %v", err) + } + + n, err := conn.Write([]byte(http2.ClientPreface)) + if err != nil || n != len(http2.ClientPreface) { + t.Fatalf("Error writing client preface: %v, %v", n, err) + } + + fr := http2.NewFramer(conn, conn) + f, err := fr.ReadFrame() + if err != nil { + t.Fatalf("Error reading initial settings frame: %v", err) + } + if _, ok := f.(*http2.SettingsFrame); ok { + if err := fr.WriteSettingsAck(); err != nil { + t.Fatalf("Error writing settings ack: %v", err) + } + } else { + t.Fatalf("Error reading initial settings frame: type=%T", f) + } + + // Confirm settings can be written, and that an ack is read. + if err = fr.WriteSettings(); err != nil { + t.Fatalf("Error writing settings frame: %v", err) + } + if f, err = fr.ReadFrame(); err != nil { + t.Fatalf("Error reading frame: %v", err) + } + if sf, ok := f.(*http2.SettingsFrame); !ok || !sf.IsAck() { + t.Fatalf("Unexpected frame: %v", f) + } + + // Flood settings frames until a timeout occurs, indiciating the server has + // stopped reading from the connection, then close the conn. + for { + conn.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)) + if err := fr.WriteSettings(); err != nil { + if to, ok := err.(interface{ Timeout() bool }); !ok || !to.Timeout() { + t.Fatalf("Received unexpected write error: %v", err) + } + break + } + } + conn.Close() + + // If the server does not handle this situation correctly, it will never + // close the transport. This is because its loopyWriter.run() will have + // exited, and thus not handle the goAway the draining process initiates. + // Also, we would see a goroutine leak in this case, as the reader would be + // blocked on the controlBuf's throttle() method indefinitely. + + timer := time.AfterFunc(5*time.Second, func() { + t.Errorf("Timeout waiting for GracefulStop to return") + s.Stop() + }) + s.GracefulStop() + timer.Stop() +} + +// TestDeadlineSetOnConnectionOnClientCredentialHandshake tests that there is a deadline +// set on the net.Conn when a credential handshake happens in http2_client. +func (s) TestDeadlineSetOnConnectionOnClientCredentialHandshake(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + connCh := make(chan net.Conn, 1) + go func() { + defer close(connCh) + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + connCh <- conn + }() + defer func() { + conn := <-connCh + if conn != nil { + conn.Close() + } + }() + deadlineCh := testutils.NewChannel() + cvd := &credentialsVerifyDeadline{ + deadlineCh: deadlineCh, + } + dOpt := grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + return &infoConn{Conn: conn}, nil + }) + cc, err := grpc.Dial(lis.Addr().String(), dOpt, grpc.WithTransportCredentials(cvd)) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + deadline, err := deadlineCh.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from credsInvoked: %v", err) + } + // Default connection timeout is 20 seconds, so if the deadline exceeds now + // + 18 seconds it should be valid. + if !deadline.(time.Time).After(time.Now().Add(time.Second * 18)) { + t.Fatalf("Connection did not have deadline set.") + } +} + +type infoConn struct { + net.Conn + deadline time.Time +} + +func (c *infoConn) SetDeadline(t time.Time) error { + c.deadline = t + return c.Conn.SetDeadline(t) +} + +type credentialsVerifyDeadline struct { + deadlineCh *testutils.Channel +} + +func (cvd *credentialsVerifyDeadline) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} + +func (cvd *credentialsVerifyDeadline) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + cvd.deadlineCh.Send(rawConn.(*infoConn).deadline) + return rawConn, nil, nil +} + +func (cvd *credentialsVerifyDeadline) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (cvd *credentialsVerifyDeadline) Clone() credentials.TransportCredentials { + return cvd +} +func (cvd *credentialsVerifyDeadline) OverrideServerName(s string) error { + return nil +} + +func unaryInterceptorVerifyConn(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + conn := transport.GetConnection(ctx) + if conn == nil { + return nil, status.Error(codes.NotFound, "connection was not in context") + } + return nil, status.Error(codes.OK, "") +} + +// TestUnaryServerInterceptorGetsConnection tests whether the accepted conn on +// the server gets to any unary interceptors on the server side. +func (s) TestUnaryServerInterceptorGetsConnection(t *testing.T) { + ss := &stubserver.StubServer{} + if err := ss.Start([]grpc.ServerOption{grpc.UnaryInterceptor(unaryInterceptorVerifyConn)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.OK { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v, want _, error code %s", err, codes.OK) + } +} + +func streamingInterceptorVerifyConn(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + conn := transport.GetConnection(ss.Context()) + if conn == nil { + return status.Error(codes.NotFound, "connection was not in context") + } + return status.Error(codes.OK, "") +} + +// TestStreamingServerInterceptorGetsConnection tests whether the accepted conn on +// the server gets to any streaming interceptors on the server side. +func (s) TestStreamingServerInterceptorGetsConnection(t *testing.T) { + ss := &stubserver.StubServer{} + if err := ss.Start([]grpc.ServerOption{grpc.StreamInterceptor(streamingInterceptorVerifyConn)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + s, err := ss.Client.StreamingOutputCall(ctx, &testpb.StreamingOutputCallRequest{}) + if err != nil { + t.Fatalf("ss.Client.StreamingOutputCall(_) = _, %v, want _, ", err) + } + if _, err := s.Recv(); err != io.EOF { + t.Fatalf("ss.Client.StreamingInputCall(_) = _, %v, want _, %v", err, io.EOF) + } +} + +// unaryInterceptorVerifyAuthority verifies there is an unambiguous :authority +// once the request gets to an interceptor. An unambiguous :authority is defined +// as at most a single :authority header, and no host header according to A41. +func unaryInterceptorVerifyAuthority(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.NotFound, "metadata was not in context") + } + authority := md.Get(":authority") + if len(authority) > 1 { // Should be an unambiguous authority by the time it gets to interceptor. + return nil, status.Error(codes.NotFound, ":authority value had more than one value") + } + // Host header shouldn't be present by the time it gets to the interceptor + // level (should either be renamed to :authority or explicitly deleted). + host := md.Get("host") + if len(host) != 0 { + return nil, status.Error(codes.NotFound, "host header should not be present in metadata") + } + // Pass back the authority for verification on client - NotFound so + // grpc-message will be available to read for verification. + if len(authority) == 0 { + // Represent no :authority header present with an empty string. + return nil, status.Error(codes.NotFound, "") + } + return nil, status.Error(codes.NotFound, authority[0]) +} + +// TestAuthorityHeader tests that the eventual :authority that reaches the grpc +// layer is unambiguous due to logic added in A41. +func (s) TestAuthorityHeader(t *testing.T) { + tests := []struct { + name string + headers []string + wantAuthority string + }{ + // "If :authority is missing, Host must be renamed to :authority." - A41 + { + name: "Missing :authority", + // Codepath triggered by incoming headers with no authority but with + // a host. + headers: []string{ + ":method", "POST", + ":path", "/grpc.testing.TestService/UnaryCall", + "content-type", "application/grpc", + "te", "trailers", + "host", "localhost", + }, + wantAuthority: "localhost", + }, + { + name: "Missing :authority and host", + // Codepath triggered by incoming headers with no :authority and no + // host. + headers: []string{ + ":method", "POST", + ":path", "/grpc.testing.TestService/UnaryCall", + "content-type", "application/grpc", + "te", "trailers", + }, + wantAuthority: "", + }, + // "If :authority is present, Host must be discarded." - A41 + { + name: ":authority and host present", + // Codepath triggered by incoming headers with both an authority + // header and a host header. + headers: []string{ + ":method", "POST", + ":path", "/grpc.testing.TestService/UnaryCall", + ":authority", "localhost", + "content-type", "application/grpc", + "host", "localhost2", + }, + wantAuthority: "localhost", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + te := newTest(t, tcpClearRREnv) + ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }} + te.unaryServerInt = unaryInterceptorVerifyAuthority + te.startServer(ts) + defer te.tearDown() + success := testutils.NewChannel() + te.withServerTester(func(st *serverTester) { + st.writeHeaders(http2.HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(test.headers...), + EndStream: false, + EndHeaders: true, + }) + st.writeData(1, true, []byte{0, 0, 0, 0, 0}) + + for { + frame := st.wantAnyFrame() + f, ok := frame.(*http2.MetaHeadersFrame) + if !ok { + continue + } + for _, header := range f.Fields { + if header.Name == "grpc-message" { + success.Send(header.Value) + return + } + } + } + }) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotAuthority, err := success.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from channel: %v", err) + } + if gotAuthority != test.wantAuthority { + t.Fatalf("gotAuthority: %v, wantAuthority %v", gotAuthority, test.wantAuthority) + } + }) + } +} + +// wrapCloseListener tracks Accepts/Closes and maintains a counter of the +// number of open connections. +type wrapCloseListener struct { + net.Listener + connsOpen int32 +} + +// wrapCloseListener is returned by wrapCloseListener.Accept and decrements its +// connsOpen when Close is called. +type wrapCloseConn struct { + net.Conn + lis *wrapCloseListener + closeOnce sync.Once +} + +func (w *wrapCloseListener) Accept() (net.Conn, error) { + conn, err := w.Listener.Accept() + if err != nil { + return nil, err + } + atomic.AddInt32(&w.connsOpen, 1) + return &wrapCloseConn{Conn: conn, lis: w}, nil +} + +func (w *wrapCloseConn) Close() error { + defer w.closeOnce.Do(func() { atomic.AddInt32(&w.lis.connsOpen, -1) }) + return w.Conn.Close() +} + +// TestServerClosesConn ensures conn.Close is always closed even if the client +// doesn't complete the HTTP/2 handshake. +func (s) TestServerClosesConn(t *testing.T) { + lis := bufconn.Listen(20) + wrapLis := &wrapCloseListener{Listener: lis} + + s := grpc.NewServer() + go s.Serve(wrapLis) + defer s.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + for i := 0; i < 10; i++ { + conn, err := lis.DialContext(ctx) + if err != nil { + t.Fatalf("Dial = _, %v; want _, nil", err) + } + conn.Close() + } + for ctx.Err() == nil { + if atomic.LoadInt32(&wrapLis.connsOpen) == 0 { + return + } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("timed out waiting for conns to be closed by server; still open: %v", atomic.LoadInt32(&wrapLis.connsOpen)) +} diff --git a/test/go_vet/vet.go b/test/go_vet/vet.go deleted file mode 100644 index 475e8d683fc..00000000000 --- a/test/go_vet/vet.go +++ /dev/null @@ -1,53 +0,0 @@ -/* - * - * Copyright 2018 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// vet checks whether files that are supposed to be built on appengine running -// Go 1.10 or earlier import an unsupported package (e.g. "unsafe", "syscall"). -package main - -import ( - "fmt" - "go/build" - "os" -) - -func main() { - fail := false - b := build.Default - b.BuildTags = []string{"appengine", "appenginevm"} - argsWithoutProg := os.Args[1:] - for _, dir := range argsWithoutProg { - p, err := b.Import(".", dir, 0) - if _, ok := err.(*build.NoGoError); ok { - continue - } else if err != nil { - fmt.Printf("build.Import failed due to %v\n", err) - fail = true - continue - } - for _, pkg := range p.Imports { - if pkg == "syscall" || pkg == "unsafe" { - fmt.Printf("Package %s/%s importing %s package without appengine build tag is NOT ALLOWED!\n", p.Dir, p.Name, pkg) - fail = true - } - } - } - if fail { - os.Exit(1) - } -} diff --git a/test/grpc_testing/test_grpc.pb.go b/test/grpc_testing/test_grpc.pb.go index ab3b68a92bc..76b3935620c 100644 --- a/test/grpc_testing/test_grpc.pb.go +++ b/test/grpc_testing/test_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.1.0 +// - protoc v3.14.0 +// source: test/grpc_testing/test.proto package grpc_testing diff --git a/test/insecure_creds_test.go b/test/insecure_creds_test.go index 19f8bb8b791..791cf650887 100644 --- a/test/insecure_creds_test.go +++ b/test/insecure_creds_test.go @@ -124,21 +124,19 @@ func (s) TestInsecureCreds(t *testing.T) { go s.Serve(lis) addr := lis.Addr().String() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - cOpts := []grpc.DialOption{grpc.WithBlock()} + opts := []grpc.DialOption{grpc.WithInsecure()} if test.clientInsecureCreds { - cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } else { - cOpts = append(cOpts, grpc.WithInsecure()) + opts = []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} } - cc, err := grpc.DialContext(ctx, addr, cOpts...) + cc, err := grpc.Dial(addr, opts...) if err != nil { t.Fatalf("grpc.Dial(%q) failed: %v", addr, err) } defer cc.Close() c := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() if _, err = c.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("EmptyCall(_, _) = _, %v; want _, ", err) } @@ -151,19 +149,16 @@ func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) { desc string perRPCCredsViaDialOptions bool perRPCCredsViaCallOptions bool - wantErr string }{ { desc: "send PerRPCCredentials via DialOptions", perRPCCredsViaDialOptions: true, perRPCCredsViaCallOptions: false, - wantErr: "context deadline exceeded", }, { desc: "send PerRPCCredentials via CallOptions", perRPCCredsViaDialOptions: false, perRPCCredsViaCallOptions: true, - wantErr: "transport: cannot send secure credentials on an insecure connection", }, } for _, test := range tests { @@ -174,44 +169,38 @@ func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) { }, } - sOpts := []grpc.ServerOption{} - sOpts = append(sOpts, grpc.Creds(insecure.NewCredentials())) - s := grpc.NewServer(sOpts...) + s := grpc.NewServer(grpc.Creds(insecure.NewCredentials())) defer s.Stop() - testpb.RegisterTestServiceServer(s, ss) lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("net.Listen(tcp, localhost:0) failed: %v", err) } - go s.Serve(lis) addr := lis.Addr().String() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cOpts := []grpc.DialOption{grpc.WithBlock()} - cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + dopts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} if test.perRPCCredsViaDialOptions { - cOpts = append(cOpts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{})) - if _, err := grpc.DialContext(ctx, addr, cOpts...); !strings.Contains(err.Error(), test.wantErr) { - t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_DialOptions = %v; want %s", err, test.wantErr) - } + dopts = append(dopts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{})) } - + copts := []grpc.CallOption{} if test.perRPCCredsViaCallOptions { - cc, err := grpc.DialContext(ctx, addr, cOpts...) - if err != nil { - t.Fatalf("grpc.Dial(%q) failed: %v", addr, err) - } - defer cc.Close() - - c := testpb.NewTestServiceClient(cc) - if _, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})); !strings.Contains(err.Error(), test.wantErr) { - t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, test.wantErr) - } + copts = append(copts, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})) + } + cc, err := grpc.Dial(addr, dopts...) + if err != nil { + t.Fatalf("grpc.Dial(%q) failed: %v", addr, err) + } + defer cc.Close() + + const wantErr = "transport: cannot send secure credentials on an insecure connection" + c := testpb.NewTestServiceClient(cc) + if _, err = c.EmptyCall(ctx, &testpb.Empty{}, copts...); err == nil || !strings.Contains(err.Error(), wantErr) { + t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, wantErr) } }) } diff --git a/test/kokoro/xds.cfg b/test/kokoro/xds.cfg index d1a078217b8..a1e4ed0bb5e 100644 --- a/test/kokoro/xds.cfg +++ b/test/kokoro/xds.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-go/test/kokoro/xds.sh" -timeout_mins: 120 +timeout_mins: 360 action { define_artifacts { regex: "**/*sponge_log.*" diff --git a/test/kokoro/xds.sh b/test/kokoro/xds.sh index 36a3f4563cf..7b7f48dba30 100755 --- a/test/kokoro/xds.sh +++ b/test/kokoro/xds.sh @@ -3,6 +3,12 @@ set -exu -o pipefail [[ -f /VERSION ]] && cat /VERSION + +echo "Remove the expired letsencrypt.org cert and update the CA certificates" +sudo apt-get install -y ca-certificates +sudo rm /usr/share/ca-certificates/mozilla/DST_Root_CA_X3.crt +sudo update-ca-certificates + cd github export GOPATH="${HOME}/gopath" @@ -27,7 +33,7 @@ grpc/tools/run_tests/helper_scripts/prep_xds.sh # they are added into "all". GRPC_GO_LOG_VERBOSITY_LEVEL=99 GRPC_GO_LOG_SEVERITY_LEVEL=info \ python3 grpc/tools/run_tests/run_xds_tests.py \ - --test_case="all,path_matching,header_matching,circuit_breaking,timeout" \ + --test_case="all,circuit_breaking,timeout,fault_injection,csds" \ --project_id=grpc-testing \ --project_num=830293263384 \ --source_image=projects/grpc-testing/global/images/xds-test-server-4 \ diff --git a/test/kokoro/xds_k8s.cfg b/test/kokoro/xds_k8s.cfg new file mode 100644 index 00000000000..4d5e019991f --- /dev/null +++ b/test/kokoro/xds_k8s.cfg @@ -0,0 +1,13 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-go/test/kokoro/xds_k8s.sh" +timeout_mins: 120 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*sponge_log.log" + strip_prefix: "artifacts" + } +} diff --git a/test/kokoro/xds_k8s.sh b/test/kokoro/xds_k8s.sh new file mode 100755 index 00000000000..f91d1d026d6 --- /dev/null +++ b/test/kokoro/xds_k8s.sh @@ -0,0 +1,155 @@ +#!/usr/bin/env bash +# Copyright 2021 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +# Constants +readonly GITHUB_REPOSITORY_NAME="grpc-go" +readonly TEST_DRIVER_INSTALL_SCRIPT_URL="https://raw.githubusercontent.com/${TEST_DRIVER_REPO_OWNER:-grpc}/grpc/${TEST_DRIVER_BRANCH:-master}/tools/internal_ci/linux/grpc_xds_k8s_install_test_driver.sh" +## xDS test server/client Docker images +readonly SERVER_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/go-server" +readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/go-client" +readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" + +####################################### +# Builds test app Docker images and pushes them to GCR +# Globals: +# SERVER_IMAGE_NAME: Test server Docker image name +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# None +# Outputs: +# Writes the output of `gcloud builds submit` to stdout, stderr +####################################### +build_test_app_docker_images() { + echo "Building Go xDS interop test app Docker images" + docker build -f "${SRC_DIR}/interop/xds/client/Dockerfile" -t "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" "${SRC_DIR}" + docker build -f "${SRC_DIR}/interop/xds/server/Dockerfile" -t "${SERVER_IMAGE_NAME}:${GIT_COMMIT}" "${SRC_DIR}" + gcloud -q auth configure-docker + docker push "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" + docker push "${SERVER_IMAGE_NAME}:${GIT_COMMIT}" + if [[ -n $KOKORO_JOB_NAME ]]; then + branch_name=$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/go/([^/]+)/.*|\1|') + tag_and_push_docker_image "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" "${branch_name}" + tag_and_push_docker_image "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}" "${branch_name}" + fi +} + +####################################### +# Builds test app and its docker images unless they already exist +# Globals: +# SERVER_IMAGE_NAME: Test server Docker image name +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# FORCE_IMAGE_BUILD +# Arguments: +# None +# Outputs: +# Writes the output to stdout, stderr +####################################### +build_docker_images_if_needed() { + # Check if images already exist + server_tags="$(gcloud_gcr_list_image_tags "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}")" + printf "Server image: %s:%s\n" "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}" + echo "${server_tags:-Server image not found}" + + client_tags="$(gcloud_gcr_list_image_tags "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}")" + printf "Client image: %s:%s\n" "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" + echo "${client_tags:-Client image not found}" + + # Build if any of the images are missing, or FORCE_IMAGE_BUILD=1 + if [[ "${FORCE_IMAGE_BUILD}" == "1" || -z "${server_tags}" || -z "${client_tags}" ]]; then + build_test_app_docker_images + else + echo "Skipping Go test app build" + fi +} + +####################################### +# Executes the test case +# Globals: +# TEST_DRIVER_FLAGFILE: Relative path to test driver flagfile +# KUBE_CONTEXT: The name of kubectl context with GKE cluster access +# TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report +# SERVER_IMAGE_NAME: Test server Docker image name +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# Test case name +# Outputs: +# Writes the output of test execution to stdout, stderr +# Test xUnit report to ${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml +####################################### +run_test() { + # Test driver usage: + # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage + local test_name="${1:?Usage: run_test test_name}" + set -x + python -m "tests.${test_name}" \ + --flagfile="${TEST_DRIVER_FLAGFILE}" \ + --kube_context="${KUBE_CONTEXT}" \ + --server_image="${SERVER_IMAGE_NAME}:${GIT_COMMIT}" \ + --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ + --xml_output_file="${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml" \ + --force_cleanup \ + --nocheck_local_certs + set +x +} + +####################################### +# Main function: provision software necessary to execute tests, and run them +# Globals: +# KOKORO_ARTIFACTS_DIR +# GITHUB_REPOSITORY_NAME +# SRC_DIR: Populated with absolute path to the source repo +# TEST_DRIVER_REPO_DIR: Populated with the path to the repo containing +# the test driver +# TEST_DRIVER_FULL_DIR: Populated with the path to the test driver source code +# TEST_DRIVER_FLAGFILE: Populated with relative path to test driver flagfile +# TEST_XML_OUTPUT_DIR: Populated with the path to test xUnit XML report +# GIT_ORIGIN_URL: Populated with the origin URL of git repo used for the build +# GIT_COMMIT: Populated with the SHA-1 of git commit being built +# GIT_COMMIT_SHORT: Populated with the short SHA-1 of git commit being built +# KUBE_CONTEXT: Populated with name of kubectl context with GKE cluster access +# Arguments: +# None +# Outputs: +# Writes the output of test execution to stdout, stderr +####################################### +main() { + local script_dir + script_dir="$(dirname "$0")" + + # Source the test driver from the master branch. + echo "Sourcing test driver install script from: ${TEST_DRIVER_INSTALL_SCRIPT_URL}" + source /dev/stdin <<< "$(curl -s "${TEST_DRIVER_INSTALL_SCRIPT_URL}")" + + activate_gke_cluster GKE_CLUSTER_PSM_SECURITY + + set -x + if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then + kokoro_setup_test_driver "${GITHUB_REPOSITORY_NAME}" + else + local_setup_test_driver "${script_dir}" + fi + build_docker_images_if_needed + # Run tests + cd "${TEST_DRIVER_FULL_DIR}" + run_test baseline_test + run_test security_test +} + +main "$@" diff --git a/test/kokoro/xds_url_map.cfg b/test/kokoro/xds_url_map.cfg new file mode 100644 index 00000000000..f6fd84a419a --- /dev/null +++ b/test/kokoro/xds_url_map.cfg @@ -0,0 +1,13 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-go/test/kokoro/xds_url_map.sh" +timeout_mins: 60 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*sponge_log.log" + strip_prefix: "artifacts" + } +} diff --git a/test/kokoro/xds_url_map.sh b/test/kokoro/xds_url_map.sh new file mode 100755 index 00000000000..34805d43a13 --- /dev/null +++ b/test/kokoro/xds_url_map.sh @@ -0,0 +1,138 @@ +#!/usr/bin/env bash +# Copyright 2021 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +# Constants +readonly GITHUB_REPOSITORY_NAME="grpc-go" +readonly TEST_DRIVER_INSTALL_SCRIPT_URL="https://raw.githubusercontent.com/${TEST_DRIVER_REPO_OWNER:-grpc}/grpc/${TEST_DRIVER_BRANCH:-master}/tools/internal_ci/linux/grpc_xds_k8s_install_test_driver.sh" +## xDS test client Docker images +readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/go-client" +readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" + +####################################### +# Builds test app Docker images and pushes them to GCR +# Globals: +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# None +# Outputs: +# Writes the output of `gcloud builds submit` to stdout, stderr +####################################### +build_test_app_docker_images() { + echo "Building Go xDS interop test app Docker images" + docker build -f "${SRC_DIR}/interop/xds/client/Dockerfile" -t "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" "${SRC_DIR}" + gcloud -q auth configure-docker + docker push "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" +} + +####################################### +# Builds test app and its docker images unless they already exist +# Globals: +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# FORCE_IMAGE_BUILD +# Arguments: +# None +# Outputs: +# Writes the output to stdout, stderr +####################################### +build_docker_images_if_needed() { + # Check if images already exist + client_tags="$(gcloud_gcr_list_image_tags "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}")" + printf "Client image: %s:%s\n" "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" + echo "${client_tags:-Client image not found}" + + # Build if any of the images are missing, or FORCE_IMAGE_BUILD=1 + if [[ "${FORCE_IMAGE_BUILD}" == "1" || -z "${client_tags}" ]]; then + build_test_app_docker_images + else + echo "Skipping Go test app build" + fi +} + +####################################### +# Executes the test case +# Globals: +# TEST_DRIVER_FLAGFILE: Relative path to test driver flagfile +# KUBE_CONTEXT: The name of kubectl context with GKE cluster access +# TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# Test case name +# Outputs: +# Writes the output of test execution to stdout, stderr +# Test xUnit report to ${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml +####################################### +run_test() { + # Test driver usage: + # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage + local test_name="${1:?Usage: run_test test_name}" + set -x + python -m "tests.${test_name}" \ + --flagfile="${TEST_DRIVER_FLAGFILE}" \ + --kube_context="${KUBE_CONTEXT}" \ + --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ + --testing_version=$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/go/([^/]+)/.*|\1|') \ + --xml_output_file="${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml" \ + --flagfile="config/url-map.cfg" + set +x +} + +####################################### +# Main function: provision software necessary to execute tests, and run them +# Globals: +# KOKORO_ARTIFACTS_DIR +# GITHUB_REPOSITORY_NAME +# SRC_DIR: Populated with absolute path to the source repo +# TEST_DRIVER_REPO_DIR: Populated with the path to the repo containing +# the test driver +# TEST_DRIVER_FULL_DIR: Populated with the path to the test driver source code +# TEST_DRIVER_FLAGFILE: Populated with relative path to test driver flagfile +# TEST_XML_OUTPUT_DIR: Populated with the path to test xUnit XML report +# GIT_ORIGIN_URL: Populated with the origin URL of git repo used for the build +# GIT_COMMIT: Populated with the SHA-1 of git commit being built +# GIT_COMMIT_SHORT: Populated with the short SHA-1 of git commit being built +# KUBE_CONTEXT: Populated with name of kubectl context with GKE cluster access +# Arguments: +# None +# Outputs: +# Writes the output of test execution to stdout, stderr +####################################### +main() { + local script_dir + script_dir="$(dirname "$0")" + + # Source the test driver from the master branch. + echo "Sourcing test driver install script from: ${TEST_DRIVER_INSTALL_SCRIPT_URL}" + source /dev/stdin <<< "$(curl -s "${TEST_DRIVER_INSTALL_SCRIPT_URL}")" + + activate_gke_cluster GKE_CLUSTER_PSM_BASIC + + set -x + if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then + kokoro_setup_test_driver "${GITHUB_REPOSITORY_NAME}" + else + local_setup_test_driver "${script_dir}" + fi + build_docker_images_if_needed + # Run tests + cd "${TEST_DRIVER_FULL_DIR}" + run_test url_map +} + +main "$@" diff --git a/test/kokoro/xds_v3.cfg b/test/kokoro/xds_v3.cfg index c4c8aad9e6f..1991efd325d 100644 --- a/test/kokoro/xds_v3.cfg +++ b/test/kokoro/xds_v3.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-go/test/kokoro/xds_v3.sh" -timeout_mins: 120 +timeout_mins: 360 action { define_artifacts { regex: "**/*sponge_log.*" diff --git a/test/race.go b/test/race.go index acfa0dfae37..d99f0a410ac 100644 --- a/test/race.go +++ b/test/race.go @@ -1,3 +1,4 @@ +//go:build race // +build race /* diff --git a/test/retry_test.go b/test/retry_test.go index f93c9ac053f..7f068d79f44 100644 --- a/test/retry_test.go +++ b/test/retry_test.go @@ -22,9 +22,11 @@ import ( "context" "fmt" "io" - "os" + "net" + "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -34,6 +36,7 @@ import ( "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" ) @@ -112,68 +115,6 @@ func (s) TestRetryUnary(t *testing.T) { } } -func (s) TestRetryDisabledByDefault(t *testing.T) { - if strings.EqualFold(os.Getenv("GRPC_GO_RETRY"), "on") { - return - } - i := -1 - ss := &stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - i++ - switch i { - case 0: - return nil, status.New(codes.AlreadyExists, "retryable error").Err() - } - return &testpb.Empty{}, nil - }, - } - if err := ss.Start([]grpc.ServerOption{}); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) - } - defer ss.Stop() - ss.NewServiceConfig(`{ - "methodConfig": [{ - "name": [{"service": "grpc.testing.TestService"}], - "waitForReady": true, - "retryPolicy": { - "MaxAttempts": 4, - "InitialBackoff": ".01s", - "MaxBackoff": ".01s", - "BackoffMultiplier": 1.0, - "RetryableStatusCodes": [ "ALREADY_EXISTS" ] - } - }]}`) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - for { - if ctx.Err() != nil { - t.Fatalf("Timed out waiting for service config update") - } - if ss.CC.GetMethodConfig("/grpc.testing.TestService/EmptyCall").WaitForReady != nil { - break - } - time.Sleep(time.Millisecond) - } - cancel() - - testCases := []struct { - code codes.Code - count int - }{ - {codes.AlreadyExists, 0}, - } - for _, tc := range testCases { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) - cancel() - if status.Code(err) != tc.code { - t.Fatalf("EmptyCall(_, _) = _, %v; want _, ", err, tc.code) - } - if i != tc.count { - t.Fatalf("i = %v; want %v", i, tc.count) - } - } -} - func (s) TestRetryThrottling(t *testing.T) { defer enableRetry()() i := -1 @@ -549,3 +490,167 @@ func (s) TestRetryStreaming(t *testing.T) { }() } } + +type retryStatsHandler struct { + mu sync.Mutex + s []stats.RPCStats +} + +func (*retryStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { + return ctx +} +func (h *retryStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) { + h.mu.Lock() + h.s = append(h.s, s) + h.mu.Unlock() +} +func (*retryStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { + return ctx +} +func (*retryStatsHandler) HandleConn(context.Context, stats.ConnStats) {} + +func (s) TestRetryStats(t *testing.T) { + defer enableRetry()() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen. Err: %v", err) + } + defer lis.Close() + server := &httpServer{ + waitForEndStream: true, + responses: []httpServerResponse{{ + trailers: [][]string{{ + ":status", "200", + "content-type", "application/grpc", + "grpc-status", "14", // UNAVAILABLE + "grpc-message", "unavailable retry", + "grpc-retry-pushback-ms", "10", + }}, + }, { + headers: [][]string{{ + ":status", "200", + "content-type", "application/grpc", + }}, + payload: []byte{0, 0, 0, 0, 0}, // header for 0-byte response message. + trailers: [][]string{{ + "grpc-status", "0", // OK + }}, + }}, + refuseStream: func(i uint32) bool { + return i == 1 + }, + } + server.start(t, lis) + handler := &retryStatsHandler{} + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithStatsHandler(handler), + grpc.WithDefaultServiceConfig((`{ + "methodConfig": [{ + "name": [{"service": "grpc.testing.TestService"}], + "retryPolicy": { + "MaxAttempts": 4, + "InitialBackoff": ".01s", + "MaxBackoff": ".01s", + "BackoffMultiplier": 1.0, + "RetryableStatusCodes": [ "UNAVAILABLE" ] + } + }]}`))) + if err != nil { + t.Fatalf("failed to dial due to err: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + client := testpb.NewTestServiceClient(cc) + + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("unexpected EmptyCall error: %v", err) + } + handler.mu.Lock() + want := []stats.RPCStats{ + &stats.Begin{}, + &stats.OutHeader{FullMethod: "/grpc.testing.TestService/EmptyCall"}, + &stats.OutPayload{WireLength: 5}, + &stats.End{}, + + &stats.Begin{IsTransparentRetryAttempt: true}, + &stats.OutHeader{FullMethod: "/grpc.testing.TestService/EmptyCall"}, + &stats.OutPayload{WireLength: 5}, + &stats.InTrailer{Trailer: metadata.Pairs("content-type", "application/grpc", "grpc-retry-pushback-ms", "10")}, + &stats.End{}, + + &stats.Begin{}, + &stats.OutHeader{FullMethod: "/grpc.testing.TestService/EmptyCall"}, + &stats.OutPayload{WireLength: 5}, + &stats.InHeader{}, + &stats.InPayload{WireLength: 5}, + &stats.InTrailer{}, + &stats.End{}, + } + + toString := func(ss []stats.RPCStats) (ret []string) { + for _, s := range ss { + ret = append(ret, fmt.Sprintf("%T - %v", s, s)) + } + return ret + } + t.Logf("Handler received frames:\n%v\n---\nwant:\n%v\n", + strings.Join(toString(handler.s), "\n"), + strings.Join(toString(want), "\n")) + + if len(handler.s) != len(want) { + t.Fatalf("received unexpected number of RPCStats: got %v; want %v", len(handler.s), len(want)) + } + + // There is a race between receiving the payload (triggered by the + // application / gRPC library) and receiving the trailer (triggered at the + // transport layer). Adjust the received stats accordingly if necessary. + const tIdx, pIdx = 13, 14 + _, okT := handler.s[tIdx].(*stats.InTrailer) + _, okP := handler.s[pIdx].(*stats.InPayload) + if okT && okP { + handler.s[pIdx], handler.s[tIdx] = handler.s[tIdx], handler.s[pIdx] + } + + for i := range handler.s { + w, s := want[i], handler.s[i] + + // Validate the event type + if reflect.TypeOf(w) != reflect.TypeOf(s) { + t.Fatalf("at position %v: got %T; want %T", i, s, w) + } + wv, sv := reflect.ValueOf(w).Elem(), reflect.ValueOf(s).Elem() + + // Validate that Client is always true + if sv.FieldByName("Client").Interface().(bool) != true { + t.Fatalf("at position %v: got Client=false; want true", i) + } + + // Validate any set fields in want + for i := 0; i < wv.NumField(); i++ { + if !wv.Field(i).IsZero() { + if got, want := sv.Field(i).Interface(), wv.Field(i).Interface(); !reflect.DeepEqual(got, want) { + name := reflect.TypeOf(w).Elem().Field(i).Name + t.Fatalf("at position %v, field %v: got %v; want %v", i, name, got, want) + } + } + } + + // Since the above only tests non-zero-value fields, test + // IsTransparentRetryAttempt=false explicitly when needed. + if wb, ok := w.(*stats.Begin); ok && !wb.IsTransparentRetryAttempt { + if s.(*stats.Begin).IsTransparentRetryAttempt { + t.Fatalf("at position %v: got IsTransparentRetryAttempt=true; want false", i) + } + } + } + + // Validate timings between last Begin and preceding End. + end := handler.s[8].(*stats.End) + begin := handler.s[9].(*stats.Begin) + diff := begin.BeginTime.Sub(end.EndTime) + if diff < 10*time.Millisecond || diff > 50*time.Millisecond { + t.Fatalf("pushback time before final attempt = %v; want ~10ms", diff) + } +} diff --git a/test/tools/go.mod b/test/tools/go.mod index 874268d34fc..9c964971413 100644 --- a/test/tools/go.mod +++ b/test/tools/go.mod @@ -1,6 +1,6 @@ module google.golang.org/grpc/test/tools -go 1.11 +go 1.14 require ( github.com/client9/misspell v0.3.4 diff --git a/test/tools/tools.go b/test/tools/tools.go index 511dc253446..646a144ccca 100644 --- a/test/tools/tools.go +++ b/test/tools/tools.go @@ -1,3 +1,4 @@ +//go:build tools // +build tools /* @@ -18,10 +19,9 @@ * */ -// This package exists to cause `go mod` and `go get` to believe these tools -// are dependencies, even though they are not runtime dependencies of any grpc -// package. This means they will appear in our `go.mod` file, but will not be -// a part of the build. +// This file is not intended to be compiled. Because some of these imports are +// not actual go packages, we use a build constraint at the top of this file to +// prevent tools from inspecting the imports. package tools diff --git a/xds/go113.go b/test/tools/tools_vet.go similarity index 75% rename from xds/go113.go rename to test/tools/tools_vet.go index 40f82cde5c1..06ab2fd10be 100644 --- a/xds/go113.go +++ b/test/tools/tools_vet.go @@ -1,8 +1,6 @@ -// +build go1.13 - /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +16,6 @@ * */ -package xds - -import ( - _ "google.golang.org/grpc/credentials/tls/certprovider/meshca" // Register the MeshCA certificate provider plugin. -) +// Package tools is used to pin specific versions of external tools in this +// module's go.mod that gRPC uses for internal testing. +package tools diff --git a/version.go b/version.go index c149b22ad8a..6ba1fd5bdb7 100644 --- a/version.go +++ b/version.go @@ -19,4 +19,4 @@ package grpc // Version is the current grpc version. -const Version = "1.37.0-dev" +const Version = "1.42.0-dev" diff --git a/vet.sh b/vet.sh index dcd939bb390..d923187a7b3 100755 --- a/vet.sh +++ b/vet.sh @@ -32,26 +32,14 @@ PATH="${HOME}/go/bin:${GOROOT}/bin:${PATH}" go version if [[ "$1" = "-install" ]]; then - # Check for module support - if go help mod >& /dev/null; then - # Install the pinned versions as defined in module tools. - pushd ./test/tools - go install \ - golang.org/x/lint/golint \ - golang.org/x/tools/cmd/goimports \ - honnef.co/go/tools/cmd/staticcheck \ - github.com/client9/misspell/cmd/misspell - popd - else - # Ye olde `go get` incantation. - # Note: this gets the latest version of all tools (vs. the pinned versions - # with Go modules). - go get -u \ - golang.org/x/lint/golint \ - golang.org/x/tools/cmd/goimports \ - honnef.co/go/tools/cmd/staticcheck \ - github.com/client9/misspell/cmd/misspell - fi + # Install the pinned versions as defined in module tools. + pushd ./test/tools + go install \ + golang.org/x/lint/golint \ + golang.org/x/tools/cmd/goimports \ + honnef.co/go/tools/cmd/staticcheck \ + github.com/client9/misspell/cmd/misspell + popd if [[ -z "${VET_SKIP_PROTO}" ]]; then if [[ "${TRAVIS}" = "true" ]]; then PROTOBUF_VERSION=3.14.0 @@ -101,16 +89,6 @@ not git grep "\(import \|^\s*\)\"github.com/golang/protobuf/ptypes/" -- "*.go" # - Ensure all xds proto imports are renamed to *pb or *grpc. git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.pb.go' | not grep -v 'pb "\|grpc "' -# - Check imports that are illegal in appengine (until Go 1.11). -# TODO: Remove when we drop Go 1.10 support -go list -f {{.Dir}} ./... | xargs go run test/go_vet/vet.go - -# - gofmt, goimports, golint (with exceptions for generated code), go vet. -gofmt -s -d -l . 2>&1 | fail_on_output -goimports -l . 2>&1 | not grep -vE "\.pb\.go" -golint ./... 2>&1 | not grep -vE "/testv3\.pb\.go:" -go vet -all ./... - misspell -error . # - Check that generated proto files are up to date. @@ -120,12 +98,22 @@ if [[ -z "${VET_SKIP_PROTO}" ]]; then (git status; git --no-pager diff; exit 1) fi -# - Check that our modules are tidy. -if go help mod >& /dev/null; then - find . -name 'go.mod' | xargs -IXXX bash -c 'cd $(dirname XXX); go mod tidy' +# - gofmt, goimports, golint (with exceptions for generated code), go vet, +# go mod tidy. +# Perform these checks on each module inside gRPC. +for MOD_FILE in $(find . -name 'go.mod'); do + MOD_DIR=$(dirname ${MOD_FILE}) + pushd ${MOD_DIR} + go vet -all ./... | fail_on_output + gofmt -s -d -l . 2>&1 | fail_on_output + goimports -l . 2>&1 | not grep -vE "\.pb\.go" + golint ./... 2>&1 | not grep -vE "/testv3\.pb\.go:" + + go mod tidy git status --porcelain 2>&1 | fail_on_output || \ (git status; git --no-pager diff; exit 1) -fi + popd +done # - Collection of static analysis checks # diff --git a/xds/csds/csds.go b/xds/csds/csds.go new file mode 100644 index 00000000000..c4477a55d1a --- /dev/null +++ b/xds/csds/csds.go @@ -0,0 +1,305 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package csds implements features to dump the status (xDS responses) the +// xds_client is using. +// +// Notice: This package is EXPERIMENTAL and may be changed or removed in a later +// release. +package csds + +import ( + "context" + "io" + "time" + + v3adminpb "github.com/envoyproxy/go-control-plane/envoy/admin/v3" + v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3statusgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + v3statuspb "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + "github.com/golang/protobuf/proto" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/protobuf/types/known/timestamppb" + + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register v2 xds_client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register v3 xds_client. +) + +var ( + logger = grpclog.Component("xds") + newXDSClient = func() xdsclient.XDSClient { + c, err := xdsclient.New() + if err != nil { + logger.Warningf("failed to create xds client: %v", err) + return nil + } + return c + } +) + +// ClientStatusDiscoveryServer implementations interface ClientStatusDiscoveryServiceServer. +type ClientStatusDiscoveryServer struct { + // xdsClient will always be the same in practice. But we keep a copy in each + // server instance for testing. + xdsClient xdsclient.XDSClient +} + +// NewClientStatusDiscoveryServer returns an implementation of the CSDS server that can be +// registered on a gRPC server. +func NewClientStatusDiscoveryServer() (*ClientStatusDiscoveryServer, error) { + return &ClientStatusDiscoveryServer{xdsClient: newXDSClient()}, nil +} + +// StreamClientStatus implementations interface ClientStatusDiscoveryServiceServer. +func (s *ClientStatusDiscoveryServer) StreamClientStatus(stream v3statusgrpc.ClientStatusDiscoveryService_StreamClientStatusServer) error { + for { + req, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + resp, err := s.buildClientStatusRespForReq(req) + if err != nil { + return err + } + if err := stream.Send(resp); err != nil { + return err + } + } +} + +// FetchClientStatus implementations interface ClientStatusDiscoveryServiceServer. +func (s *ClientStatusDiscoveryServer) FetchClientStatus(_ context.Context, req *v3statuspb.ClientStatusRequest) (*v3statuspb.ClientStatusResponse, error) { + return s.buildClientStatusRespForReq(req) +} + +// buildClientStatusRespForReq fetches the status from the client, and returns +// the response to be sent back to xdsclient. +// +// If it returns an error, the error is a status error. +func (s *ClientStatusDiscoveryServer) buildClientStatusRespForReq(req *v3statuspb.ClientStatusRequest) (*v3statuspb.ClientStatusResponse, error) { + if s.xdsClient == nil { + return &v3statuspb.ClientStatusResponse{}, nil + } + // Field NodeMatchers is unsupported, by design + // https://github.com/grpc/proposal/blob/master/A40-csds-support.md#detail-node-matching. + if len(req.NodeMatchers) != 0 { + return nil, status.Errorf(codes.InvalidArgument, "node_matchers are not supported, request contains node_matchers: %v", req.NodeMatchers) + } + + ret := &v3statuspb.ClientStatusResponse{ + Config: []*v3statuspb.ClientConfig{ + { + Node: nodeProtoToV3(s.xdsClient.BootstrapConfig().NodeProto), + XdsConfig: []*v3statuspb.PerXdsConfig{ + s.buildLDSPerXDSConfig(), + s.buildRDSPerXDSConfig(), + s.buildCDSPerXDSConfig(), + s.buildEDSPerXDSConfig(), + }, + }, + }, + } + return ret, nil +} + +// Close cleans up the resources. +func (s *ClientStatusDiscoveryServer) Close() { + if s.xdsClient != nil { + s.xdsClient.Close() + } +} + +// nodeProtoToV3 converts the given proto into a v3.Node. n is from bootstrap +// config, it can be either v2.Node or v3.Node. +// +// If n is already a v3.Node, return it. +// If n is v2.Node, marshal and unmarshal it to v3. +// Otherwise, return nil. +// +// The default case (not v2 or v3) is nil, instead of error, because the +// resources in the response are more important than the node. The worst case is +// that the user will receive no Node info, but will still get resources. +func nodeProtoToV3(n proto.Message) *v3corepb.Node { + var node *v3corepb.Node + switch nn := n.(type) { + case *v3corepb.Node: + node = nn + case *v2corepb.Node: + v2, err := proto.Marshal(nn) + if err != nil { + logger.Warningf("Failed to marshal node (%v): %v", n, err) + break + } + node = new(v3corepb.Node) + if err := proto.Unmarshal(v2, node); err != nil { + logger.Warningf("Failed to unmarshal node (%v): %v", v2, err) + } + default: + logger.Warningf("node from bootstrap is %#v, only v2.Node and v3.Node are supported", nn) + } + return node +} + +func (s *ClientStatusDiscoveryServer) buildLDSPerXDSConfig() *v3statuspb.PerXdsConfig { + version, dump := s.xdsClient.DumpLDS() + resources := make([]*v3adminpb.ListenersConfigDump_DynamicListener, 0, len(dump)) + for name, d := range dump { + configDump := &v3adminpb.ListenersConfigDump_DynamicListener{ + Name: name, + ClientStatus: serviceStatusToProto(d.MD.Status), + } + if (d.MD.Timestamp != time.Time{}) { + configDump.ActiveState = &v3adminpb.ListenersConfigDump_DynamicListenerState{ + VersionInfo: d.MD.Version, + Listener: d.Raw, + LastUpdated: timestamppb.New(d.MD.Timestamp), + } + } + if errState := d.MD.ErrState; errState != nil { + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + LastUpdateAttempt: timestamppb.New(errState.Timestamp), + Details: errState.Err.Error(), + VersionInfo: errState.Version, + } + } + resources = append(resources, configDump) + } + return &v3statuspb.PerXdsConfig{ + PerXdsConfig: &v3statuspb.PerXdsConfig_ListenerConfig{ + ListenerConfig: &v3adminpb.ListenersConfigDump{ + VersionInfo: version, + DynamicListeners: resources, + }, + }, + } +} + +func (s *ClientStatusDiscoveryServer) buildRDSPerXDSConfig() *v3statuspb.PerXdsConfig { + _, dump := s.xdsClient.DumpRDS() + resources := make([]*v3adminpb.RoutesConfigDump_DynamicRouteConfig, 0, len(dump)) + for _, d := range dump { + configDump := &v3adminpb.RoutesConfigDump_DynamicRouteConfig{ + VersionInfo: d.MD.Version, + ClientStatus: serviceStatusToProto(d.MD.Status), + } + if (d.MD.Timestamp != time.Time{}) { + configDump.RouteConfig = d.Raw + configDump.LastUpdated = timestamppb.New(d.MD.Timestamp) + } + if errState := d.MD.ErrState; errState != nil { + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + LastUpdateAttempt: timestamppb.New(errState.Timestamp), + Details: errState.Err.Error(), + VersionInfo: errState.Version, + } + } + resources = append(resources, configDump) + } + return &v3statuspb.PerXdsConfig{ + PerXdsConfig: &v3statuspb.PerXdsConfig_RouteConfig{ + RouteConfig: &v3adminpb.RoutesConfigDump{ + DynamicRouteConfigs: resources, + }, + }, + } +} + +func (s *ClientStatusDiscoveryServer) buildCDSPerXDSConfig() *v3statuspb.PerXdsConfig { + version, dump := s.xdsClient.DumpCDS() + resources := make([]*v3adminpb.ClustersConfigDump_DynamicCluster, 0, len(dump)) + for _, d := range dump { + configDump := &v3adminpb.ClustersConfigDump_DynamicCluster{ + VersionInfo: d.MD.Version, + ClientStatus: serviceStatusToProto(d.MD.Status), + } + if (d.MD.Timestamp != time.Time{}) { + configDump.Cluster = d.Raw + configDump.LastUpdated = timestamppb.New(d.MD.Timestamp) + } + if errState := d.MD.ErrState; errState != nil { + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + LastUpdateAttempt: timestamppb.New(errState.Timestamp), + Details: errState.Err.Error(), + VersionInfo: errState.Version, + } + } + resources = append(resources, configDump) + } + return &v3statuspb.PerXdsConfig{ + PerXdsConfig: &v3statuspb.PerXdsConfig_ClusterConfig{ + ClusterConfig: &v3adminpb.ClustersConfigDump{ + VersionInfo: version, + DynamicActiveClusters: resources, + }, + }, + } +} + +func (s *ClientStatusDiscoveryServer) buildEDSPerXDSConfig() *v3statuspb.PerXdsConfig { + _, dump := s.xdsClient.DumpEDS() + resources := make([]*v3adminpb.EndpointsConfigDump_DynamicEndpointConfig, 0, len(dump)) + for _, d := range dump { + configDump := &v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{ + VersionInfo: d.MD.Version, + ClientStatus: serviceStatusToProto(d.MD.Status), + } + if (d.MD.Timestamp != time.Time{}) { + configDump.EndpointConfig = d.Raw + configDump.LastUpdated = timestamppb.New(d.MD.Timestamp) + } + if errState := d.MD.ErrState; errState != nil { + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + LastUpdateAttempt: timestamppb.New(errState.Timestamp), + Details: errState.Err.Error(), + VersionInfo: errState.Version, + } + } + resources = append(resources, configDump) + } + return &v3statuspb.PerXdsConfig{ + PerXdsConfig: &v3statuspb.PerXdsConfig_EndpointConfig{ + EndpointConfig: &v3adminpb.EndpointsConfigDump{ + DynamicEndpointConfigs: resources, + }, + }, + } +} + +func serviceStatusToProto(serviceStatus xdsclient.ServiceStatus) v3adminpb.ClientResourceStatus { + switch serviceStatus { + case xdsclient.ServiceStatusUnknown: + return v3adminpb.ClientResourceStatus_UNKNOWN + case xdsclient.ServiceStatusRequested: + return v3adminpb.ClientResourceStatus_REQUESTED + case xdsclient.ServiceStatusNotExist: + return v3adminpb.ClientResourceStatus_DOES_NOT_EXIST + case xdsclient.ServiceStatusACKed: + return v3adminpb.ClientResourceStatus_ACKED + case xdsclient.ServiceStatusNACKed: + return v3adminpb.ClientResourceStatus_NACKED + default: + return v3adminpb.ClientResourceStatus_UNKNOWN + } +} diff --git a/xds/csds/csds_test.go b/xds/csds/csds_test.go new file mode 100644 index 00000000000..9de83d37fec --- /dev/null +++ b/xds/csds/csds_test.go @@ -0,0 +1,739 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package csds + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds" + _ "google.golang.org/grpc/xds/internal/httpfilter/router" + xtestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" + + v3adminpb "github.com/envoyproxy/go-control-plane/envoy/admin/v3" + v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3statuspb "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + v3statuspbgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" +) + +const ( + defaultTestTimeout = 10 * time.Second +) + +var cmpOpts = cmp.Options{ + cmpopts.EquateEmpty(), + cmp.Comparer(func(a, b *timestamppb.Timestamp) bool { return true }), + protocmp.IgnoreFields(&v3adminpb.UpdateFailureState{}, "last_update_attempt", "details"), + protocmp.SortRepeated(func(a, b *v3adminpb.ListenersConfigDump_DynamicListener) bool { + return strings.Compare(a.Name, b.Name) < 0 + }), + protocmp.SortRepeated(func(a, b *v3adminpb.RoutesConfigDump_DynamicRouteConfig) bool { + if a.RouteConfig == nil { + return false + } + if b.RouteConfig == nil { + return true + } + var at, bt v3routepb.RouteConfiguration + if err := ptypes.UnmarshalAny(a.RouteConfig, &at); err != nil { + panic("failed to unmarshal RouteConfig" + err.Error()) + } + if err := ptypes.UnmarshalAny(b.RouteConfig, &bt); err != nil { + panic("failed to unmarshal RouteConfig" + err.Error()) + } + return strings.Compare(at.Name, bt.Name) < 0 + }), + protocmp.SortRepeated(func(a, b *v3adminpb.ClustersConfigDump_DynamicCluster) bool { + if a.Cluster == nil { + return false + } + if b.Cluster == nil { + return true + } + var at, bt v3clusterpb.Cluster + if err := ptypes.UnmarshalAny(a.Cluster, &at); err != nil { + panic("failed to unmarshal Cluster" + err.Error()) + } + if err := ptypes.UnmarshalAny(b.Cluster, &bt); err != nil { + panic("failed to unmarshal Cluster" + err.Error()) + } + return strings.Compare(at.Name, bt.Name) < 0 + }), + protocmp.SortRepeated(func(a, b *v3adminpb.EndpointsConfigDump_DynamicEndpointConfig) bool { + if a.EndpointConfig == nil { + return false + } + if b.EndpointConfig == nil { + return true + } + var at, bt v3endpointpb.ClusterLoadAssignment + if err := ptypes.UnmarshalAny(a.EndpointConfig, &at); err != nil { + panic("failed to unmarshal Endpoints" + err.Error()) + } + if err := ptypes.UnmarshalAny(b.EndpointConfig, &bt); err != nil { + panic("failed to unmarshal Endpoints" + err.Error()) + } + return strings.Compare(at.ClusterName, bt.ClusterName) < 0 + }), + protocmp.IgnoreFields(&v3adminpb.ListenersConfigDump_DynamicListenerState{}, "last_updated"), + protocmp.IgnoreFields(&v3adminpb.RoutesConfigDump_DynamicRouteConfig{}, "last_updated"), + protocmp.IgnoreFields(&v3adminpb.ClustersConfigDump_DynamicCluster{}, "last_updated"), + protocmp.IgnoreFields(&v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{}, "last_updated"), + protocmp.Transform(), +} + +var ( + ldsTargets = []string{"lds.target.good:0000", "lds.target.good:1111"} + listeners = make([]*v3listenerpb.Listener, len(ldsTargets)) + listenerAnys = make([]*anypb.Any, len(ldsTargets)) + + rdsTargets = []string{"route-config-0", "route-config-1"} + routes = make([]*v3routepb.RouteConfiguration, len(rdsTargets)) + routeAnys = make([]*anypb.Any, len(rdsTargets)) + + cdsTargets = []string{"cluster-0", "cluster-1"} + clusters = make([]*v3clusterpb.Cluster, len(cdsTargets)) + clusterAnys = make([]*anypb.Any, len(cdsTargets)) + + edsTargets = []string{"endpoints-0", "endpoints-1"} + endpoints = make([]*v3endpointpb.ClusterLoadAssignment, len(edsTargets)) + endpointAnys = make([]*anypb.Any, len(edsTargets)) + ips = []string{"0.0.0.0", "1.1.1.1"} + ports = []uint32{123, 456} +) + +func init() { + for i := range ldsTargets { + listeners[i] = e2e.DefaultClientListener(ldsTargets[i], rdsTargets[i]) + listenerAnys[i] = testutils.MarshalAny(listeners[i]) + } + for i := range rdsTargets { + routes[i] = e2e.DefaultRouteConfig(rdsTargets[i], ldsTargets[i], cdsTargets[i]) + routeAnys[i] = testutils.MarshalAny(routes[i]) + } + for i := range cdsTargets { + clusters[i] = e2e.DefaultCluster(cdsTargets[i], edsTargets[i], e2e.SecurityLevelNone) + clusterAnys[i] = testutils.MarshalAny(clusters[i]) + } + for i := range edsTargets { + endpoints[i] = e2e.DefaultEndpoint(edsTargets[i], ips[i], ports[i:i+1]) + endpointAnys[i] = testutils.MarshalAny(endpoints[i]) + } +} + +func TestCSDS(t *testing.T) { + const retryCount = 10 + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + xdsC, mgmServer, nodeID, stream, cleanup := commonSetup(ctx, t) + defer cleanup() + + for _, target := range ldsTargets { + xdsC.WatchListener(target, func(xdsclient.ListenerUpdate, error) {}) + } + for _, target := range rdsTargets { + xdsC.WatchRouteConfig(target, func(xdsclient.RouteConfigUpdate, error) {}) + } + for _, target := range cdsTargets { + xdsC.WatchCluster(target, func(xdsclient.ClusterUpdate, error) {}) + } + for _, target := range edsTargets { + xdsC.WatchEndpoints(target, func(xdsclient.EndpointsUpdate, error) {}) + } + + for i := 0; i < retryCount; i++ { + err := checkForRequested(stream) + if err == nil { + break + } + if i == retryCount-1 { + t.Fatalf("%v", err) + } + time.Sleep(time.Millisecond * 100) + } + + if err := mgmServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: listeners, + Routes: routes, + Clusters: clusters, + Endpoints: endpoints, + }); err != nil { + t.Fatal(err) + } + for i := 0; i < retryCount; i++ { + err := checkForACKed(stream) + if err == nil { + break + } + if i == retryCount-1 { + t.Fatalf("%v", err) + } + time.Sleep(time.Millisecond * 100) + } + + const nackResourceIdx = 0 + var ( + nackListeners = append([]*v3listenerpb.Listener{}, listeners...) + nackRoutes = append([]*v3routepb.RouteConfiguration{}, routes...) + nackClusters = append([]*v3clusterpb.Cluster{}, clusters...) + nackEndpoints = append([]*v3endpointpb.ClusterLoadAssignment{}, endpoints...) + ) + nackListeners[0] = &v3listenerpb.Listener{Name: ldsTargets[nackResourceIdx], ApiListener: &v3listenerpb.ApiListener{}} // 0 will be nacked. 1 will stay the same. + nackRoutes[0] = &v3routepb.RouteConfiguration{ + Name: rdsTargets[nackResourceIdx], VirtualHosts: []*v3routepb.VirtualHost{{Routes: []*v3routepb.Route{{}}}}, + } + nackClusters[0] = &v3clusterpb.Cluster{ + Name: cdsTargets[nackResourceIdx], ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, + } + nackEndpoints[0] = &v3endpointpb.ClusterLoadAssignment{ + ClusterName: edsTargets[nackResourceIdx], Endpoints: []*v3endpointpb.LocalityLbEndpoints{{}}, + } + if err := mgmServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: nackListeners, + Routes: nackRoutes, + Clusters: nackClusters, + Endpoints: nackEndpoints, + SkipValidation: true, + }); err != nil { + t.Fatal(err) + } + for i := 0; i < retryCount; i++ { + err := checkForNACKed(nackResourceIdx, stream) + if err == nil { + break + } + if i == retryCount-1 { + t.Fatalf("%v", err) + } + time.Sleep(time.Millisecond * 100) + } +} + +func commonSetup(ctx context.Context, t *testing.T) (xdsclient.XDSClient, *e2e.ManagementServer, string, v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient, func()) { + t.Helper() + + // Spin up a xDS management server on a local port. + nodeID := uuid.New().String() + fs, err := e2e.StartManagementServer() + if err != nil { + t.Fatal(err) + } + + // Create a bootstrap file in a temporary directory. + bootstrapCleanup, err := xds.SetupBootstrapFile(xds.BootstrapOptions{ + Version: xds.TransportV3, + NodeID: nodeID, + ServerURI: fs.Address, + }) + if err != nil { + t.Fatal(err) + } + // Create xds_client. + xdsC, err := xdsclient.New() + if err != nil { + t.Fatalf("failed to create xds client: %v", err) + } + oldNewXDSClient := newXDSClient + newXDSClient = func() xdsclient.XDSClient { return xdsC } + + // Initialize an gRPC server and register CSDS on it. + server := grpc.NewServer() + csdss, err := NewClientStatusDiscoveryServer() + if err != nil { + t.Fatal(err) + } + v3statuspbgrpc.RegisterClientStatusDiscoveryServiceServer(server, csdss) + // Create a local listener and pass it to Serve(). + lis, err := xtestutils.LocalTCPListener() + if err != nil { + t.Fatalf("xtestutils.LocalTCPListener() failed: %v", err) + } + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + // Create CSDS client. + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("cannot connect to server: %v", err) + } + c := v3statuspbgrpc.NewClientStatusDiscoveryServiceClient(conn) + stream, err := c.StreamClientStatus(ctx, grpc.WaitForReady(true)) + if err != nil { + t.Fatalf("cannot get ServerReflectionInfo: %v", err) + } + + return xdsC, fs, nodeID, stream, func() { + fs.Stop() + conn.Close() + server.Stop() + csdss.Close() + newXDSClient = oldNewXDSClient + xdsC.Close() + bootstrapCleanup() + } +} + +func checkForRequested(stream v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient) error { + if err := stream.Send(&v3statuspb.ClientStatusRequest{Node: nil}); err != nil { + return fmt.Errorf("failed to send request: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + return fmt.Errorf("failed to recv response: %v", err) + } + + if n := len(r.Config); n != 1 { + return fmt.Errorf("got %d configs, want 1: %v", n, proto.MarshalTextString(r)) + } + if n := len(r.Config[0].XdsConfig); n != 4 { + return fmt.Errorf("got %d xds configs (one for each type), want 4: %v", n, proto.MarshalTextString(r)) + } + for _, cfg := range r.Config[0].XdsConfig { + switch config := cfg.PerXdsConfig.(type) { + case *v3statuspb.PerXdsConfig_ListenerConfig: + var wantLis []*v3adminpb.ListenersConfigDump_DynamicListener + for i := range ldsTargets { + wantLis = append(wantLis, &v3adminpb.ListenersConfigDump_DynamicListener{ + Name: ldsTargets[i], + ClientStatus: v3adminpb.ClientResourceStatus_REQUESTED, + }) + } + wantDump := &v3adminpb.ListenersConfigDump{ + DynamicListeners: wantLis, + } + if diff := cmp.Diff(config.ListenerConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_RouteConfig: + var wantRoutes []*v3adminpb.RoutesConfigDump_DynamicRouteConfig + for range rdsTargets { + wantRoutes = append(wantRoutes, &v3adminpb.RoutesConfigDump_DynamicRouteConfig{ + ClientStatus: v3adminpb.ClientResourceStatus_REQUESTED, + }) + } + wantDump := &v3adminpb.RoutesConfigDump{ + DynamicRouteConfigs: wantRoutes, + } + if diff := cmp.Diff(config.RouteConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_ClusterConfig: + var wantCluster []*v3adminpb.ClustersConfigDump_DynamicCluster + for range cdsTargets { + wantCluster = append(wantCluster, &v3adminpb.ClustersConfigDump_DynamicCluster{ + ClientStatus: v3adminpb.ClientResourceStatus_REQUESTED, + }) + } + wantDump := &v3adminpb.ClustersConfigDump{ + DynamicActiveClusters: wantCluster, + } + if diff := cmp.Diff(config.ClusterConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_EndpointConfig: + var wantEndpoint []*v3adminpb.EndpointsConfigDump_DynamicEndpointConfig + for range cdsTargets { + wantEndpoint = append(wantEndpoint, &v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{ + ClientStatus: v3adminpb.ClientResourceStatus_REQUESTED, + }) + } + wantDump := &v3adminpb.EndpointsConfigDump{ + DynamicEndpointConfigs: wantEndpoint, + } + if diff := cmp.Diff(config.EndpointConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + default: + return fmt.Errorf("unexpected PerXdsConfig: %+v; %v", cfg.PerXdsConfig, protoToJSON(r)) + } + } + return nil +} + +func checkForACKed(stream v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient) error { + const wantVersion = "1" + + if err := stream.Send(&v3statuspb.ClientStatusRequest{Node: nil}); err != nil { + return fmt.Errorf("failed to send: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + return fmt.Errorf("failed to recv response: %v", err) + } + + if n := len(r.Config); n != 1 { + return fmt.Errorf("got %d configs, want 1: %v", n, proto.MarshalTextString(r)) + } + if n := len(r.Config[0].XdsConfig); n != 4 { + return fmt.Errorf("got %d xds configs (one for each type), want 4: %v", n, proto.MarshalTextString(r)) + } + for _, cfg := range r.Config[0].XdsConfig { + switch config := cfg.PerXdsConfig.(type) { + case *v3statuspb.PerXdsConfig_ListenerConfig: + var wantLis []*v3adminpb.ListenersConfigDump_DynamicListener + for i := range ldsTargets { + wantLis = append(wantLis, &v3adminpb.ListenersConfigDump_DynamicListener{ + Name: ldsTargets[i], + ActiveState: &v3adminpb.ListenersConfigDump_DynamicListenerState{ + VersionInfo: wantVersion, + Listener: listenerAnys[i], + LastUpdated: nil, + }, + ErrorState: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + }) + } + wantDump := &v3adminpb.ListenersConfigDump{ + VersionInfo: wantVersion, + DynamicListeners: wantLis, + } + if diff := cmp.Diff(config.ListenerConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_RouteConfig: + var wantRoutes []*v3adminpb.RoutesConfigDump_DynamicRouteConfig + for i := range rdsTargets { + wantRoutes = append(wantRoutes, &v3adminpb.RoutesConfigDump_DynamicRouteConfig{ + VersionInfo: wantVersion, + RouteConfig: routeAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + }) + } + wantDump := &v3adminpb.RoutesConfigDump{ + DynamicRouteConfigs: wantRoutes, + } + if diff := cmp.Diff(config.RouteConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_ClusterConfig: + var wantCluster []*v3adminpb.ClustersConfigDump_DynamicCluster + for i := range cdsTargets { + wantCluster = append(wantCluster, &v3adminpb.ClustersConfigDump_DynamicCluster{ + VersionInfo: wantVersion, + Cluster: clusterAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + }) + } + wantDump := &v3adminpb.ClustersConfigDump{ + VersionInfo: wantVersion, + DynamicActiveClusters: wantCluster, + } + if diff := cmp.Diff(config.ClusterConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_EndpointConfig: + var wantEndpoint []*v3adminpb.EndpointsConfigDump_DynamicEndpointConfig + for i := range cdsTargets { + wantEndpoint = append(wantEndpoint, &v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{ + VersionInfo: wantVersion, + EndpointConfig: endpointAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + }) + } + wantDump := &v3adminpb.EndpointsConfigDump{ + DynamicEndpointConfigs: wantEndpoint, + } + if diff := cmp.Diff(config.EndpointConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + default: + return fmt.Errorf("unexpected PerXdsConfig: %+v; %v", cfg.PerXdsConfig, protoToJSON(r)) + } + } + return nil +} + +func checkForNACKed(nackResourceIdx int, stream v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient) error { + const ( + ackVersion = "1" + nackVersion = "2" + ) + if err := stream.Send(&v3statuspb.ClientStatusRequest{Node: nil}); err != nil { + return fmt.Errorf("failed to send: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + return fmt.Errorf("failed to recv response: %v", err) + } + + if n := len(r.Config); n != 1 { + return fmt.Errorf("got %d configs, want 1: %v", n, proto.MarshalTextString(r)) + } + if n := len(r.Config[0].XdsConfig); n != 4 { + return fmt.Errorf("got %d xds configs (one for each type), want 4: %v", n, proto.MarshalTextString(r)) + } + for _, cfg := range r.Config[0].XdsConfig { + switch config := cfg.PerXdsConfig.(type) { + case *v3statuspb.PerXdsConfig_ListenerConfig: + var wantLis []*v3adminpb.ListenersConfigDump_DynamicListener + for i := range ldsTargets { + configDump := &v3adminpb.ListenersConfigDump_DynamicListener{ + Name: ldsTargets[i], + ActiveState: &v3adminpb.ListenersConfigDump_DynamicListenerState{ + VersionInfo: nackVersion, + Listener: listenerAnys[i], + LastUpdated: nil, + }, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + } + if i == nackResourceIdx { + configDump.ActiveState.VersionInfo = ackVersion + configDump.ClientStatus = v3adminpb.ClientResourceStatus_NACKED + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + Details: "blahblah", + VersionInfo: nackVersion, + } + } + wantLis = append(wantLis, configDump) + } + wantDump := &v3adminpb.ListenersConfigDump{ + VersionInfo: nackVersion, + DynamicListeners: wantLis, + } + if diff := cmp.Diff(config.ListenerConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_RouteConfig: + var wantRoutes []*v3adminpb.RoutesConfigDump_DynamicRouteConfig + for i := range rdsTargets { + configDump := &v3adminpb.RoutesConfigDump_DynamicRouteConfig{ + VersionInfo: nackVersion, + RouteConfig: routeAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + } + if i == nackResourceIdx { + configDump.VersionInfo = ackVersion + configDump.ClientStatus = v3adminpb.ClientResourceStatus_NACKED + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + Details: "blahblah", + VersionInfo: nackVersion, + } + } + wantRoutes = append(wantRoutes, configDump) + } + wantDump := &v3adminpb.RoutesConfigDump{ + DynamicRouteConfigs: wantRoutes, + } + if diff := cmp.Diff(config.RouteConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_ClusterConfig: + var wantCluster []*v3adminpb.ClustersConfigDump_DynamicCluster + for i := range cdsTargets { + configDump := &v3adminpb.ClustersConfigDump_DynamicCluster{ + VersionInfo: nackVersion, + Cluster: clusterAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + } + if i == nackResourceIdx { + configDump.VersionInfo = ackVersion + configDump.ClientStatus = v3adminpb.ClientResourceStatus_NACKED + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + Details: "blahblah", + VersionInfo: nackVersion, + } + } + wantCluster = append(wantCluster, configDump) + } + wantDump := &v3adminpb.ClustersConfigDump{ + VersionInfo: nackVersion, + DynamicActiveClusters: wantCluster, + } + if diff := cmp.Diff(config.ClusterConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + case *v3statuspb.PerXdsConfig_EndpointConfig: + var wantEndpoint []*v3adminpb.EndpointsConfigDump_DynamicEndpointConfig + for i := range cdsTargets { + configDump := &v3adminpb.EndpointsConfigDump_DynamicEndpointConfig{ + VersionInfo: nackVersion, + EndpointConfig: endpointAnys[i], + LastUpdated: nil, + ClientStatus: v3adminpb.ClientResourceStatus_ACKED, + } + if i == nackResourceIdx { + configDump.VersionInfo = ackVersion + configDump.ClientStatus = v3adminpb.ClientResourceStatus_NACKED + configDump.ErrorState = &v3adminpb.UpdateFailureState{ + Details: "blahblah", + VersionInfo: nackVersion, + } + } + wantEndpoint = append(wantEndpoint, configDump) + } + wantDump := &v3adminpb.EndpointsConfigDump{ + DynamicEndpointConfigs: wantEndpoint, + } + if diff := cmp.Diff(config.EndpointConfig, wantDump, cmpOpts); diff != "" { + return fmt.Errorf(diff) + } + default: + return fmt.Errorf("unexpected PerXdsConfig: %+v; %v", cfg.PerXdsConfig, protoToJSON(r)) + } + } + return nil +} + +func protoToJSON(p proto.Message) string { + mm := jsonpb.Marshaler{ + Indent: " ", + } + ret, _ := mm.MarshalToString(p) + return ret +} + +func TestCSDSNoXDSClient(t *testing.T) { + oldNewXDSClient := newXDSClient + newXDSClient = func() xdsclient.XDSClient { return nil } + defer func() { newXDSClient = oldNewXDSClient }() + + // Initialize an gRPC server and register CSDS on it. + server := grpc.NewServer() + csdss, err := NewClientStatusDiscoveryServer() + if err != nil { + t.Fatal(err) + } + defer csdss.Close() + v3statuspbgrpc.RegisterClientStatusDiscoveryServiceServer(server, csdss) + // Create a local listener and pass it to Serve(). + lis, err := xtestutils.LocalTCPListener() + if err != nil { + t.Fatalf("xtestutils.LocalTCPListener() failed: %v", err) + } + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + defer server.Stop() + + // Create CSDS client. + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("cannot connect to server: %v", err) + } + defer conn.Close() + c := v3statuspbgrpc.NewClientStatusDiscoveryServiceClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := c.StreamClientStatus(ctx, grpc.WaitForReady(true)) + if err != nil { + t.Fatalf("cannot get ServerReflectionInfo: %v", err) + } + + if err := stream.Send(&v3statuspb.ClientStatusRequest{Node: nil}); err != nil { + t.Fatalf("failed to send: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + t.Fatalf("failed to recv response: %v", err) + } + if n := len(r.Config); n != 0 { + t.Fatalf("got %d configs, want 0: %v", n, proto.MarshalTextString(r)) + } +} + +func Test_nodeProtoToV3(t *testing.T) { + const ( + testID = "test-id" + testCluster = "test-cluster" + testZone = "test-zone" + ) + tests := []struct { + name string + n proto.Message + want *v3corepb.Node + }{ + { + name: "v3", + n: &v3corepb.Node{ + Id: testID, + Cluster: testCluster, + Locality: &v3corepb.Locality{Zone: testZone}, + }, + want: &v3corepb.Node{ + Id: testID, + Cluster: testCluster, + Locality: &v3corepb.Locality{Zone: testZone}, + }, + }, + { + name: "v2", + n: &v2corepb.Node{ + Id: testID, + Cluster: testCluster, + Locality: &v2corepb.Locality{Zone: testZone}, + }, + want: &v3corepb.Node{ + Id: testID, + Cluster: testCluster, + Locality: &v3corepb.Locality{Zone: testZone}, + }, + }, + { + name: "not node", + n: &v2corepb.Locality{Zone: testZone}, + want: nil, // Input is not a node, should return nil. + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := nodeProtoToV3(tt.n) + if diff := cmp.Diff(got, tt.want, protocmp.Transform()); diff != "" { + t.Errorf("nodeProtoToV3() got unexpected result, diff (-got, +want): %v", diff) + } + }) + } +} diff --git a/xds/googledirectpath/googlec2p.go b/xds/googledirectpath/googlec2p.go new file mode 100644 index 00000000000..b9f1c712014 --- /dev/null +++ b/xds/googledirectpath/googlec2p.go @@ -0,0 +1,178 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package googledirectpath implements a resolver that configures xds to make +// cloud to prod directpath connection. +// +// It's a combo of DNS and xDS resolvers. It delegates to DNS if +// - not on GCE, or +// - xDS bootstrap env var is set (so this client needs to do normal xDS, not +// direct path, and clients with this scheme is not part of the xDS mesh). +package googledirectpath + +import ( + "fmt" + "time" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/google" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/googlecloud" + internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/resolver" + _ "google.golang.org/grpc/xds" // To register xds resolvers and balancers. + "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/protobuf/types/known/structpb" +) + +const ( + c2pScheme = "google-c2p" + + tdURL = "directpath-pa.googleapis.com" + httpReqTimeout = 10 * time.Second + zoneURL = "http://metadata.google.internal/computeMetadata/v1/instance/zone" + ipv6URL = "http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/ipv6s" + + gRPCUserAgentName = "gRPC Go" + clientFeatureNoOverprovisioning = "envoy.lb.does_not_support_overprovisioning" + ipv6CapableMetadataName = "TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE" + + logPrefix = "[google-c2p-resolver]" + + dnsName, xdsName = "dns", "xds" +) + +// For overriding in unittests. +var ( + onGCE = googlecloud.OnGCE + + newClientWithConfig = func(config *bootstrap.Config) (xdsclient.XDSClient, error) { + return xdsclient.NewWithConfig(config) + } + + logger = internalgrpclog.NewPrefixLogger(grpclog.Component("directpath"), logPrefix) +) + +func init() { + if env.C2PResolverSupport { + resolver.Register(c2pResolverBuilder{}) + } +} + +type c2pResolverBuilder struct{} + +func (c2pResolverBuilder) Build(t resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { + if !runDirectPath() { + // If not xDS, fallback to DNS. + t.Scheme = dnsName + return resolver.Get(dnsName).Build(t, cc, opts) + } + + // Note that the following calls to getZone() and getIPv6Capable() does I/O, + // and has 10 seconds timeout each. + // + // This should be fine in most of the cases. In certain error cases, this + // could block Dial() for up to 10 seconds (each blocking call has its own + // goroutine). + zoneCh, ipv6CapableCh := make(chan string), make(chan bool) + go func() { zoneCh <- getZone(httpReqTimeout) }() + go func() { ipv6CapableCh <- getIPv6Capable(httpReqTimeout) }() + + balancerName := env.C2PResolverTestOnlyTrafficDirectorURI + if balancerName == "" { + balancerName = tdURL + } + config := &bootstrap.Config{ + BalancerName: balancerName, + Creds: grpc.WithCredentialsBundle(google.NewDefaultCredentials()), + TransportAPI: version.TransportV3, + NodeProto: newNode(<-zoneCh, <-ipv6CapableCh), + } + + // Create singleton xds client with this config. The xds client will be + // used by the xds resolver later. + xdsC, err := newClientWithConfig(config) + if err != nil { + return nil, fmt.Errorf("failed to start xDS client: %v", err) + } + + // Create and return an xDS resolver. + t.Scheme = xdsName + xdsR, err := resolver.Get(xdsName).Build(t, cc, opts) + if err != nil { + xdsC.Close() + return nil, err + } + return &c2pResolver{ + Resolver: xdsR, + client: xdsC, + }, nil +} + +func (c2pResolverBuilder) Scheme() string { + return c2pScheme +} + +type c2pResolver struct { + resolver.Resolver + client xdsclient.XDSClient +} + +func (r *c2pResolver) Close() { + r.Resolver.Close() + r.client.Close() +} + +var ipv6EnabledMetadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + ipv6CapableMetadataName: structpb.NewBoolValue(true), + }, +} + +var id = fmt.Sprintf("C2P-%d", grpcrand.Int()) + +// newNode makes a copy of defaultNode, and populate it's Metadata and +// Locality fields. +func newNode(zone string, ipv6Capable bool) *v3corepb.Node { + ret := &v3corepb.Node{ + // Not all required fields are set in defaultNote. Metadata will be set + // if ipv6 is enabled. Locality will be set to the value from metadata. + Id: id, + UserAgentName: gRPCUserAgentName, + UserAgentVersionType: &v3corepb.Node_UserAgentVersion{UserAgentVersion: grpc.Version}, + ClientFeatures: []string{clientFeatureNoOverprovisioning}, + } + ret.Locality = &v3corepb.Locality{Zone: zone} + if ipv6Capable { + ret.Metadata = ipv6EnabledMetadata + } + return ret +} + +// runDirectPath returns whether this resolver should use direct path. +// +// direct path is enabled if this client is running on GCE, and the normal xDS +// is not used (bootstrap env vars are not set). +func runDirectPath() bool { + return env.BootstrapFileName == "" && env.BootstrapFileContent == "" && onGCE() +} diff --git a/xds/googledirectpath/googlec2p_test.go b/xds/googledirectpath/googlec2p_test.go new file mode 100644 index 00000000000..a208fad66c5 --- /dev/null +++ b/xds/googledirectpath/googlec2p_test.go @@ -0,0 +1,242 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package googledirectpath + +import ( + "strconv" + "testing" + "time" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/structpb" +) + +type emptyResolver struct { + resolver.Resolver + scheme string +} + +func (er *emptyResolver) Build(_ resolver.Target, _ resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + return er, nil +} + +func (er *emptyResolver) Scheme() string { + return er.scheme +} + +func (er *emptyResolver) Close() {} + +var ( + testDNSResolver = &emptyResolver{scheme: "dns"} + testXDSResolver = &emptyResolver{scheme: "xds"} +) + +func replaceResolvers() func() { + var registerForTesting bool + if resolver.Get(c2pScheme) == nil { + // If env var to enable c2p is not set, the resolver isn't registered. + // Need to register and unregister in defer. + registerForTesting = true + resolver.Register(&c2pResolverBuilder{}) + } + oldDNS := resolver.Get("dns") + resolver.Register(testDNSResolver) + oldXDS := resolver.Get("xds") + resolver.Register(testXDSResolver) + return func() { + if oldDNS != nil { + resolver.Register(oldDNS) + } else { + resolver.UnregisterForTesting("dns") + } + if oldXDS != nil { + resolver.Register(oldXDS) + } else { + resolver.UnregisterForTesting("xds") + } + if registerForTesting { + resolver.UnregisterForTesting(c2pScheme) + } + } +} + +// Test that when bootstrap env is set, fallback to DNS. +func TestBuildWithBootstrapEnvSet(t *testing.T) { + defer replaceResolvers()() + builder := resolver.Get(c2pScheme) + + for i, envP := range []*string{&env.BootstrapFileName, &env.BootstrapFileContent} { + t.Run(strconv.Itoa(i), func(t *testing.T) { + // Set bootstrap config env var. + oldEnv := *envP + *envP = "does not matter" + defer func() { *envP = oldEnv }() + + // Build should return DNS, not xDS. + r, err := builder.Build(resolver.Target{}, nil, resolver.BuildOptions{}) + if err != nil { + t.Fatalf("failed to build resolver: %v", err) + } + if r != testDNSResolver { + t.Fatalf("want dns resolver, got %#v", r) + } + }) + } +} + +// Test that when not on GCE, fallback to DNS. +func TestBuildNotOnGCE(t *testing.T) { + defer replaceResolvers()() + builder := resolver.Get(c2pScheme) + + oldOnGCE := onGCE + onGCE = func() bool { return false } + defer func() { onGCE = oldOnGCE }() + + // Build should return DNS, not xDS. + r, err := builder.Build(resolver.Target{}, nil, resolver.BuildOptions{}) + if err != nil { + t.Fatalf("failed to build resolver: %v", err) + } + if r != testDNSResolver { + t.Fatalf("want dns resolver, got %#v", r) + } +} + +type testXDSClient struct { + xdsclient.XDSClient + closed chan struct{} +} + +func (c *testXDSClient) Close() { + c.closed <- struct{}{} +} + +// Test that when xDS is built, the client is built with the correct config. +func TestBuildXDS(t *testing.T) { + defer replaceResolvers()() + builder := resolver.Get(c2pScheme) + + oldOnGCE := onGCE + onGCE = func() bool { return true } + defer func() { onGCE = oldOnGCE }() + + const testZone = "test-zone" + oldGetZone := getZone + getZone = func(time.Duration) string { return testZone } + defer func() { getZone = oldGetZone }() + + for _, tt := range []struct { + name string + ipv6 bool + tdURI string // traffic director URI will be overridden if this is set. + }{ + {name: "ipv6 true", ipv6: true}, + {name: "ipv6 false", ipv6: false}, + {name: "override TD URI", ipv6: true, tdURI: "test-uri"}, + } { + t.Run(tt.name, func(t *testing.T) { + oldGetIPv6Capability := getIPv6Capable + getIPv6Capable = func(time.Duration) bool { return tt.ipv6 } + defer func() { getIPv6Capable = oldGetIPv6Capability }() + + if tt.tdURI != "" { + oldURI := env.C2PResolverTestOnlyTrafficDirectorURI + env.C2PResolverTestOnlyTrafficDirectorURI = tt.tdURI + defer func() { + env.C2PResolverTestOnlyTrafficDirectorURI = oldURI + }() + } + + tXDSClient := &testXDSClient{closed: make(chan struct{}, 1)} + + configCh := make(chan *bootstrap.Config, 1) + oldNewClient := newClientWithConfig + newClientWithConfig = func(config *bootstrap.Config) (xdsclient.XDSClient, error) { + configCh <- config + return tXDSClient, nil + } + defer func() { newClientWithConfig = oldNewClient }() + + // Build should return DNS, not xDS. + r, err := builder.Build(resolver.Target{}, nil, resolver.BuildOptions{}) + if err != nil { + t.Fatalf("failed to build resolver: %v", err) + } + rr := r.(*c2pResolver) + if rrr := rr.Resolver; rrr != testXDSResolver { + t.Fatalf("want xds resolver, got %#v, ", rrr) + } + + wantNode := &v3corepb.Node{ + Id: id, + Metadata: nil, + Locality: &v3corepb.Locality{Zone: testZone}, + UserAgentName: gRPCUserAgentName, + UserAgentVersionType: &v3corepb.Node_UserAgentVersion{UserAgentVersion: grpc.Version}, + ClientFeatures: []string{clientFeatureNoOverprovisioning}, + } + if tt.ipv6 { + wantNode.Metadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + ipv6CapableMetadataName: { + Kind: &structpb.Value_BoolValue{BoolValue: true}, + }, + }, + } + } + wantConfig := &bootstrap.Config{ + BalancerName: tdURL, + TransportAPI: version.TransportV3, + NodeProto: wantNode, + } + if tt.tdURI != "" { + wantConfig.BalancerName = tt.tdURI + } + cmpOpts := cmp.Options{ + cmpopts.IgnoreFields(bootstrap.Config{}, "Creds"), + protocmp.Transform(), + } + select { + case c := <-configCh: + if diff := cmp.Diff(c, wantConfig, cmpOpts); diff != "" { + t.Fatalf("%v", diff) + } + case <-time.After(time.Second): + t.Fatalf("timeout waiting for client config") + } + + r.Close() + select { + case <-tXDSClient.closed: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for client close") + } + }) + } +} diff --git a/xds/googledirectpath/utils.go b/xds/googledirectpath/utils.go new file mode 100644 index 00000000000..60044197978 --- /dev/null +++ b/xds/googledirectpath/utils.go @@ -0,0 +1,96 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package googledirectpath + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "sync" + "time" +) + +func getFromMetadata(timeout time.Duration, urlStr string) ([]byte, error) { + parsedURL, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + client := &http.Client{Timeout: timeout} + req := &http.Request{ + Method: http.MethodGet, + URL: parsedURL, + Header: http.Header{"Metadata-Flavor": {"Google"}}, + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed communicating with metadata server: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("metadata server returned resp with non-OK: %v", resp) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed reading from metadata server: %v", err) + } + return body, nil +} + +var ( + zone string + zoneOnce sync.Once +) + +// Defined as var to be overridden in tests. +var getZone = func(timeout time.Duration) string { + zoneOnce.Do(func() { + qualifiedZone, err := getFromMetadata(timeout, zoneURL) + if err != nil { + logger.Warningf("could not discover instance zone: %v", err) + return + } + i := bytes.LastIndexByte(qualifiedZone, '/') + if i == -1 { + logger.Warningf("could not parse zone from metadata server: %s", qualifiedZone) + return + } + zone = string(qualifiedZone[i+1:]) + }) + return zone +} + +var ( + ipv6Capable bool + ipv6CapableOnce sync.Once +) + +// Defined as var to be overridden in tests. +var getIPv6Capable = func(timeout time.Duration) bool { + ipv6CapableOnce.Do(func() { + _, err := getFromMetadata(timeout, ipv6URL) + if err != nil { + logger.Warningf("could not discover ipv6 capability: %v", err) + return + } + ipv6Capable = true + }) + return ipv6Capable +} diff --git a/xds/internal/balancer/balancer.go b/xds/internal/balancer/balancer.go index 5883027a2c5..86656736a61 100644 --- a/xds/internal/balancer/balancer.go +++ b/xds/internal/balancer/balancer.go @@ -20,8 +20,10 @@ package balancer import ( - _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // Register the CDS balancer - _ "google.golang.org/grpc/xds/internal/balancer/clustermanager" // Register the xds_cluster_manager balancer - _ "google.golang.org/grpc/xds/internal/balancer/edsbalancer" // Register the EDS balancer - _ "google.golang.org/grpc/xds/internal/balancer/weightedtarget" // Register the weighted_target balancer + _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // Register the CDS balancer + _ "google.golang.org/grpc/xds/internal/balancer/clusterimpl" // Register the xds_cluster_impl balancer + _ "google.golang.org/grpc/xds/internal/balancer/clustermanager" // Register the xds_cluster_manager balancer + _ "google.golang.org/grpc/xds/internal/balancer/clusterresolver" // Register the xds_cluster_resolver balancer + _ "google.golang.org/grpc/xds/internal/balancer/priority" // Register the priority balancer + _ "google.golang.org/grpc/xds/internal/balancer/weightedtarget" // Register the weighted_target balancer ) diff --git a/xds/internal/balancer/balancergroup/balancergroup.go b/xds/internal/balancer/balancergroup/balancergroup.go index 2ec576a4b57..5798b03ac50 100644 --- a/xds/internal/balancer/balancergroup/balancergroup.go +++ b/xds/internal/balancer/balancergroup/balancergroup.go @@ -24,7 +24,7 @@ import ( "time" orcapb "github.com/cncf/udpa/go/udpa/data/orca/v1" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" @@ -104,6 +104,22 @@ func (sbc *subBalancerWrapper) startBalancer() { } } +func (sbc *subBalancerWrapper) exitIdle() { + b := sbc.balancer + if b == nil { + return + } + if ei, ok := b.(balancer.ExitIdler); ok { + ei.ExitIdle() + return + } + for sc, b := range sbc.group.scToSubBalancer { + if b == sbc { + sc.Connect() + } + } +} + func (sbc *subBalancerWrapper) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { b := sbc.balancer if b == nil { @@ -183,7 +199,7 @@ type BalancerGroup struct { cc balancer.ClientConn buildOpts balancer.BuildOptions logger *grpclog.PrefixLogger - loadStore load.PerClusterReporter + loadStore load.PerClusterReporter // TODO: delete this, no longer needed. It was used by EDS. // stateAggregator is where the state/picker updates will be sent to. It's // provided by the parent balancer, to build a picker with all the @@ -479,6 +495,10 @@ func (bg *BalancerGroup) Close() { } bg.incomingMu.Unlock() + // Clear(true) runs clear function to close sub-balancers in cache. It + // must be called out of outgoing mutex. + bg.balancerCache.Clear(true) + bg.outgoingMu.Lock() if bg.outgoingStarted { bg.outgoingStarted = false @@ -487,9 +507,17 @@ func (bg *BalancerGroup) Close() { } } bg.outgoingMu.Unlock() - // Clear(true) runs clear function to close sub-balancers in cache. It - // must be called out of outgoing mutex. - bg.balancerCache.Clear(true) +} + +// ExitIdle should be invoked when the parent LB policy's ExitIdle is invoked. +// It will trigger this on all sub-balancers, or reconnect their subconns if +// not supported. +func (bg *BalancerGroup) ExitIdle() { + bg.outgoingMu.Lock() + for _, config := range bg.idToBalancerConfig { + config.exitIdle() + } + bg.outgoingMu.Unlock() } const ( diff --git a/xds/internal/balancer/balancergroup/balancergroup_test.go b/xds/internal/balancer/balancergroup/balancergroup_test.go index 0ad4bf8df10..9cc7bd072ec 100644 --- a/xds/internal/balancer/balancergroup/balancergroup_test.go +++ b/xds/internal/balancer/balancergroup/balancergroup_test.go @@ -42,8 +42,8 @@ import ( "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/resolver" "google.golang.org/grpc/xds/internal/balancer/weightedtarget/weightedaggregator" - "google.golang.org/grpc/xds/internal/client/load" "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) var ( @@ -843,7 +843,7 @@ func (s) TestBalancerGroup_locality_caching_not_readd_within_timeout(t *testing. defer replaceDefaultSubBalancerCloseTimeout(time.Second)() _, _, cc, addrToSC := initBalancerGroupForCachingTest(t) - // The sub-balancer is not re-added withtin timeout. The subconns should be + // The sub-balancer is not re-added within timeout. The subconns should be // removed. removeTimeout := time.After(DefaultSubBalancerCloseTimeout) scToRemove := map[balancer.SubConn]int{ @@ -938,6 +938,36 @@ func (s) TestBalancerGroup_locality_caching_readd_with_different_builder(t *test } } +// After removing a sub-balancer, it will be kept in cache. Make sure that this +// sub-balancer's Close is called when the balancer group is closed. +func (s) TestBalancerGroup_CloseStopsBalancerInCache(t *testing.T) { + const balancerName = "stub-TestBalancerGroup_check_close" + closed := make(chan struct{}) + stub.Register(balancerName, stub.BalancerFuncs{Close: func(_ *stub.BalancerData) { + close(closed) + }}) + builder := balancer.Get(balancerName) + + defer replaceDefaultSubBalancerCloseTimeout(time.Second)() + gator, bg, _, _ := initBalancerGroupForCachingTest(t) + + // Add balancer, and remove + gator.Add(testBalancerIDs[2], 1) + bg.Add(testBalancerIDs[2], builder) + gator.Remove(testBalancerIDs[2]) + bg.Remove(testBalancerIDs[2]) + + // Immediately close balancergroup, before the cache timeout. + bg.Close() + + // Make sure the removed child balancer is closed eventually. + select { + case <-closed: + case <-time.After(time.Second * 2): + t.Fatalf("timeout waiting for the child balancer in cache to be closed") + } +} + // TestBalancerGroupBuildOptions verifies that the balancer.BuildOptions passed // to the balancergroup at creation time is passed to child policies. func (s) TestBalancerGroupBuildOptions(t *testing.T) { diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go index e4d349753e1..82d2a96958e 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go @@ -31,66 +31,57 @@ import ( xdsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/pretty" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal/balancer/edsbalancer" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/xds/internal/balancer/clusterresolver" + "google.golang.org/grpc/xds/internal/balancer/ringhash" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( cdsName = "cds_experimental" - edsName = "eds_experimental" ) var ( errBalancerClosed = errors.New("cdsBalancer is closed") - // newEDSBalancer is a helper function to build a new edsBalancer and will be - // overridden in unittests. - newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { - builder := balancer.Get(edsName) + // newChildBalancer is a helper function to build a new cluster_resolver + // balancer and will be overridden in unittests. + newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { + builder := balancer.Get(clusterresolver.Name) if builder == nil { - return nil, fmt.Errorf("xds: no balancer builder with name %v", edsName) + return nil, fmt.Errorf("xds: no balancer builder with name %v", clusterresolver.Name) } - // We directly pass the parent clientConn to the - // underlying edsBalancer because the cdsBalancer does - // not deal with subConns. + // We directly pass the parent clientConn to the underlying + // cluster_resolver balancer because the cdsBalancer does not deal with + // subConns. return builder.Build(cc, opts), nil } - newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } buildProvider = buildProviderFunc ) func init() { - balancer.Register(cdsBB{}) + balancer.Register(bb{}) } -// cdsBB (short for cdsBalancerBuilder) implements the balancer.Builder -// interface to help build a cdsBalancer. +// bb implements the balancer.Builder interface to help build a cdsBalancer. // It also implements the balancer.ConfigParser interface to help parse the // JSON service config, to be passed to the cdsBalancer. -type cdsBB struct{} +type bb struct{} // Build creates a new CDS balancer with the ClientConn. -func (cdsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { b := &cdsBalancer{ - bOpts: opts, - updateCh: buffer.NewUnbounded(), - closed: grpcsync.NewEvent(), - cancelWatch: func() {}, // No-op at this point. - xdsHI: xdsinternal.NewHandshakeInfo(nil, nil), + bOpts: opts, + updateCh: buffer.NewUnbounded(), + closed: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), + xdsHI: xdsinternal.NewHandshakeInfo(nil, nil), } b.logger = prefixLogger((b)) b.logger.Infof("Created") - - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.xdsClient = client - var creds credentials.TransportCredentials switch { case opts.DialCreds != nil: @@ -102,7 +93,7 @@ func (cdsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer. b.xdsCredsInUse = true } b.logger.Infof("xDS credentials in use: %v", b.xdsCredsInUse) - + b.clusterHandler = newClusterHandler(b) b.ccw = &ccWrapper{ ClientConn: cc, xdsHI: b.xdsHI, @@ -112,7 +103,7 @@ func (cdsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer. } // Name returns the name of balancers built by this builder. -func (cdsBB) Name() string { +func (bb) Name() string { return cdsName } @@ -125,7 +116,7 @@ type lbConfig struct { // ParseConfig parses the JSON load balancer config provided into an // internal form or returns an error if the config is invalid. -func (cdsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { var cfg lbConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, fmt.Errorf("xds: unable to unmarshal lbconfig: %s, error: %v", string(c), err) @@ -133,54 +124,40 @@ func (cdsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, return &cfg, nil } -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the cdsBalancer. This will be faked out in unittests. -type xdsClientInterface interface { - WatchCluster(string, func(xdsclient.ClusterUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - // ccUpdate wraps a clientConn update received from gRPC (pushed from the // xdsResolver). A valid clusterName causes the cdsBalancer to register a CDS // watcher with the xdsClient, while a non-nil error causes it to cancel the -// existing watch and propagate the error to the underlying edsBalancer. +// existing watch and propagate the error to the underlying cluster_resolver +// balancer. type ccUpdate struct { clusterName string err error } // scUpdate wraps a subConn update received from gRPC. This is directly passed -// on to the edsBalancer. +// on to the cluster_resolver balancer. type scUpdate struct { subConn balancer.SubConn state balancer.SubConnState } -// watchUpdate wraps the information received from a registered CDS watcher. A -// non-nil error is propagated to the underlying edsBalancer. A valid update -// results in creating a new edsBalancer (if one doesn't already exist) and -// pushing the update to it. -type watchUpdate struct { - cds xdsclient.ClusterUpdate - err error -} +type exitIdle struct{} -// cdsBalancer implements a CDS based LB policy. It instantiates an EDS based -// LB policy to further resolve the serviceName received from CDS, into -// localities and endpoints. Implements the balancer.Balancer interface which -// is exposed to gRPC and implements the balancer.ClientConn interface which is -// exposed to the edsBalancer. +// cdsBalancer implements a CDS based LB policy. It instantiates a +// cluster_resolver balancer to further resolve the serviceName received from +// CDS, into localities and endpoints. Implements the balancer.Balancer +// interface which is exposed to gRPC and implements the balancer.ClientConn +// interface which is exposed to the cluster_resolver balancer. type cdsBalancer struct { ccw *ccWrapper // ClientConn interface passed to child LB. bOpts balancer.BuildOptions // BuildOptions passed to child LB. updateCh *buffer.Unbounded // Channel for gRPC and xdsClient updates. - xdsClient xdsClientInterface // xDS client to watch Cluster resource. - cancelWatch func() // Cluster watch cancel func. - edsLB balancer.Balancer // EDS child policy. - clusterToWatch string + xdsClient xdsclient.XDSClient // xDS client to watch Cluster resource. + clusterHandler *clusterHandler // To watch the clusters. + childLB balancer.Balancer logger *grpclog.PrefixLogger closed *grpcsync.Event + done *grpcsync.Event // The certificate providers are cached here to that they can be closed when // a new provider is to be created. @@ -193,25 +170,15 @@ type cdsBalancer struct { // handleClientConnUpdate handles a ClientConnUpdate received from gRPC. Good // updates lead to registration of a CDS watch. Updates with error lead to // cancellation of existing watch and propagation of the same error to the -// edsBalancer. +// cluster_resolver balancer. func (b *cdsBalancer) handleClientConnUpdate(update *ccUpdate) { // We first handle errors, if any, and then proceed with handling the // update, only if the status quo has changed. if err := update.err; err != nil { b.handleErrorFromUpdate(err, true) - } - if b.clusterToWatch == update.clusterName { return } - if update.clusterName != "" { - cancelWatch := b.xdsClient.WatchCluster(update.clusterName, b.handleClusterUpdate) - b.logger.Infof("Watch started on resource name %v with xds-client %p", update.clusterName, b.xdsClient) - b.cancelWatch = func() { - cancelWatch() - b.logger.Infof("Watch cancelled on resource name %v with xds-client %p", update.clusterName, b.xdsClient) - } - b.clusterToWatch = update.clusterName - } + b.clusterHandler.updateRootCluster(update.clusterName) } // handleSecurityConfig processes the security configuration received from the @@ -235,7 +202,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err // one where fallback credentials are to be used. b.xdsHI.SetRootCertProvider(nil) b.xdsHI.SetIdentityCertProvider(nil) - b.xdsHI.SetAcceptedSANs(nil) + b.xdsHI.SetSANMatchers(nil) return nil } @@ -278,7 +245,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err // could have been non-nil earlier. b.xdsHI.SetRootCertProvider(rootProvider) b.xdsHI.SetIdentityCertProvider(identityProvider) - b.xdsHI.SetAcceptedSANs(config.AcceptedSANs) + b.xdsHI.SetSANMatchers(config.SubjectAltNameMatchers) return nil } @@ -303,22 +270,22 @@ func buildProviderFunc(configs map[string]*certprovider.BuildableConfig, instanc } // handleWatchUpdate handles a watch update from the xDS Client. Good updates -// lead to clientConn updates being invoked on the underlying edsBalancer. -func (b *cdsBalancer) handleWatchUpdate(update *watchUpdate) { +// lead to clientConn updates being invoked on the underlying cluster_resolver balancer. +func (b *cdsBalancer) handleWatchUpdate(update clusterHandlerUpdate) { if err := update.err; err != nil { b.logger.Warningf("Watch error from xds-client %p: %v", b.xdsClient, err) b.handleErrorFromUpdate(err, false) return } - b.logger.Infof("Watch update from xds-client %p, content: %+v", b.xdsClient, update.cds) + b.logger.Infof("Watch update from xds-client %p, content: %+v, security config: %v", b.xdsClient, pretty.ToJSON(update.updates), pretty.ToJSON(update.securityCfg)) // Process the security config from the received update before building the // child policy or forwarding the update to it. We do this because the child // policy may try to create a new subConn inline. Processing the security // configuration here and setting up the handshakeInfo will make sure that // such attempts are handled properly. - if err := b.handleSecurityConfig(update.cds.SecurityCfg); err != nil { + if err := b.handleSecurityConfig(update.securityCfg); err != nil { // If the security config is invalid, for example, if the provider // instance is not found in the bootstrap config, we need to put the // channel in transient failure. @@ -328,33 +295,67 @@ func (b *cdsBalancer) handleWatchUpdate(update *watchUpdate) { } // The first good update from the watch API leads to the instantiation of an - // edsBalancer. Further updates/errors are propagated to the existing - // edsBalancer. - if b.edsLB == nil { - edsLB, err := newEDSBalancer(b.ccw, b.bOpts) + // cluster_resolver balancer. Further updates/errors are propagated to the existing + // cluster_resolver balancer. + if b.childLB == nil { + childLB, err := newChildBalancer(b.ccw, b.bOpts) if err != nil { - b.logger.Errorf("Failed to create child policy of type %s, %v", edsName, err) + b.logger.Errorf("Failed to create child policy of type %s, %v", clusterresolver.Name, err) return } - b.edsLB = edsLB - b.logger.Infof("Created child policy %p of type %s", b.edsLB, edsName) + b.childLB = childLB + b.logger.Infof("Created child policy %p of type %s", b.childLB, clusterresolver.Name) + } + + dms := make([]clusterresolver.DiscoveryMechanism, len(update.updates)) + for i, cu := range update.updates { + switch cu.ClusterType { + case xdsclient.ClusterTypeEDS: + dms[i] = clusterresolver.DiscoveryMechanism{ + Type: clusterresolver.DiscoveryMechanismTypeEDS, + Cluster: cu.ClusterName, + EDSServiceName: cu.EDSServiceName, + MaxConcurrentRequests: cu.MaxRequests, + } + if cu.EnableLRS { + // An empty string here indicates that the cluster_resolver balancer should use the + // same xDS server for load reporting as it does for EDS + // requests/responses. + dms[i].LoadReportingServerName = new(string) + + } + case xdsclient.ClusterTypeLogicalDNS: + dms[i] = clusterresolver.DiscoveryMechanism{ + Type: clusterresolver.DiscoveryMechanismTypeLogicalDNS, + DNSHostname: cu.DNSHostName, + } + default: + b.logger.Infof("unexpected cluster type %v when handling update from cluster handler", cu.ClusterType) + } } - lbCfg := &edsbalancer.EDSConfig{ - EDSServiceName: update.cds.ServiceName, - MaxConcurrentRequests: update.cds.MaxRequests, + lbCfg := &clusterresolver.LBConfig{ + DiscoveryMechanisms: dms, } - if update.cds.EnableLRS { - // An empty string here indicates that the edsBalancer should use the - // same xDS server for load reporting as it does for EDS - // requests/responses. - lbCfg.LrsLoadReportingServerName = new(string) + // lbPolicy is set only when the policy is ringhash. The default (when it's + // not set) is roundrobin. And similarly, we only need to set XDSLBPolicy + // for ringhash (it also defaults to roundrobin). + if lbp := update.lbPolicy; lbp != nil { + lbCfg.XDSLBPolicy = &internalserviceconfig.BalancerConfig{ + Name: ringhash.Name, + Config: &ringhash.LBConfig{ + MinRingSize: lbp.MinimumRingSize, + MaxRingSize: lbp.MaximumRingSize, + }, + } } + ccState := balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, b.xdsClient), BalancerConfig: lbCfg, } - if err := b.edsLB.UpdateClientConnState(ccState); err != nil { - b.logger.Errorf("xds: edsBalancer.UpdateClientConnState(%+v) returned error: %v", ccState, err) + if err := b.childLB.UpdateClientConnState(ccState); err != nil { + b.logger.Errorf("xds: cluster_resolver balancer.UpdateClientConnState(%+v) returned error: %v", ccState, err) } } @@ -371,28 +372,41 @@ func (b *cdsBalancer) run() { b.handleClientConnUpdate(update) case *scUpdate: // SubConn updates are passthrough and are simply handed over to - // the underlying edsBalancer. - if b.edsLB == nil { - b.logger.Errorf("xds: received scUpdate {%+v} with no edsBalancer", update) + // the underlying cluster_resolver balancer. + if b.childLB == nil { + b.logger.Errorf("xds: received scUpdate {%+v} with no cluster_resolver balancer", update) break } - b.edsLB.UpdateSubConnState(update.subConn, update.state) - case *watchUpdate: - b.handleWatchUpdate(update) + b.childLB.UpdateSubConnState(update.subConn, update.state) + case exitIdle: + if b.childLB == nil { + b.logger.Errorf("xds: received ExitIdle with no child balancer") + break + } + // This implementation assumes the child balancer supports + // ExitIdle (but still checks for the interface's existence to + // avoid a panic if not). If the child does not, no subconns + // will be connected. + if ei, ok := b.childLB.(balancer.ExitIdler); ok { + ei.ExitIdle() + } } - - // Close results in cancellation of the CDS watch and closing of the - // underlying edsBalancer and is the only way to exit this goroutine. + case u := <-b.clusterHandler.updateChannel: + b.handleWatchUpdate(u) case <-b.closed.Done(): - b.cancelWatch() - b.cancelWatch = func() {} - - if b.edsLB != nil { - b.edsLB.Close() - b.edsLB = nil + b.clusterHandler.close() + if b.childLB != nil { + b.childLB.Close() + b.childLB = nil + } + if b.cachedRoot != nil { + b.cachedRoot.Close() + } + if b.cachedIdentity != nil { + b.cachedIdentity.Close() } - // This is the *ONLY* point of return from this function. b.logger.Infof("Shutdown") + b.done.Fire() return } } @@ -411,23 +425,22 @@ func (b *cdsBalancer) run() { // - If it's from xds client, it means CDS resource were removed. The CDS // watcher should keep watching. // -// In both cases, the error will be forwarded to EDS balancer. And if error is -// resource-not-found, the child EDS balancer will stop watching EDS. +// In both cases, the error will be forwarded to the child balancer. And if +// error is resource-not-found, the child balancer will stop watching EDS. func (b *cdsBalancer) handleErrorFromUpdate(err error, fromParent bool) { - // TODO: connection errors will be sent to the eds balancers directly, and - // also forwarded by the parent balancers/resolvers. So the eds balancer may - // see the same error multiple times. We way want to only forward the error - // to eds if it's not a connection error. - // // This is not necessary today, because xds client never sends connection // errors. if fromParent && xdsclient.ErrType(err) == xdsclient.ErrorTypeResourceNotFound { - b.cancelWatch() + b.clusterHandler.close() } - if b.edsLB != nil { - b.edsLB.ResolverError(err) + if b.childLB != nil { + if xdsclient.ErrType(err) != xdsclient.ErrorTypeConnection { + // Connection errors will be sent to the child balancers directly. + // There's no need to forward them. + b.childLB.ResolverError(err) + } } else { - // If eds balancer was never created, fail the RPCs with + // If child balancer was never created, fail the RPCs with // errors. b.ccw.UpdateState(balancer.State{ ConnectivityState: connectivity.TransientFailure, @@ -436,16 +449,6 @@ func (b *cdsBalancer) handleErrorFromUpdate(err error, fromParent bool) { } } -// handleClusterUpdate is the CDS watch API callback. It simply pushes the -// received information on to the update channel for run() to pick it up. -func (b *cdsBalancer) handleClusterUpdate(cu xdsclient.ClusterUpdate, err error) { - if b.closed.HasFired() { - b.logger.Warningf("xds: received cluster update {%+v} after cdsBalancer was closed", cu) - return - } - b.updateCh.Put(&watchUpdate{cds: cu, err: err}) -} - // UpdateClientConnState receives the serviceConfig (which contains the // clusterName to watch for in CDS) and the xdsClient object from the // xdsResolver. @@ -455,7 +458,15 @@ func (b *cdsBalancer) UpdateClientConnState(state balancer.ClientConnState) erro return errBalancerClosed } - b.logger.Infof("Received update from resolver, balancer config: %+v", state.BalancerConfig) + if b.xdsClient == nil { + c := xdsclient.FromResolverState(state.ResolverState) + if c == nil { + return balancer.ErrBadResolverState + } + b.xdsClient = c + } + + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(state.BalancerConfig)) // The errors checked here should ideally never happen because the // ServiceConfig in this case is prepared by the xdsResolver and is not // something that is received on the wire. @@ -490,10 +501,15 @@ func (b *cdsBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Sub b.updateCh.Put(&scUpdate{subConn: sc, state: state}) } -// Close closes the cdsBalancer and the underlying edsBalancer. +// Close cancels the CDS watch, closes the child policy and closes the +// cdsBalancer. func (b *cdsBalancer) Close() { b.closed.Fire() - b.xdsClient.Close() + <-b.done.Done() +} + +func (b *cdsBalancer) ExitIdle() { + b.updateCh.Put(exitIdle{}) } // ccWrapper wraps the balancer.ClientConn passed to the CDS balancer at diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go index fee48c262eb..9483818e306 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -20,47 +20,63 @@ import ( "context" "errors" "fmt" + "regexp" "testing" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/credentials/local" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/internal" - xdsinternal "google.golang.org/grpc/internal/credentials/xds" + xdscredsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/resolver" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( fakeProvider1Name = "fake-certificate-provider-1" fakeProvider2Name = "fake-certificate-provider-2" fakeConfig = "my fake config" + testSAN = "test-san" ) var ( + testSANMatchers = []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP(testSAN), nil, nil, nil, nil, true), + matcher.StringMatcherForTesting(nil, newStringP(testSAN), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP(testSAN), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(testSAN), false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP(testSAN), nil, false), + } fpb1, fpb2 *fakeProviderBuilder bootstrapConfig *bootstrap.Config cdsUpdateWithGoodSecurityCfg = xdsclient.ClusterUpdate{ - ServiceName: serviceName, + ClusterName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", - IdentityInstanceName: "default2", + RootInstanceName: "default1", + IdentityInstanceName: "default2", + SubjectAltNameMatchers: testSANMatchers, }, } cdsUpdateWithMissingSecurityCfg = xdsclient.ClusterUpdate{ - ServiceName: serviceName, + ClusterName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ RootInstanceName: "not-default", }, } ) +func newStringP(s string) *string { + return &s +} + func init() { fpb1 = &fakeProviderBuilder{name: fakeProvider1Name} fpb2 = &fakeProviderBuilder{name: fakeProvider2Name} @@ -115,11 +131,7 @@ func (p *fakeProvider) Close() { // xDSCredentials. func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) { t.Helper() - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - builder := balancer.Get(cdsName) if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", cdsName) @@ -139,14 +151,14 @@ func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDS // Override the creation of the EDS balancer to return a fake EDS balancer // implementation. edsB := newTestEDSBalancer() - oldEDSBalancerBuilder := newEDSBalancer - newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { + oldEDSBalancerBuilder := newChildBalancer + newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { edsB.parentCC = cc return edsB, nil } // Push a ClientConnState update to the CDS balancer with a cluster name. - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } @@ -163,8 +175,8 @@ func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDS } return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() { - newXDSClient = oldNewXDSClient - newEDSBalancer = oldEDSBalancerBuilder + newChildBalancer = oldEDSBalancerBuilder + xdsC.Close() } } @@ -190,7 +202,7 @@ func makeNewSubConn(ctx context.Context, edsCC balancer.ClientConn, parentCC *xd if got, want := gotAddrs[0].Addr, addrs[0].Addr; got != want { return nil, fmt.Errorf("resolver.Address passed to parent ClientConn has address %q, want %q", got, want) } - getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdsinternal.HandshakeInfo) + getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo) hi := getHI(gotAddrs[0].Attributes) if hi == nil { return nil, errors.New("resolver.Address passed to parent ClientConn doesn't contain attributes") @@ -198,6 +210,11 @@ func makeNewSubConn(ctx context.Context, edsCC balancer.ClientConn, parentCC *xd if gotFallback := hi.UseFallbackCreds(); gotFallback != wantFallback { return nil, fmt.Errorf("resolver.Address HandshakeInfo uses fallback creds? %v, want %v", gotFallback, wantFallback) } + if !wantFallback { + if diff := cmp.Diff(testSANMatchers, hi.GetSANMatchersForTesting(), cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + return nil, fmt.Errorf("unexpected diff in the list of SAN matchers (-got, +want):\n%s", diff) + } + } } return sc, nil } @@ -232,9 +249,9 @@ func (s) TestSecurityConfigWithoutXDSCreds(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -287,10 +304,10 @@ func (s) TestNoSecurityConfigWithXDSCreds(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. No security config is + // newChildBalancer function as part of test setup. No security config is // passed to the CDS balancer as part of this update. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -445,8 +462,8 @@ func (s) TestSecurityConfigUpdate_BadToGood(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + wantCCS := edsCCS(serviceName, nil, false, nil) if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } @@ -479,8 +496,8 @@ func (s) TestGoodSecurityConfig(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { @@ -507,7 +524,7 @@ func (s) TestGoodSecurityConfig(t *testing.T) { if got, want := gotAddrs[0].Addr, addrs[0].Addr; got != want { t.Fatalf("resolver.Address passed to parent ClientConn through UpdateAddresses() has address %q, want %q", got, want) } - getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdsinternal.HandshakeInfo) + getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo) hi := getHI(gotAddrs[0].Attributes) if hi == nil { t.Fatal("resolver.Address passed to parent ClientConn through UpdateAddresses() doesn't contain attributes") @@ -532,8 +549,8 @@ func (s) TestSecurityConfigUpdate_GoodToFallback(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { @@ -549,7 +566,7 @@ func (s) TestSecurityConfigUpdate_GoodToFallback(t *testing.T) { // an update which contains bad security config. So, we expect the CDS // balancer to forward this error to the EDS balancer and eventually the // channel needs to be put in a bad state. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } @@ -582,8 +599,8 @@ func (s) TestSecurityConfigUpdate_GoodToBad(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { @@ -617,7 +634,7 @@ func (s) TestSecurityConfigUpdate_GoodToBad(t *testing.T) { // registered watch should not be cancelled. sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } } @@ -653,14 +670,15 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. cdsUpdate := xdsclient.ClusterUpdate{ - ServiceName: serviceName, + ClusterName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", + RootInstanceName: "default1", + SubjectAltNameMatchers: testSANMatchers, }, } - wantCCS := edsCCS(serviceName, nil, false) + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -679,9 +697,10 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { // Push another update with a new security configuration. cdsUpdate = xdsclient.ClusterUpdate{ - ServiceName: serviceName, + ClusterName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default2", + RootInstanceName: "default2", + SubjectAltNameMatchers: testSANMatchers, }, } if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go index 9c7bc2362ab..30b612fc7d0 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go @@ -26,19 +26,19 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpctest" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal/balancer/edsbalancer" - "google.golang.org/grpc/xds/internal/client" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/balancer/clusterresolver" + "google.golang.org/grpc/xds/internal/balancer/ringhash" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -83,7 +83,8 @@ type testEDSBalancer struct { // resolverErrCh is a channel used to signal a resolver error. resolverErrCh *testutils.Channel // closeCh is a channel used to signal the closing of this balancer. - closeCh *testutils.Channel + closeCh *testutils.Channel + exitIdleCh *testutils.Channel // parentCC is the balancer.ClientConn passed to this test balancer as part // of the Build() call. parentCC balancer.ClientConn @@ -100,6 +101,7 @@ func newTestEDSBalancer() *testEDSBalancer { scStateCh: testutils.NewChannel(), resolverErrCh: testutils.NewChannel(), closeCh: testutils.NewChannel(), + exitIdleCh: testutils.NewChannel(), } } @@ -120,6 +122,10 @@ func (tb *testEDSBalancer) Close() { tb.closeCh.Send(struct{}{}) } +func (tb *testEDSBalancer) ExitIdle() { + tb.exitIdleCh.Send(struct{}{}) +} + // waitForClientConnUpdate verifies if the testEDSBalancer receives the // provided ClientConnState within a reasonable amount of time. func (tb *testEDSBalancer) waitForClientConnUpdate(ctx context.Context, wantCCS balancer.ClientConnState) error { @@ -128,8 +134,11 @@ func (tb *testEDSBalancer) waitForClientConnUpdate(ctx context.Context, wantCCS return err } gotCCS := ccs.(balancer.ClientConnState) - if !cmp.Equal(gotCCS, wantCCS, cmpopts.IgnoreUnexported(attributes.Attributes{})) { - return fmt.Errorf("received ClientConnState: %+v, want %+v", gotCCS, wantCCS) + if xdsclient.FromResolverState(gotCCS.ResolverState) == nil { + return fmt.Errorf("want resolver state with XDSClient attached, got one without") + } + if diff := cmp.Diff(gotCCS, wantCCS, cmpopts.IgnoreFields(resolver.State{}, "Attributes")); diff != "" { + return fmt.Errorf("received unexpected ClientConnState, diff (-got +want): %v", diff) } return nil } @@ -172,7 +181,7 @@ func (tb *testEDSBalancer) waitForClose(ctx context.Context) error { // cdsCCS is a helper function to construct a good update passed from the // xdsResolver to the cdsBalancer. -func cdsCCS(cluster string) balancer.ClientConnState { +func cdsCCS(cluster string, xdsC xdsclient.XDSClient) balancer.ClientConnState { const cdsLBConfig = `{ "loadBalancingConfig":[ { @@ -184,37 +193,40 @@ func cdsCCS(cluster string) balancer.ClientConnState { }` jsonSC := fmt.Sprintf(cdsLBConfig, cluster) return balancer.ClientConnState{ - ResolverState: resolver.State{ + ResolverState: xdsclient.SetClient(resolver.State{ ServiceConfig: internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(jsonSC), - }, + }, xdsC), BalancerConfig: &lbConfig{ClusterName: clusterName}, } } // edsCCS is a helper function to construct a good update passed from the // cdsBalancer to the edsBalancer. -func edsCCS(service string, countMax *uint32, enableLRS bool) balancer.ClientConnState { - lbCfg := &edsbalancer.EDSConfig{ - EDSServiceName: service, +func edsCCS(service string, countMax *uint32, enableLRS bool, xdslbpolicy *internalserviceconfig.BalancerConfig) balancer.ClientConnState { + discoveryMechanism := clusterresolver.DiscoveryMechanism{ + Type: clusterresolver.DiscoveryMechanismTypeEDS, + Cluster: service, MaxConcurrentRequests: countMax, } if enableLRS { - lbCfg.LrsLoadReportingServerName = new(string) + discoveryMechanism.LoadReportingServerName = new(string) + } + lbCfg := &clusterresolver.LBConfig{ + DiscoveryMechanisms: []clusterresolver.DiscoveryMechanism{discoveryMechanism}, + XDSLBPolicy: xdslbpolicy, + } + return balancer.ClientConnState{ BalancerConfig: lbCfg, } } // setup creates a cdsBalancer and an edsBalancer (and overrides the -// newEDSBalancer function to return it), and also returns a cleanup function. +// newChildBalancer function to return it), and also returns a cleanup function. func setup(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) { t.Helper() - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - builder := balancer.Get(cdsName) if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", cdsName) @@ -223,15 +235,15 @@ func setup(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *x cdsB := builder.Build(tcc, balancer.BuildOptions{}) edsB := newTestEDSBalancer() - oldEDSBalancerBuilder := newEDSBalancer - newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { + oldEDSBalancerBuilder := newChildBalancer + newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { edsB.parentCC = cc return edsB, nil } return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() { - newEDSBalancer = oldEDSBalancerBuilder - newXDSClient = oldNewXDSClient + newChildBalancer = oldEDSBalancerBuilder + xdsC.Close() } } @@ -241,7 +253,7 @@ func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBal t.Helper() xdsC, cdsB, edsB, tcc, cancel := setup(t) - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } @@ -261,6 +273,9 @@ func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBal // cdsBalancer with different inputs and verifies that the CDS watch API on the // provided xdsClient is invoked appropriately. func (s) TestUpdateClientConnState(t *testing.T) { + xdsC := fakeclient.NewClient() + defer xdsC.Close() + tests := []struct { name string ccs balancer.ClientConnState @@ -279,14 +294,14 @@ func (s) TestUpdateClientConnState(t *testing.T) { }, { name: "happy-good-case", - ccs: cdsCCS(clusterName), + ccs: cdsCCS(clusterName, xdsC), wantCluster: clusterName, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - xdsC, cdsB, _, _, cancel := setup(t) + _, cdsB, _, _, cancel := setup(t) defer func() { cancel() cdsB.Close() @@ -323,7 +338,7 @@ func (s) TestUpdateClientConnStateWithSameState(t *testing.T) { }() // This is the same clientConn update sent in setupWithWatch(). - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } // The above update should not result in a new watch being registered. @@ -352,13 +367,24 @@ func (s) TestHandleClusterUpdate(t *testing.T) { }{ { name: "happy-case-with-lrs", - cdsUpdate: xdsclient.ClusterUpdate{ServiceName: serviceName, EnableLRS: true}, - wantCCS: edsCCS(serviceName, nil, true), + cdsUpdate: xdsclient.ClusterUpdate{ClusterName: serviceName, EnableLRS: true}, + wantCCS: edsCCS(serviceName, nil, true, nil), }, { name: "happy-case-without-lrs", - cdsUpdate: xdsclient.ClusterUpdate{ServiceName: serviceName}, - wantCCS: edsCCS(serviceName, nil, false), + cdsUpdate: xdsclient.ClusterUpdate{ClusterName: serviceName}, + wantCCS: edsCCS(serviceName, nil, false, nil), + }, + { + name: "happy-case-with-ring-hash-lb-policy", + cdsUpdate: xdsclient.ClusterUpdate{ + ClusterName: serviceName, + LBPolicy: &xdsclient.ClusterLBPolicyRingHash{MinimumRingSize: 10, MaximumRingSize: 100}, + }, + wantCCS: edsCCS(serviceName, nil, false, &internalserviceconfig.BalancerConfig{ + Name: ringhash.Name, + Config: &ringhash.LBConfig{MinRingSize: 10, MaxRingSize: 100}, + }), }, } @@ -397,7 +423,7 @@ func (s) TestHandleClusterUpdateError(t *testing.T) { // registered watch should not be cancelled. sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } // The CDS balancer has not yet created an EDS balancer. So, this resolver @@ -424,9 +450,9 @@ func (s) TestHandleClusterUpdateError(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } @@ -436,7 +462,7 @@ func (s) TestHandleClusterUpdateError(t *testing.T) { // Make sure the registered watch is not cancelled. sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } // Make sure the error is forwarded to the EDS balancer. @@ -451,7 +477,7 @@ func (s) TestHandleClusterUpdateError(t *testing.T) { // request cluster resource is not found. We should continue to watch it. sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a resource-not-found-error") } // Make sure the error is forwarded to the EDS balancer. @@ -483,7 +509,7 @@ func (s) TestResolverError(t *testing.T) { // registered watch should not be cancelled. sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } // The CDS balancer has not yet created an EDS balancer. So, this resolver @@ -509,9 +535,9 @@ func (s) TestResolverError(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } @@ -521,7 +547,7 @@ func (s) TestResolverError(t *testing.T) { // Make sure the registered watch is not cancelled. sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelClusterWatch(sCtx); err != context.DeadlineExceeded { t.Fatal("cluster watch cancelled for a non-resource-not-found-error") } // Make sure the error is forwarded to the EDS balancer. @@ -533,7 +559,7 @@ func (s) TestResolverError(t *testing.T) { resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "cdsBalancer resource not found error") cdsB.ResolverError(resourceErr) // Make sure the registered watch is cancelled. - if err := xdsC.WaitForCancelClusterWatch(ctx); err != nil { + if _, err := xdsC.WaitForCancelClusterWatch(ctx); err != nil { t.Fatalf("want watch to be canceled, watchForCancel failed: %v", err) } // Make sure the error is forwarded to the EDS balancer. @@ -558,9 +584,9 @@ func (s) TestUpdateSubConnState(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -594,8 +620,8 @@ func (s) TestCircuitBreaking(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will update // the service's counter with the new max requests. var maxRequests uint32 = 1 - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName, MaxRequests: &maxRequests} - wantCCS := edsCCS(serviceName, &maxRequests, false) + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: clusterName, MaxRequests: &maxRequests} + wantCCS := edsCCS(clusterName, &maxRequests, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -604,7 +630,7 @@ func (s) TestCircuitBreaking(t *testing.T) { // Since the counter's max requests was set to 1, the first request should // succeed and the second should fail. - counter := client.GetServiceRequestsCounter(serviceName) + counter := xdsclient.GetClusterRequestsCounter(clusterName, "") if err := counter.StartRequest(maxRequests); err != nil { t.Fatal(err) } @@ -626,9 +652,9 @@ func (s) TestClose(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. - cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName} - wantCCS := edsCCS(serviceName, nil, false) + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -640,7 +666,7 @@ func (s) TestClose(t *testing.T) { // Make sure that the cluster watch registered by the CDS balancer is // cancelled. - if err := xdsC.WaitForCancelClusterWatch(ctx); err != nil { + if _, err := xdsC.WaitForCancelClusterWatch(ctx); err != nil { t.Fatal(err) } @@ -659,7 +685,7 @@ func (s) TestClose(t *testing.T) { // Make sure that the UpdateClientConnState() method on the CDS balancer // returns error. - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != errBalancerClosed { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != errBalancerClosed { t.Fatalf("UpdateClientConnState() after close returned %v, want %v", err, errBalancerClosed) } @@ -683,6 +709,35 @@ func (s) TestClose(t *testing.T) { } } +func (s) TestExitIdle(t *testing.T) { + // This creates a CDS balancer, pushes a ClientConnState update with a fake + // xdsClient, and makes sure that the CDS balancer registers a watch on the + // provided xdsClient. + xdsC, cdsB, edsB, _, cancel := setupWithWatch(t) + defer func() { + cancel() + cdsB.Close() + }() + + // Here we invoke the watch callback registered on the fake xdsClient. This + // will trigger the watch handler on the CDS balancer, which will attempt to + // create a new EDS balancer. The fake EDS balancer created above will be + // returned to the CDS balancer, because we have overridden the + // newChildBalancer function as part of test setup. + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} + wantCCS := edsCCS(serviceName, nil, false, nil) + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { + t.Fatal(err) + } + + // Call ExitIdle on the CDS balancer. + cdsB.ExitIdle() + + edsB.exitIdleCh.Receive(ctx) +} + // TestParseConfig verifies the ParseConfig() method in the CDS balancer. func (s) TestParseConfig(t *testing.T) { bb := balancer.Get(cdsName) diff --git a/xds/internal/balancer/cdsbalancer/cluster_handler.go b/xds/internal/balancer/cdsbalancer/cluster_handler.go new file mode 100644 index 00000000000..163a8c0a2e1 --- /dev/null +++ b/xds/internal/balancer/cdsbalancer/cluster_handler.go @@ -0,0 +1,318 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cdsbalancer + +import ( + "errors" + "sync" + + "google.golang.org/grpc/xds/internal/xdsclient" +) + +var errNotReceivedUpdate = errors.New("tried to construct a cluster update on a cluster that has not received an update") + +// clusterHandlerUpdate wraps the information received from the registered CDS +// watcher. A non-nil error is propagated to the underlying cluster_resolver +// balancer. A valid update results in creating a new cluster_resolver balancer +// (if one doesn't already exist) and pushing the update to it. +type clusterHandlerUpdate struct { + // securityCfg is the Security Config from the top (root) cluster. + securityCfg *xdsclient.SecurityConfig + // lbPolicy is the lb policy from the top (root) cluster. + // + // Currently, we only support roundrobin or ringhash, and since roundrobin + // does need configs, this is only set to the ringhash config, if the policy + // is ringhash. In the future, if we support more policies, we can make this + // an interface, and set it to config of the other policies. + lbPolicy *xdsclient.ClusterLBPolicyRingHash + + // updates is a list of ClusterUpdates from all the leaf clusters. + updates []xdsclient.ClusterUpdate + err error +} + +// clusterHandler will be given a name representing a cluster. It will then +// update the CDS policy constantly with a list of Clusters to pass down to +// XdsClusterResolverLoadBalancingPolicyConfig in a stream like fashion. +type clusterHandler struct { + parent *cdsBalancer + + // A mutex to protect entire tree of clusters. + clusterMutex sync.Mutex + root *clusterNode + rootClusterName string + + // A way to ping CDS Balancer about any updates or errors to a Node in the + // tree. This will either get called from this handler constructing an + // update or from a child with an error. Capacity of one as the only update + // CDS Balancer cares about is the most recent update. + updateChannel chan clusterHandlerUpdate +} + +func newClusterHandler(parent *cdsBalancer) *clusterHandler { + return &clusterHandler{ + parent: parent, + updateChannel: make(chan clusterHandlerUpdate, 1), + } +} + +func (ch *clusterHandler) updateRootCluster(rootClusterName string) { + ch.clusterMutex.Lock() + defer ch.clusterMutex.Unlock() + if ch.root == nil { + // Construct a root node on first update. + ch.root = createClusterNode(rootClusterName, ch.parent.xdsClient, ch) + ch.rootClusterName = rootClusterName + return + } + // Check if root cluster was changed. If it was, delete old one and start + // new one, if not do nothing. + if rootClusterName != ch.rootClusterName { + ch.root.delete() + ch.root = createClusterNode(rootClusterName, ch.parent.xdsClient, ch) + ch.rootClusterName = rootClusterName + } +} + +// This function tries to construct a cluster update to send to CDS. +func (ch *clusterHandler) constructClusterUpdate() { + if ch.root == nil { + // If root is nil, this handler is closed, ignore the update. + return + } + clusterUpdate, err := ch.root.constructClusterUpdate() + if err != nil { + // If there was an error received no op, as this simply means one of the + // children hasn't received an update yet. + return + } + // For a ClusterUpdate, the only update CDS cares about is the most + // recent one, so opportunistically drain the update channel before + // sending the new update. + select { + case <-ch.updateChannel: + default: + } + ch.updateChannel <- clusterHandlerUpdate{ + securityCfg: ch.root.clusterUpdate.SecurityCfg, + lbPolicy: ch.root.clusterUpdate.LBPolicy, + updates: clusterUpdate, + } +} + +// close() is meant to be called by CDS when the CDS balancer is closed, and it +// cancels the watches for every cluster in the cluster tree. +func (ch *clusterHandler) close() { + ch.clusterMutex.Lock() + defer ch.clusterMutex.Unlock() + if ch.root == nil { + return + } + ch.root.delete() + ch.root = nil + ch.rootClusterName = "" +} + +// This logically represents a cluster. This handles all the logic for starting +// and stopping a cluster watch, handling any updates, and constructing a list +// recursively for the ClusterHandler. +type clusterNode struct { + // A way to cancel the watch for the cluster. + cancelFunc func() + + // A list of children, as the Node can be an aggregate Cluster. + children []*clusterNode + + // A ClusterUpdate in order to build a list of cluster updates for CDS to + // send down to child XdsClusterResolverLoadBalancingPolicy. + clusterUpdate xdsclient.ClusterUpdate + + // This boolean determines whether this Node has received an update or not. + // This isn't the best practice, but this will protect a list of Cluster + // Updates from being constructed if a cluster in the tree has not received + // an update yet. + receivedUpdate bool + + clusterHandler *clusterHandler +} + +// CreateClusterNode creates a cluster node from a given clusterName. This will +// also start the watch for that cluster. +func createClusterNode(clusterName string, xdsClient xdsclient.XDSClient, topLevelHandler *clusterHandler) *clusterNode { + c := &clusterNode{ + clusterHandler: topLevelHandler, + } + // Communicate with the xds client here. + topLevelHandler.parent.logger.Infof("CDS watch started on %v", clusterName) + cancel := xdsClient.WatchCluster(clusterName, c.handleResp) + c.cancelFunc = func() { + topLevelHandler.parent.logger.Infof("CDS watch canceled on %v", clusterName) + cancel() + } + return c +} + +// This function cancels the cluster watch on the cluster and all of it's +// children. +func (c *clusterNode) delete() { + c.cancelFunc() + for _, child := range c.children { + child.delete() + } +} + +// Construct cluster update (potentially a list of ClusterUpdates) for a node. +func (c *clusterNode) constructClusterUpdate() ([]xdsclient.ClusterUpdate, error) { + // If the cluster has not yet received an update, the cluster update is not + // yet ready. + if !c.receivedUpdate { + return nil, errNotReceivedUpdate + } + + // Base case - LogicalDNS or EDS. Both of these cluster types will be tied + // to a single ClusterUpdate. + if c.clusterUpdate.ClusterType != xdsclient.ClusterTypeAggregate { + return []xdsclient.ClusterUpdate{c.clusterUpdate}, nil + } + + // If an aggregate construct a list by recursively calling down to all of + // it's children. + var childrenUpdates []xdsclient.ClusterUpdate + for _, child := range c.children { + childUpdateList, err := child.constructClusterUpdate() + if err != nil { + return nil, err + } + childrenUpdates = append(childrenUpdates, childUpdateList...) + } + return childrenUpdates, nil +} + +// handleResp handles a xds response for a particular cluster. This function +// also handles any logic with regards to any child state that may have changed. +// At the end of the handleResp(), the clusterUpdate will be pinged in certain +// situations to try and construct an update to send back to CDS. +func (c *clusterNode) handleResp(clusterUpdate xdsclient.ClusterUpdate, err error) { + c.clusterHandler.clusterMutex.Lock() + defer c.clusterHandler.clusterMutex.Unlock() + if err != nil { // Write this error for run() to pick up in CDS LB policy. + // For a ClusterUpdate, the only update CDS cares about is the most + // recent one, so opportunistically drain the update channel before + // sending the new update. + select { + case <-c.clusterHandler.updateChannel: + default: + } + c.clusterHandler.updateChannel <- clusterHandlerUpdate{err: err} + return + } + + c.receivedUpdate = true + c.clusterUpdate = clusterUpdate + + // If the cluster was a leaf node, if the cluster update received had change + // in the cluster update then the overall cluster update would change and + // there is a possibility for the overall update to build so ping cluster + // handler to return. Also, if there was any children from previously, + // delete the children, as the cluster type is no longer an aggregate + // cluster. + if clusterUpdate.ClusterType != xdsclient.ClusterTypeAggregate { + for _, child := range c.children { + child.delete() + } + c.children = nil + // This is an update in the one leaf node, should try to send an update + // to the parent CDS balancer. + // + // Note that this update might be a duplicate from the previous one. + // Because the update contains not only the cluster name to watch, but + // also the extra fields (e.g. security config). There's no good way to + // compare all the fields. + c.clusterHandler.constructClusterUpdate() + return + } + + // Aggregate cluster handling. + newChildren := make(map[string]bool) + for _, childName := range clusterUpdate.PrioritizedClusterNames { + newChildren[childName] = true + } + + // These booleans help determine whether this callback will ping the overall + // clusterHandler to try and construct an update to send back to CDS. This + // will be determined by whether there would be a change in the overall + // clusterUpdate for the whole tree (ex. change in clusterUpdate for current + // cluster or a deleted child) and also if there's even a possibility for + // the update to build (ex. if a child is created and a watch is started, + // that child hasn't received an update yet due to the mutex lock on this + // callback). + var createdChild, deletedChild bool + + // This map will represent the current children of the cluster. It will be + // first added to in order to represent the new children. It will then have + // any children deleted that are no longer present. Then, from the cluster + // update received, will be used to construct the new child list. + mapCurrentChildren := make(map[string]*clusterNode) + for _, child := range c.children { + mapCurrentChildren[child.clusterUpdate.ClusterName] = child + } + + // Add and construct any new child nodes. + for child := range newChildren { + if _, inChildrenAlready := mapCurrentChildren[child]; !inChildrenAlready { + createdChild = true + mapCurrentChildren[child] = createClusterNode(child, c.clusterHandler.parent.xdsClient, c.clusterHandler) + } + } + + // Delete any child nodes no longer in the aggregate cluster's children. + for child := range mapCurrentChildren { + if _, stillAChild := newChildren[child]; !stillAChild { + deletedChild = true + mapCurrentChildren[child].delete() + delete(mapCurrentChildren, child) + } + } + + // The order of the children list matters, so use the clusterUpdate from + // xdsclient as the ordering, and use that logical ordering for the new + // children list. This will be a mixture of child nodes which are all + // already constructed in the mapCurrentChildrenMap. + var children = make([]*clusterNode, 0, len(clusterUpdate.PrioritizedClusterNames)) + + for _, orderedChild := range clusterUpdate.PrioritizedClusterNames { + // The cluster's already have watches started for them in xds client, so + // you can use these pointers to construct the new children list, you + // just have to put them in the correct order using the original cluster + // update. + currentChild := mapCurrentChildren[orderedChild] + children = append(children, currentChild) + } + + c.children = children + + // If the cluster is an aggregate cluster, if this callback created any new + // child cluster nodes, then there's no possibility for a full cluster + // update to successfully build, as those created children will not have + // received an update yet. However, if there was simply a child deleted, + // then there is a possibility that it will have a full cluster update to + // build and also will have a changed overall cluster update from the + // deleted child. + if deletedChild && !createdChild { + c.clusterHandler.constructClusterUpdate() + } +} diff --git a/xds/internal/balancer/cdsbalancer/cluster_handler_test.go b/xds/internal/balancer/cdsbalancer/cluster_handler_test.go new file mode 100644 index 00000000000..cb9b4e14da3 --- /dev/null +++ b/xds/internal/balancer/cdsbalancer/cluster_handler_test.go @@ -0,0 +1,685 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cdsbalancer + +import ( + "context" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const ( + edsService = "EDS Service" + logicalDNSService = "Logical DNS Service" + edsService2 = "EDS Service 2" + logicalDNSService2 = "Logical DNS Service 2" + aggregateClusterService = "Aggregate Cluster Service" +) + +// setupTests creates a clusterHandler with a fake xds client for control over +// xds client. +func setupTests(t *testing.T) (*clusterHandler, *fakeclient.Client) { + xdsC := fakeclient.NewClient() + ch := newClusterHandler(&cdsBalancer{xdsClient: xdsC}) + return ch, xdsC +} + +// Simplest case: the cluster handler receives a cluster name, handler starts a +// watch for that cluster, xds client returns that it is a Leaf Node (EDS or +// LogicalDNS), not a tree, so expectation that update is written to buffer +// which will be read by CDS LB. +func (s) TestSuccessCaseLeafNode(t *testing.T) { + tests := []struct { + name string + clusterName string + clusterUpdate xdsclient.ClusterUpdate + lbPolicy *xdsclient.ClusterLBPolicyRingHash + }{ + { + name: "test-update-root-cluster-EDS-success", + clusterName: edsService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, + }, + { + name: "test-update-root-cluster-EDS-with-ring-hash", + clusterName: logicalDNSService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + LBPolicy: &xdsclient.ClusterLBPolicyRingHash{MinimumRingSize: 10, MaximumRingSize: 100}, + }, + lbPolicy: &xdsclient.ClusterLBPolicyRingHash{MinimumRingSize: 10, MaximumRingSize: 100}, + }, + { + name: "test-update-root-cluster-Logical-DNS-success", + clusterName: logicalDNSService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ch, fakeClient := setupTests(t) + // When you first update the root cluster, it should hit the code + // path which will start a cluster node for that root. Updating the + // root cluster logically represents a ping from a ClientConn. + ch.updateRootCluster(test.clusterName) + // Starting a cluster node involves communicating with the + // xdsClient, telling it to watch a cluster. + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + gotCluster, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != test.clusterName { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, test.clusterName) + } + // Invoke callback with xds client with a certain clusterUpdate. Due + // to this cluster update filling out the whole cluster tree, as the + // cluster is of a root type (EDS or Logical DNS) and not an + // aggregate cluster, this should trigger the ClusterHandler to + // write to the update buffer to update the CDS policy. + fakeClient.InvokeWatchClusterCallback(test.clusterUpdate, nil) + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{test.clusterUpdate}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + if diff := cmp.Diff(chu.lbPolicy, test.lbPolicy); diff != "" { + t.Fatalf("got unexpected lb policy in cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } + // Close the clusterHandler. This is meant to be called when the CDS + // Balancer is closed, and the call should cancel the watch for this + // cluster. + ch.close() + clusterNameDeleted, err := fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != test.clusterName { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + }) + } +} + +// The cluster handler receives a cluster name, handler starts a watch for that +// cluster, xds client returns that it is a Leaf Node (EDS or LogicalDNS), not a +// tree, so expectation that first update is written to buffer which will be +// read by CDS LB. Then, send a new cluster update that is different, with the +// expectation that it is also written to the update buffer to send back to CDS. +func (s) TestSuccessCaseLeafNodeThenNewUpdate(t *testing.T) { + tests := []struct { + name string + clusterName string + clusterUpdate xdsclient.ClusterUpdate + newClusterUpdate xdsclient.ClusterUpdate + }{ + {name: "test-update-root-cluster-then-new-update-EDS-success", + clusterName: edsService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, + newClusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService2, + }, + }, + { + name: "test-update-root-cluster-then-new-update-Logical-DNS-success", + clusterName: logicalDNSService, + clusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, + newClusterUpdate: xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService2, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ch, fakeClient := setupTests(t) + ch.updateRootCluster(test.clusterName) + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + fakeClient.InvokeWatchClusterCallback(test.clusterUpdate, nil) + select { + case <-ch.updateChannel: + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from updateChannel.") + } + + // Check that sending the same cluster update also induces an update + // to be written to update buffer. + fakeClient.InvokeWatchClusterCallback(test.clusterUpdate, nil) + shouldNotHappenCtx, shouldNotHappenCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + select { + case <-ch.updateChannel: + case <-shouldNotHappenCtx.Done(): + t.Fatal("Timed out waiting for update from updateChannel.") + } + + // Above represents same thing as the simple + // TestSuccessCaseLeafNode, extra behavior + validation (clusterNode + // which is a leaf receives a changed clusterUpdate, which should + // ping clusterHandler, which should then write to the update + // buffer). + fakeClient.InvokeWatchClusterCallback(test.newClusterUpdate, nil) + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{test.newClusterUpdate}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from updateChannel.") + } + }) + } +} + +// TestUpdateRootClusterAggregateSuccess tests the case where an aggregate +// cluster is a root pointing to two child clusters one of type EDS and the +// other of type LogicalDNS. This test will then send cluster updates for both +// the children, and at the end there should be a successful clusterUpdate +// written to the update buffer to send back to CDS. +func (s) TestUpdateRootClusterAggregateSuccess(t *testing.T) { + ch, fakeClient := setupTests(t) + ch.updateRootCluster(aggregateClusterService) + + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + gotCluster, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != aggregateClusterService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, aggregateClusterService) + } + + // The xdsClient telling the clusterNode that the cluster type is an + // aggregate cluster which will cause a lot of downstream behavior. For a + // cluster type that isn't an aggregate, the behavior is simple. The + // clusterNode will simply get a successful update, which will then ping the + // clusterHandler which will successfully build an update to send to the CDS + // policy. In the aggregate cluster case, the handleResp callback must also + // start watches for the aggregate cluster's children. The ping to the + // clusterHandler at the end of handleResp should be a no-op, as neither the + // EDS or LogicalDNS child clusters have received an update yet. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: aggregateClusterService, + PrioritizedClusterNames: []string{edsService, logicalDNSService}, + }, nil) + + // xds client should be called to start a watch for one of the child + // clusters of the aggregate. The order of the children in the update + // written to the buffer to send to CDS matters, however there is no + // guarantee on the order it will start the watches of the children. + gotCluster, err = fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService { + if gotCluster != logicalDNSService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, edsService) + } + } + + // xds client should then be called to start a watch for the second child + // cluster. + gotCluster, err = fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService { + if gotCluster != logicalDNSService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, logicalDNSService) + } + } + + // The handleResp() call on the root aggregate cluster should not ping the + // cluster handler to try and construct an update, as the handleResp() + // callback knows that when a child is created, it cannot possibly build a + // successful update yet. Thus, there should be nothing in the update + // channel. + + shouldNotHappenCtx, shouldNotHappenCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + + select { + case <-ch.updateChannel: + t.Fatal("Cluster Handler wrote an update to updateChannel when it shouldn't have, as each node in the full cluster tree has not yet received an update") + case <-shouldNotHappenCtx.Done(): + } + + // Send callback for the EDS child cluster. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, nil) + + // EDS child cluster will ping the Cluster Handler, to try an update, which + // still won't successfully build as the LogicalDNS child of the root + // aggregate cluster has not yet received and handled an update. + select { + case <-ch.updateChannel: + t.Fatal("Cluster Handler wrote an update to updateChannel when it shouldn't have, as each node in the full cluster tree has not yet received an update") + case <-shouldNotHappenCtx.Done(): + } + + // Invoke callback for Logical DNS child cluster. + + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, nil) + + // Will Ping Cluster Handler, which will finally successfully build an + // update as all nodes in the tree of clusters have received an update. + // Since this cluster is an aggregate cluster comprised of two children, the + // returned update should be length 2, as the xds cluster resolver LB policy + // only cares about the full list of LogicalDNS and EDS clusters + // representing the base nodes of the tree of clusters. This list should be + // ordered as per the cluster update. + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, { + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the cluster update to be written to the update buffer.") + } +} + +// TestUpdateRootClusterAggregateThenChangeChild tests the scenario where you +// have an aggregate cluster with an EDS child and a LogicalDNS child, then you +// change one of the children and send an update for the changed child. This +// should write a new update to the update buffer to send back to CDS. +func (s) TestUpdateRootClusterAggregateThenChangeChild(t *testing.T) { + // This initial code is the same as the test for the aggregate success case, + // except without validations. This will get this test to the point where it + // can change one of the children. + ch, fakeClient := setupTests(t) + ch.updateRootCluster(aggregateClusterService) + + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: aggregateClusterService, + PrioritizedClusterNames: []string{edsService, logicalDNSService}, + }, nil) + fakeClient.WaitForWatchCluster(ctx) + fakeClient.WaitForWatchCluster(ctx) + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, nil) + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, nil) + + select { + case <-ch.updateChannel: + case <-ctx.Done(): + t.Fatal("Timed out waiting for the cluster update to be written to the update buffer.") + } + + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: aggregateClusterService, + PrioritizedClusterNames: []string{edsService, logicalDNSService2}, + }, nil) + + // The cluster update let's the aggregate cluster know that it's children + // are now edsService and logicalDNSService2, which implies that the + // aggregateCluster lost it's old logicalDNSService child. Thus, the + // logicalDNSService child should be deleted. + clusterNameDeleted, err := fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != logicalDNSService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + + // The handleResp() callback should then start a watch for + // logicalDNSService2. + clusterNameCreated, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if clusterNameCreated != logicalDNSService2 { + t.Fatalf("xdsClient.WatchCDS called for cluster %v, want: %v", clusterNameCreated, logicalDNSService2) + } + + // handleResp() should try and send an update here, but it will fail as + // logicalDNSService2 has not yet received an update. + shouldNotHappenCtx, shouldNotHappenCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + select { + case <-ch.updateChannel: + t.Fatal("Cluster Handler wrote an update to updateChannel when it shouldn't have, as each node in the full cluster tree has not yet received an update") + case <-shouldNotHappenCtx.Done(): + } + + // Invoke a callback for the new logicalDNSService2 - this will fill out the + // tree with successful updates. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService2, + }, nil) + + // Behavior: This update make every node in the tree of cluster have + // received an update. Thus, at the end of this callback, when you ping the + // clusterHandler to try and construct an update, the update should now + // successfully be written to update buffer to send back to CDS. This new + // update should contain the new child of LogicalDNS2. + + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, { + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService2, + }}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the cluster update to be written to the update buffer.") + } +} + +// TestUpdateRootClusterAggregateThenChangeRootToEDS tests the situation where +// you have a fully updated aggregate cluster (where AggregateCluster success +// test gets you) as the root cluster, then you update that root cluster to a +// cluster of type EDS. +func (s) TestUpdateRootClusterAggregateThenChangeRootToEDS(t *testing.T) { + // This initial code is the same as the test for the aggregate success case, + // except without validations. This will get this test to the point where it + // can update the root cluster to one of type EDS. + ch, fakeClient := setupTests(t) + ch.updateRootCluster(aggregateClusterService) + + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: aggregateClusterService, + PrioritizedClusterNames: []string{edsService, logicalDNSService}, + }, nil) + fakeClient.WaitForWatchCluster(ctx) + fakeClient.WaitForWatchCluster(ctx) + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService, + }, nil) + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeLogicalDNS, + ClusterName: logicalDNSService, + }, nil) + + select { + case <-ch.updateChannel: + case <-ctx.Done(): + t.Fatal("Timed out waiting for the cluster update to be written to the update buffer.") + } + + // Changes the root aggregate cluster to a EDS cluster. This should delete + // the root aggregate cluster and all of it's children by successfully + // canceling the watches for them. + ch.updateRootCluster(edsService2) + + // Reads from the cancel channel, should first be type Aggregate, then EDS + // then Logical DNS. + clusterNameDeleted, err := fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != aggregateClusterService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + + clusterNameDeleted, err = fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != edsService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + + clusterNameDeleted, err = fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != logicalDNSService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want: %v", clusterNameDeleted, logicalDNSService) + } + + // After deletion, it should start a watch for the EDS Cluster. The behavior + // for this EDS Cluster receiving an update from xds client and then + // successfully writing an update to send back to CDS is already tested in + // the updateEDS success case. + gotCluster, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService2 { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, edsService2) + } +} + +// TestHandleRespInvokedWithError tests that when handleResp is invoked with an +// error, that the error is successfully written to the update buffer. +func (s) TestHandleRespInvokedWithError(t *testing.T) { + ch, fakeClient := setupTests(t) + ch.updateRootCluster(edsService) + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{}, errors.New("some error")) + select { + case chu := <-ch.updateChannel: + if chu.err.Error() != "some error" { + t.Fatalf("Did not receive the expected error, instead received: %v", chu.err.Error()) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } +} + +// TestSwitchClusterNodeBetweenLeafAndAggregated tests having an existing +// cluster node switch between a leaf and an aggregated cluster. When the +// cluster switches from a leaf to an aggregated cluster, it should add +// children, and when it switches back to a leaf, it should delete those new +// children and also successfully write a cluster update to the update buffer. +func (s) TestSwitchClusterNodeBetweenLeafAndAggregated(t *testing.T) { + // Getting the test to the point where there's a root cluster which is a eds + // leaf. + ch, fakeClient := setupTests(t) + ch.updateRootCluster(edsService2) + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + _, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService2, + }, nil) + select { + case <-ch.updateChannel: + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } + // Switch the cluster to an aggregate cluster, this should cause two new + // child watches to be created. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeAggregate, + ClusterName: edsService2, + PrioritizedClusterNames: []string{edsService, logicalDNSService}, + }, nil) + + // xds client should be called to start a watch for one of the child + // clusters of the aggregate. The order of the children in the update + // written to the buffer to send to CDS matters, however there is no + // guarantee on the order it will start the watches of the children. + gotCluster, err := fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService { + if gotCluster != logicalDNSService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, edsService) + } + } + + // xds client should then be called to start a watch for the second child + // cluster. + gotCluster, err = fakeClient.WaitForWatchCluster(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) + } + if gotCluster != edsService { + if gotCluster != logicalDNSService { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, logicalDNSService) + } + } + + // After starting a watch for the second child cluster, there should be no + // more watches started on the xds client. + shouldNotHappenCtx, shouldNotHappenCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + gotCluster, err = fakeClient.WaitForWatchCluster(shouldNotHappenCtx) + if err == nil { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, no more watches should be started.", gotCluster) + } + + // The handleResp() call on the root aggregate cluster should not ping the + // cluster handler to try and construct an update, as the handleResp() + // callback knows that when a child is created, it cannot possibly build a + // successful update yet. Thus, there should be nothing in the update + // channel. + + shouldNotHappenCtx, shouldNotHappenCtxCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + + select { + case <-ch.updateChannel: + t.Fatal("Cluster Handler wrote an update to updateChannel when it shouldn't have, as each node in the full cluster tree has not yet received an update") + case <-shouldNotHappenCtx.Done(): + } + + // Switch the cluster back to an EDS Cluster. This should cause the two + // children to be deleted. + fakeClient.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService2, + }, nil) + + // Should delete the two children (no guarantee of ordering deleted, which + // is ok), then successfully write an update to the update buffer as the + // full cluster tree has received updates. + clusterNameDeleted, err := fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + // No guarantee of ordering, so one of the children should be deleted first. + if clusterNameDeleted != edsService { + if clusterNameDeleted != logicalDNSService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want either: %v or: %v", clusterNameDeleted, edsService, logicalDNSService) + } + } + // Then the other child should be deleted. + clusterNameDeleted, err = fakeClient.WaitForCancelClusterWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelCDS failed with error: %v", err) + } + if clusterNameDeleted != edsService { + if clusterNameDeleted != logicalDNSService { + t.Fatalf("xdsClient.CancelCDS called for cluster %v, want either: %v or: %v", clusterNameDeleted, edsService, logicalDNSService) + } + } + + // After cancelling a watch for the second child cluster, there should be no + // more watches cancelled on the xds client. + shouldNotHappenCtx, shouldNotHappenCtxCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer shouldNotHappenCtxCancel() + gotCluster, err = fakeClient.WaitForCancelClusterWatch(shouldNotHappenCtx) + if err == nil { + t.Fatalf("xdsClient.WatchCDS called for cluster: %v, no more watches should be cancelled.", gotCluster) + } + + // Then an update should successfully be written to the update buffer. + select { + case chu := <-ch.updateChannel: + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{{ + ClusterType: xdsclient.ClusterTypeEDS, + ClusterName: edsService2, + }}); diff != "" { + t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } +} diff --git a/xds/internal/balancer/clusterimpl/balancer_test.go b/xds/internal/balancer/clusterimpl/balancer_test.go index 3e6ac0fd290..65ec17348f4 100644 --- a/xds/internal/balancer/clusterimpl/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/balancer_test.go @@ -20,6 +20,8 @@ package clusterimpl import ( "context" + "errors" + "fmt" "strings" "testing" "time" @@ -27,20 +29,28 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/balancer/stub" + "google.golang.org/grpc/internal/grpctest" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/resolver" - "google.golang.org/grpc/xds/internal/client/load" + xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) const ( - defaultTestTimeout = 1 * time.Second - testClusterName = "test-cluster" - testServiceName = "test-eds-service" - testLRSServerName = "test-lrs-name" + defaultTestTimeout = 1 * time.Second + defaultShortTestTimeout = 100 * time.Microsecond + + testClusterName = "test-cluster" + testServiceName = "test-eds-service" + testLRSServerName = "test-lrs-name" ) var ( @@ -54,19 +64,33 @@ var ( } ) +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func subConnFromPicker(p balancer.Picker) func() balancer.SubConn { + return func() balancer.SubConn { + scst, _ := p.Pick(balancer.PickInfo{}) + return scst.SubConn + } +} + func init() { - newRandomWRR = testutils.NewTestWRR + NewRandomWRR = testutils.NewTestWRR } // TestDropByCategory verifies that the balancer correctly drops the picks, and // that the drops are reported. -func TestDropByCategory(t *testing.T) { +func (s) TestDropByCategory(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() - builder := balancer.Get(clusterImplName) + builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() @@ -77,14 +101,12 @@ func TestDropByCategory(t *testing.T) { dropDenominator = 2 ) if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, - BalancerConfig: &lbConfig{ - Cluster: testClusterName, - EDSServiceName: testServiceName, - LRSLoadReportingServerName: newString(testLRSServerName), - DropCategories: []dropCategory{{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + DropCategories: []DropConfig{{ Category: dropReason, RequestsPerMillion: million * dropNumerator / dropDenominator, }}, @@ -150,6 +172,9 @@ func TestDropByCategory(t *testing.T) { Service: testServiceName, TotalDrops: dropCount, Drops: map[string]uint64{dropReason: dropCount}, + LocalityStats: map[string]load.LocalityData{ + assertString(xdsinternal.LocalityID{}.ToString): {RequestStats: load.RequestData{Succeeded: rpcCount - dropCount}}, + }, }} gotStatsData0 := loadStore.Stats([]string{testClusterName}) @@ -164,14 +189,12 @@ func TestDropByCategory(t *testing.T) { dropDenominator2 = 4 ) if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, - BalancerConfig: &lbConfig{ - Cluster: testClusterName, - EDSServiceName: testServiceName, - LRSLoadReportingServerName: newString(testLRSServerName), - DropCategories: []dropCategory{{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + DropCategories: []DropConfig{{ Category: dropReason2, RequestsPerMillion: million * dropNumerator2 / dropDenominator2, }}, @@ -207,6 +230,9 @@ func TestDropByCategory(t *testing.T) { Service: testServiceName, TotalDrops: dropCount2, Drops: map[string]uint64{dropReason2: dropCount2}, + LocalityStats: map[string]load.LocalityData{ + assertString(xdsinternal.LocalityID{}.ToString): {RequestStats: load.RequestData{Succeeded: rpcCount - dropCount2}}, + }, }} gotStatsData1 := loadStore.Stats([]string{testClusterName}) @@ -217,27 +243,24 @@ func TestDropByCategory(t *testing.T) { // TestDropCircuitBreaking verifies that the balancer correctly drops the picks // due to circuit breaking, and that the drops are reported. -func TestDropCircuitBreaking(t *testing.T) { +func (s) TestDropCircuitBreaking(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() - builder := balancer.Get(clusterImplName) + builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) b := builder.Build(cc, balancer.BuildOptions{}) defer b.Close() var maxRequest uint32 = 50 if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, - BalancerConfig: &lbConfig{ - Cluster: testClusterName, - EDSServiceName: testServiceName, - LRSLoadReportingServerName: newString(testLRSServerName), - MaxConcurrentRequests: &maxRequest, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + MaxConcurrentRequests: &maxRequest, ChildPolicy: &internalserviceconfig.BalancerConfig{ Name: roundrobin.Name, }, @@ -317,6 +340,9 @@ func TestDropCircuitBreaking(t *testing.T) { Cluster: testClusterName, Service: testServiceName, TotalDrops: uint64(maxRequest), + LocalityStats: map[string]load.LocalityData{ + assertString(xdsinternal.LocalityID{}.ToString): {RequestStats: load.RequestData{Succeeded: uint64(rpcCount - maxRequest + 50)}}, + }, }} gotStatsData0 := loadStore.Stats([]string{testClusterName}) @@ -324,3 +350,452 @@ func TestDropCircuitBreaking(t *testing.T) { t.Fatalf("got unexpected drop reports, diff (-got, +want): %v", diff) } } + +// TestPickerUpdateAfterClose covers the case where a child policy sends a +// picker update after the cluster_impl policy is closed. Because picker updates +// are handled in the run() goroutine, which exits before Close() returns, we +// expect the above picker update to be dropped. +func (s) TestPickerUpdateAfterClose(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + + // Create a stub balancer which waits for the cluster_impl policy to be + // closed before sending a picker update (upon receipt of a subConn state + // change). + closeCh := make(chan struct{}) + const childPolicyName = "stubBalancer-TestPickerUpdateAfterClose" + stub.Register(childPolicyName, stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { + // Create a subConn which will be used later on to test the race + // between UpdateSubConnState() and Close(). + bd.ClientConn.NewSubConn(ccs.ResolverState.Addresses, balancer.NewSubConnOptions{}) + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, _ balancer.SubConn, _ balancer.SubConnState) { + go func() { + // Wait for Close() to be called on the parent policy before + // sending the picker update. + <-closeCh + bd.ClientConn.UpdateState(balancer.State{ + Picker: base.NewErrPicker(errors.New("dummy error picker")), + }) + }() + }, + }) + + var maxRequest uint32 = 50 + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + MaxConcurrentRequests: &maxRequest, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: childPolicyName, + }, + }, + }); err != nil { + b.Close() + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + // Send a subConn state change to trigger a picker update. The stub balancer + // that we use as the child policy will not send a picker update until the + // parent policy is closed. + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.Close() + close(closeCh) + + select { + case <-cc.NewPickerCh: + t.Fatalf("unexpected picker update after balancer is closed") + case <-time.After(defaultShortTestTimeout): + } +} + +// TestClusterNameInAddressAttributes covers the case that cluster name is +// attached to the subconn address attributes. +func (s) TestClusterNameInAddressAttributes(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + // This should get the connecting picker. + p0 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p0.Pick(balancer.PickInfo{}) + if err != balancer.ErrNoSubConnAvailable { + t.Fatalf("picker.Pick, got _,%v, want Err=%v", err, balancer.ErrNoSubConnAvailable) + } + } + + addrs1 := <-cc.NewSubConnAddrsCh + if got, want := addrs1[0].Addr, testBackendAddrs[0].Addr; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + cn, ok := internal.GetXDSHandshakeClusterName(addrs1[0].Attributes) + if !ok || cn != testClusterName { + t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, testClusterName) + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // Test pick with one backend. + p1 := <-cc.NewPickerCh + const rpcCount = 20 + for i := 0; i < rpcCount; i++ { + gotSCSt, err := p1.Pick(balancer.PickInfo{}) + if err != nil || !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1) + } + if gotSCSt.Done != nil { + gotSCSt.Done(balancer.DoneInfo{}) + } + } + + const testClusterName2 = "test-cluster-2" + var addr2 = resolver.Address{Addr: "2.2.2.2"} + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: []resolver.Address{addr2}}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName2, + EDSServiceName: testServiceName, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + addrs2 := <-cc.NewSubConnAddrsCh + if got, want := addrs2[0].Addr, addr2.Addr; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + // New addresses should have the new cluster name. + cn2, ok := internal.GetXDSHandshakeClusterName(addrs2[0].Attributes) + if !ok || cn2 != testClusterName2 { + t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, testClusterName2) + } +} + +// TestReResolution verifies that when a SubConn turns transient failure, +// re-resolution is triggered. +func (s) TestReResolution(t *testing.T) { + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + // This should get the connecting picker. + p0 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p0.Pick(balancer.PickInfo{}) + if err != balancer.ErrNoSubConnAvailable { + t.Fatalf("picker.Pick, got _,%v, want Err=%v", err, balancer.ErrNoSubConnAvailable) + } + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + // This should get the transient failure picker. + p1 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p1.Pick(balancer.PickInfo{}) + if err == nil { + t.Fatalf("picker.Pick, got _,%v, want not nil", err) + } + } + + // The transient failure should trigger a re-resolution. + select { + case <-cc.ResolveNowCh: + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for ResolveNow()") + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // Test pick with one backend. + p2 := <-cc.NewPickerCh + want := []balancer.SubConn{sc1} + if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { + t.Fatalf("want %v, got %v", want, err) + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + // This should get the transient failure picker. + p3 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p3.Pick(balancer.PickInfo{}) + if err == nil { + t.Fatalf("picker.Pick, got _,%v, want not nil", err) + } + } + + // The transient failure should trigger a re-resolution. + select { + case <-cc.ResolveNowCh: + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for ResolveNow()") + } +} + +func (s) TestLoadReporting(t *testing.T) { + var testLocality = xdsinternal.LocalityID{ + Region: "test-region", + Zone: "test-zone", + SubZone: "test-sub-zone", + } + + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + addrs := make([]resolver.Address, len(testBackendAddrs)) + for i, a := range testBackendAddrs { + addrs[i] = xdsinternal.SetLocalityID(a, testLocality) + } + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + // Locality: testLocality, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + got, err := xdsC.WaitForReportLoad(ctx) + if err != nil { + t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) + } + if got.Server != testLRSServerName { + t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerName) + } + + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + // This should get the connecting picker. + p0 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p0.Pick(balancer.PickInfo{}) + if err != balancer.ErrNoSubConnAvailable { + t.Fatalf("picker.Pick, got _,%v, want Err=%v", err, balancer.ErrNoSubConnAvailable) + } + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // Test pick with one backend. + p1 := <-cc.NewPickerCh + const successCount = 5 + for i := 0; i < successCount; i++ { + gotSCSt, err := p1.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1) + } + gotSCSt.Done(balancer.DoneInfo{}) + } + const errorCount = 5 + for i := 0; i < errorCount; i++ { + gotSCSt, err := p1.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1) + } + gotSCSt.Done(balancer.DoneInfo{Err: fmt.Errorf("error")}) + } + + // Dump load data from the store and compare with expected counts. + loadStore := xdsC.LoadStore() + if loadStore == nil { + t.Fatal("loadStore is nil in xdsClient") + } + sds := loadStore.Stats([]string{testClusterName}) + if len(sds) == 0 { + t.Fatalf("loads for cluster %v not found in store", testClusterName) + } + sd := sds[0] + if sd.Cluster != testClusterName || sd.Service != testServiceName { + t.Fatalf("got unexpected load for %q, %q, want %q, %q", sd.Cluster, sd.Service, testClusterName, testServiceName) + } + testLocalityJSON, _ := testLocality.ToString() + localityData, ok := sd.LocalityStats[testLocalityJSON] + if !ok { + t.Fatalf("loads for %v not found in store", testLocality) + } + reqStats := localityData.RequestStats + if reqStats.Succeeded != successCount { + t.Errorf("got succeeded %v, want %v", reqStats.Succeeded, successCount) + } + if reqStats.Errored != errorCount { + t.Errorf("got errord %v, want %v", reqStats.Errored, errorCount) + } + if reqStats.InProgress != 0 { + t.Errorf("got inProgress %v, want %v", reqStats.InProgress, 0) + } + + b.Close() + if err := xdsC.WaitForCancelReportLoad(ctx); err != nil { + t.Fatalf("unexpected error waiting form load report to be canceled: %v", err) + } +} + +// TestUpdateLRSServer covers the cases +// - the init config specifies "" as the LRS server +// - config modifies LRS server to a different string +// - config sets LRS server to nil to stop load reporting +func (s) TestUpdateLRSServer(t *testing.T) { + var testLocality = xdsinternal.LocalityID{ + Region: "test-region", + Zone: "test-zone", + SubZone: "test-sub-zone", + } + + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + addrs := make([]resolver.Address, len(testBackendAddrs)) + for i, a := range testBackendAddrs { + addrs[i] = xdsinternal.SetLocalityID(a, testLocality) + } + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(""), + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + got, err := xdsC.WaitForReportLoad(ctx) + if err != nil { + t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) + } + if got.Server != "" { + t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, "") + } + + // Update LRS server to a different name. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + if err := xdsC.WaitForCancelReportLoad(ctx); err != nil { + t.Fatalf("unexpected error waiting form load report to be canceled: %v", err) + } + got2, err2 := xdsC.WaitForReportLoad(ctx) + if err2 != nil { + t.Fatalf("xdsClient.ReportLoad failed with error: %v", err2) + } + if got2.Server != testLRSServerName { + t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got2.Server, testLRSServerName) + } + + // Update LRS server to nil, to disable LRS. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: nil, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + if err := xdsC.WaitForCancelReportLoad(ctx); err != nil { + t.Fatalf("unexpected error waiting form load report to be canceled: %v", err) + } + + shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultShortTestTimeout) + defer shortCancel() + if s, err := xdsC.WaitForReportLoad(shortCtx); err != context.DeadlineExceeded { + t.Fatalf("unexpected load report to server: %q", s) + } +} + +func assertString(f func() (string, error)) string { + s, err := f() + if err != nil { + panic(err.Error()) + } + return s +} diff --git a/xds/internal/balancer/clusterimpl/clusterimpl.go b/xds/internal/balancer/clusterimpl/clusterimpl.go index 4e4af5a02b4..03d357b1f4e 100644 --- a/xds/internal/balancer/clusterimpl/clusterimpl.go +++ b/xds/internal/balancer/clusterimpl/clusterimpl.go @@ -26,107 +26,132 @@ package clusterimpl import ( "encoding/json" "fmt" + "sync" + "sync/atomic" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/buffer" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" + xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/balancer/loadstore" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) const ( - clusterImplName = "xds_cluster_impl_experimental" + // Name is the name of the cluster_impl balancer. + Name = "xds_cluster_impl_experimental" defaultRequestCountMax = 1024 ) func init() { - balancer.Register(clusterImplBB{}) + balancer.Register(bb{}) } -var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } +type bb struct{} -type clusterImplBB struct{} - -func (clusterImplBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &clusterImplBalancer{ ClientConn: cc, bOpts: bOpts, closed: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), loadWrapper: loadstore.NewWrapper(), + scWrappers: make(map[balancer.SubConn]*scWrapper), pickerUpdateCh: buffer.NewUnbounded(), requestCountMax: defaultRequestCountMax, } b.logger = prefixLogger(b) - - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.xdsC = client go b.run() - b.logger.Infof("Created") return b } -func (clusterImplBB) Name() string { - return clusterImplName +func (bb) Name() string { + return Name } -func (clusterImplBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(c) } -// xdsClientInterface contains only the xds_client methods needed by LRS -// balancer. It's defined so we can override xdsclient in tests. -type xdsClientInterface interface { - ReportLoad(server string) (*load.Store, func()) - Close() -} - type clusterImplBalancer struct { balancer.ClientConn - bOpts balancer.BuildOptions + + // mu guarantees mutual exclusion between Close() and handling of picker + // update to the parent ClientConn in run(). It's to make sure that the + // run() goroutine doesn't send picker update to parent after the balancer + // is closed. + // + // It's only used by the run() goroutine, but not the other exported + // functions. Because the exported functions are guaranteed to be + // synchronized with Close(). + mu sync.Mutex closed *grpcsync.Event - logger *grpclog.PrefixLogger - xdsC xdsClientInterface + done *grpcsync.Event - config *lbConfig + bOpts balancer.BuildOptions + logger *grpclog.PrefixLogger + xdsClient xdsclient.XDSClient + + config *LBConfig childLB balancer.Balancer cancelLoadReport func() - clusterName string edsServiceName string - lrsServerName string + lrsServerName *string loadWrapper *loadstore.Wrapper - // childState/drops/requestCounter can only be accessed in run(). And run() - // is the only goroutine that sends picker to the parent ClientConn. All + clusterNameMu sync.Mutex + clusterName string + + scWrappersMu sync.Mutex + // The SubConns passed to the child policy are wrapped in a wrapper, to keep + // locality ID. But when the parent ClientConn sends updates, it's going to + // give the original SubConn, not the wrapper. But the child policies only + // know about the wrapper, so when forwarding SubConn updates, they must be + // sent for the wrappers. + // + // This keeps a map from original SubConn to wrapper, so that when + // forwarding the SubConn state update, the child policy will get the + // wrappers. + scWrappers map[balancer.SubConn]*scWrapper + + // childState/drops/requestCounter keeps the state used by the most recently + // generated picker. All fields can only be accessed in run(). And run() is + // the only goroutine that sends picker to the parent ClientConn. All // requests to update picker need to be sent to pickerUpdateCh. - childState balancer.State - drops []*dropper - requestCounter *xdsclient.ServiceRequestsCounter - requestCountMax uint32 - pickerUpdateCh *buffer.Unbounded + childState balancer.State + dropCategories []DropConfig // The categories for drops. + drops []*dropper + requestCounterCluster string // The cluster name for the request counter. + requestCounterService string // The service name for the request counter. + requestCounter *xdsclient.ClusterRequestsCounter + requestCountMax uint32 + pickerUpdateCh *buffer.Unbounded } // updateLoadStore checks the config for load store, and decides whether it // needs to restart the load reporting stream. -func (cib *clusterImplBalancer) updateLoadStore(newConfig *lbConfig) error { +func (b *clusterImplBalancer) updateLoadStore(newConfig *LBConfig) error { var updateLoadClusterAndService bool // ClusterName is different, restart. ClusterName is from ClusterName and - // EdsServiceName. - if cib.clusterName != newConfig.Cluster { + // EDSServiceName. + clusterName := b.getClusterName() + if clusterName != newConfig.Cluster { updateLoadClusterAndService = true - cib.clusterName = newConfig.Cluster + b.setClusterName(newConfig.Cluster) + clusterName = newConfig.Cluster } - if cib.edsServiceName != newConfig.EDSServiceName { + if b.edsServiceName != newConfig.EDSServiceName { updateLoadClusterAndService = true - cib.edsServiceName = newConfig.EDSServiceName + b.edsServiceName = newConfig.EDSServiceName } if updateLoadClusterAndService { // This updates the clusterName and serviceName that will be reported @@ -137,39 +162,66 @@ func (cib *clusterImplBalancer) updateLoadStore(newConfig *lbConfig) error { // On the other hand, this will almost never happen. Each LRS policy // shouldn't get updated config. The parent should do a graceful switch // when the clusterName or serviceName is changed. - cib.loadWrapper.UpdateClusterAndService(cib.clusterName, cib.edsServiceName) + b.loadWrapper.UpdateClusterAndService(clusterName, b.edsServiceName) } + var ( + stopOldLoadReport bool + startNewLoadReport bool + ) + // Check if it's necessary to restart load report. - var newLRSServerName string - if newConfig.LRSLoadReportingServerName != nil { - newLRSServerName = *newConfig.LRSLoadReportingServerName - } - if cib.lrsServerName != newLRSServerName { - // LrsLoadReportingServerName is different, load should be report to a - // different server, restart. - cib.lrsServerName = newLRSServerName - if cib.cancelLoadReport != nil { - cib.cancelLoadReport() - cib.cancelLoadReport = nil + if b.lrsServerName == nil { + if newConfig.LoadReportingServerName != nil { + // Old is nil, new is not nil, start new LRS. + b.lrsServerName = newConfig.LoadReportingServerName + startNewLoadReport = true + } + // Old is nil, new is nil, do nothing. + } else if newConfig.LoadReportingServerName == nil { + // Old is not nil, new is nil, stop old, don't start new. + b.lrsServerName = newConfig.LoadReportingServerName + stopOldLoadReport = true + } else { + // Old is not nil, new is not nil, compare string values, if + // different, stop old and start new. + if *b.lrsServerName != *newConfig.LoadReportingServerName { + b.lrsServerName = newConfig.LoadReportingServerName + stopOldLoadReport = true + startNewLoadReport = true } + } + + if stopOldLoadReport { + if b.cancelLoadReport != nil { + b.cancelLoadReport() + b.cancelLoadReport = nil + if !startNewLoadReport { + // If a new LRS stream will be started later, no need to update + // it to nil here. + b.loadWrapper.UpdateLoadStore(nil) + } + } + } + if startNewLoadReport { var loadStore *load.Store - if cib.xdsC != nil { - loadStore, cib.cancelLoadReport = cib.xdsC.ReportLoad(cib.lrsServerName) + if b.xdsClient != nil { + loadStore, b.cancelLoadReport = b.xdsClient.ReportLoad(*b.lrsServerName) } - cib.loadWrapper.UpdateLoadStore(loadStore) + b.loadWrapper.UpdateLoadStore(loadStore) } return nil } -func (cib *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - if cib.closed.HasFired() { - cib.logger.Warningf("xds: received ClientConnState {%+v} after clusterImplBalancer was closed", s) +func (b *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + if b.closed.HasFired() { + b.logger.Warningf("xds: received ClientConnState {%+v} after clusterImplBalancer was closed", s) return nil } - newConfig, ok := s.BalancerConfig.(*lbConfig) + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) + newConfig, ok := s.BalancerConfig.(*LBConfig) if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) } @@ -182,59 +234,32 @@ func (cib *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState return fmt.Errorf("balancer %q not registered", newConfig.ChildPolicy.Name) } + if b.xdsClient == nil { + c := xdsclient.FromResolverState(s.ResolverState) + if c == nil { + return balancer.ErrBadResolverState + } + b.xdsClient = c + } + // Update load reporting config. This needs to be done before updating the // child policy because we need the loadStore from the updated client to be // passed to the ccWrapper, so that the next picker from the child policy // will pick up the new loadStore. - if err := cib.updateLoadStore(newConfig); err != nil { + if err := b.updateLoadStore(newConfig); err != nil { return err } - // Compare new drop config. And update picker if it's changed. - var updatePicker bool - if cib.config == nil || !equalDropCategories(cib.config.DropCategories, newConfig.DropCategories) { - cib.drops = make([]*dropper, 0, len(newConfig.DropCategories)) - for _, c := range newConfig.DropCategories { - cib.drops = append(cib.drops, newDropper(c)) - } - updatePicker = true - } - - // Compare cluster name. And update picker if it's changed, because circuit - // breaking's stream counter will be different. - if cib.config == nil || cib.config.Cluster != newConfig.Cluster { - cib.requestCounter = xdsclient.GetServiceRequestsCounter(newConfig.Cluster) - updatePicker = true - } - // Compare upper bound of stream count. And update picker if it's changed. - // This is also for circuit breaking. - var newRequestCountMax uint32 = 1024 - if newConfig.MaxConcurrentRequests != nil { - newRequestCountMax = *newConfig.MaxConcurrentRequests - } - if cib.requestCountMax != newRequestCountMax { - cib.requestCountMax = newRequestCountMax - updatePicker = true - } - - if updatePicker { - cib.pickerUpdateCh.Put(&dropConfigs{ - drops: cib.drops, - requestCounter: cib.requestCounter, - requestCountMax: cib.requestCountMax, - }) - } - // If child policy is a different type, recreate the sub-balancer. - if cib.config == nil || cib.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { - if cib.childLB != nil { - cib.childLB.Close() + if b.config == nil || b.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { + if b.childLB != nil { + b.childLB.Close() } - cib.childLB = bb.Build(cib, cib.bOpts) + b.childLB = bb.Build(b, b.bOpts) } - cib.config = newConfig + b.config = newConfig - if cib.childLB == nil { + if b.childLB == nil { // This is not an expected situation, and should be super rare in // practice. // @@ -244,85 +269,274 @@ func (cib *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState return fmt.Errorf("child policy is nil, this means balancer %q's Build() returned nil", newConfig.ChildPolicy.Name) } + // Notify run() of this new config, in case drop and request counter need + // update (which means a new picker needs to be generated). + b.pickerUpdateCh.Put(newConfig) + // Addresses and sub-balancer config are sent to sub-balancer. - return cib.childLB.UpdateClientConnState(balancer.ClientConnState{ + return b.childLB.UpdateClientConnState(balancer.ClientConnState{ ResolverState: s.ResolverState, - BalancerConfig: cib.config.ChildPolicy.Config, + BalancerConfig: b.config.ChildPolicy.Config, }) } -func (cib *clusterImplBalancer) ResolverError(err error) { - if cib.closed.HasFired() { - cib.logger.Warningf("xds: received resolver error {%+v} after clusterImplBalancer was closed", err) +func (b *clusterImplBalancer) ResolverError(err error) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received resolver error {%+v} after clusterImplBalancer was closed", err) return } - if cib.childLB != nil { - cib.childLB.ResolverError(err) + if b.childLB != nil { + b.childLB.ResolverError(err) } } -func (cib *clusterImplBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { - if cib.closed.HasFired() { - cib.logger.Warningf("xds: received subconn state change {%+v, %+v} after clusterImplBalancer was closed", sc, s) +func (b *clusterImplBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received subconn state change {%+v, %+v} after clusterImplBalancer was closed", sc, s) return } - if cib.childLB != nil { - cib.childLB.UpdateSubConnState(sc, s) + // Trigger re-resolution when a SubConn turns transient failure. This is + // necessary for the LogicalDNS in cluster_resolver policy to re-resolve. + // + // Note that this happens not only for the addresses from DNS, but also for + // EDS (cluster_impl doesn't know if it's DNS or EDS, only the parent + // knows). The parent priority policy is configured to ignore re-resolution + // signal from the EDS children. + if s.ConnectivityState == connectivity.TransientFailure { + b.ClientConn.ResolveNow(resolver.ResolveNowOptions{}) + } + + b.scWrappersMu.Lock() + if scw, ok := b.scWrappers[sc]; ok { + sc = scw + if s.ConnectivityState == connectivity.Shutdown { + // Remove this SubConn from the map on Shutdown. + delete(b.scWrappers, scw.SubConn) + } + } + b.scWrappersMu.Unlock() + if b.childLB != nil { + b.childLB.UpdateSubConnState(sc, s) } } -func (cib *clusterImplBalancer) Close() { - if cib.childLB != nil { - cib.childLB.Close() - cib.childLB = nil +func (b *clusterImplBalancer) Close() { + b.mu.Lock() + b.closed.Fire() + b.mu.Unlock() + + if b.childLB != nil { + b.childLB.Close() + b.childLB = nil + } + <-b.done.Done() + b.logger.Infof("Shutdown") +} + +func (b *clusterImplBalancer) ExitIdle() { + if b.childLB == nil { + return + } + if ei, ok := b.childLB.(balancer.ExitIdler); ok { + ei.ExitIdle() + return + } + // Fallback for children that don't support ExitIdle -- connect to all + // SubConns. + for _, sc := range b.scWrappers { + sc.Connect() } - cib.xdsC.Close() - cib.closed.Fire() - cib.logger.Infof("Shutdown") } // Override methods to accept updates from the child LB. -func (cib *clusterImplBalancer) UpdateState(state balancer.State) { +func (b *clusterImplBalancer) UpdateState(state balancer.State) { // Instead of updating parent ClientConn inline, send state to run(). - cib.pickerUpdateCh.Put(state) + b.pickerUpdateCh.Put(state) +} + +func (b *clusterImplBalancer) setClusterName(n string) { + b.clusterNameMu.Lock() + defer b.clusterNameMu.Unlock() + b.clusterName = n +} + +func (b *clusterImplBalancer) getClusterName() string { + b.clusterNameMu.Lock() + defer b.clusterNameMu.Unlock() + return b.clusterName +} + +// scWrapper is a wrapper of SubConn with locality ID. The locality ID can be +// retrieved from the addresses when creating SubConn. +// +// All SubConns passed to the child policies are wrapped in this, so that the +// picker can get the localityID from the picked SubConn, and do load reporting. +// +// After wrapping, all SubConns to and from the parent ClientConn (e.g. for +// SubConn state update, update/remove SubConn) must be the original SubConns. +// All SubConns to and from the child policy (NewSubConn, forwarding SubConn +// state update) must be the wrapper. The balancer keeps a map from the original +// SubConn to the wrapper for this purpose. +type scWrapper struct { + balancer.SubConn + // locality needs to be atomic because it can be updated while being read by + // the picker. + locality atomic.Value // type xdsinternal.LocalityID +} + +func (scw *scWrapper) updateLocalityID(lID xdsinternal.LocalityID) { + scw.locality.Store(lID) +} + +func (scw *scWrapper) localityID() xdsinternal.LocalityID { + lID, _ := scw.locality.Load().(xdsinternal.LocalityID) + return lID +} + +func (b *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + clusterName := b.getClusterName() + newAddrs := make([]resolver.Address, len(addrs)) + var lID xdsinternal.LocalityID + for i, addr := range addrs { + newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) + lID = xdsinternal.GetLocalityID(newAddrs[i]) + } + sc, err := b.ClientConn.NewSubConn(newAddrs, opts) + if err != nil { + return nil, err + } + // Wrap this SubConn in a wrapper, and add it to the map. + b.scWrappersMu.Lock() + ret := &scWrapper{SubConn: sc} + ret.updateLocalityID(lID) + b.scWrappers[sc] = ret + b.scWrappersMu.Unlock() + return ret, nil +} + +func (b *clusterImplBalancer) RemoveSubConn(sc balancer.SubConn) { + scw, ok := sc.(*scWrapper) + if !ok { + b.ClientConn.RemoveSubConn(sc) + return + } + // Remove the original SubConn from the parent ClientConn. + // + // Note that we don't remove this SubConn from the scWrappers map. We will + // need it to forward the final SubConn state Shutdown to the child policy. + // + // This entry is kept in the map until it's state is changes to Shutdown, + // and will be deleted in UpdateSubConnState(). + b.ClientConn.RemoveSubConn(scw.SubConn) +} + +func (b *clusterImplBalancer) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { + clusterName := b.getClusterName() + newAddrs := make([]resolver.Address, len(addrs)) + var lID xdsinternal.LocalityID + for i, addr := range addrs { + newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) + lID = xdsinternal.GetLocalityID(newAddrs[i]) + } + if scw, ok := sc.(*scWrapper); ok { + scw.updateLocalityID(lID) + // Need to get the original SubConn from the wrapper before calling + // parent ClientConn. + sc = scw.SubConn + } + b.ClientConn.UpdateAddresses(sc, newAddrs) } type dropConfigs struct { drops []*dropper - requestCounter *xdsclient.ServiceRequestsCounter + requestCounter *xdsclient.ClusterRequestsCounter requestCountMax uint32 } -func (cib *clusterImplBalancer) run() { +// handleDropAndRequestCount compares drop and request counter in newConfig with +// the one currently used by picker. It returns a new dropConfigs if a new +// picker needs to be generated, otherwise it returns nil. +func (b *clusterImplBalancer) handleDropAndRequestCount(newConfig *LBConfig) *dropConfigs { + // Compare new drop config. And update picker if it's changed. + var updatePicker bool + if !equalDropCategories(b.dropCategories, newConfig.DropCategories) { + b.dropCategories = newConfig.DropCategories + b.drops = make([]*dropper, 0, len(newConfig.DropCategories)) + for _, c := range newConfig.DropCategories { + b.drops = append(b.drops, newDropper(c)) + } + updatePicker = true + } + + // Compare cluster name. And update picker if it's changed, because circuit + // breaking's stream counter will be different. + if b.requestCounterCluster != newConfig.Cluster || b.requestCounterService != newConfig.EDSServiceName { + b.requestCounterCluster = newConfig.Cluster + b.requestCounterService = newConfig.EDSServiceName + b.requestCounter = xdsclient.GetClusterRequestsCounter(newConfig.Cluster, newConfig.EDSServiceName) + updatePicker = true + } + // Compare upper bound of stream count. And update picker if it's changed. + // This is also for circuit breaking. + var newRequestCountMax uint32 = 1024 + if newConfig.MaxConcurrentRequests != nil { + newRequestCountMax = *newConfig.MaxConcurrentRequests + } + if b.requestCountMax != newRequestCountMax { + b.requestCountMax = newRequestCountMax + updatePicker = true + } + + if !updatePicker { + return nil + } + return &dropConfigs{ + drops: b.drops, + requestCounter: b.requestCounter, + requestCountMax: b.requestCountMax, + } +} + +func (b *clusterImplBalancer) run() { + defer b.done.Fire() for { select { - case update := <-cib.pickerUpdateCh.Get(): - cib.pickerUpdateCh.Load() + case update := <-b.pickerUpdateCh.Get(): + b.pickerUpdateCh.Load() + b.mu.Lock() + if b.closed.HasFired() { + b.mu.Unlock() + return + } switch u := update.(type) { case balancer.State: - cib.childState = u - cib.ClientConn.UpdateState(balancer.State{ - ConnectivityState: cib.childState.ConnectivityState, - Picker: newDropPicker(cib.childState, &dropConfigs{ - drops: cib.drops, - requestCounter: cib.requestCounter, - requestCountMax: cib.requestCountMax, - }, cib.loadWrapper), + b.childState = u + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: b.childState.ConnectivityState, + Picker: newPicker(b.childState, &dropConfigs{ + drops: b.drops, + requestCounter: b.requestCounter, + requestCountMax: b.requestCountMax, + }, b.loadWrapper), }) - case *dropConfigs: - cib.drops = u.drops - cib.requestCounter = u.requestCounter - if cib.childState.Picker != nil { - cib.ClientConn.UpdateState(balancer.State{ - ConnectivityState: cib.childState.ConnectivityState, - Picker: newDropPicker(cib.childState, u, cib.loadWrapper), + case *LBConfig: + dc := b.handleDropAndRequestCount(u) + if dc != nil && b.childState.Picker != nil { + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: b.childState.ConnectivityState, + Picker: newPicker(b.childState, dc, b.loadWrapper), }) } } - case <-cib.closed.Done(): + b.mu.Unlock() + case <-b.closed.Done(): + if b.cancelLoadReport != nil { + b.cancelLoadReport() + b.cancelLoadReport = nil + } return } } diff --git a/xds/internal/balancer/clusterimpl/config.go b/xds/internal/balancer/clusterimpl/config.go index 548ab34bce4..51ff654f6eb 100644 --- a/xds/internal/balancer/clusterimpl/config.go +++ b/xds/internal/balancer/clusterimpl/config.go @@ -25,32 +25,33 @@ import ( "google.golang.org/grpc/serviceconfig" ) -type dropCategory struct { +// DropConfig contains the category, and drop ratio. +type DropConfig struct { Category string RequestsPerMillion uint32 } -// lbConfig is the balancer config for weighted_target. -type lbConfig struct { - serviceconfig.LoadBalancingConfig +// LBConfig is the balancer config for cluster_impl balancer. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` - Cluster string - EDSServiceName string - LRSLoadReportingServerName *string - MaxConcurrentRequests *uint32 - DropCategories []dropCategory - ChildPolicy *internalserviceconfig.BalancerConfig + Cluster string `json:"cluster,omitempty"` + EDSServiceName string `json:"edsServiceName,omitempty"` + LoadReportingServerName *string `json:"lrsLoadReportingServerName,omitempty"` + MaxConcurrentRequests *uint32 `json:"maxConcurrentRequests,omitempty"` + DropCategories []DropConfig `json:"dropCategories,omitempty"` + ChildPolicy *internalserviceconfig.BalancerConfig `json:"childPolicy,omitempty"` } -func parseConfig(c json.RawMessage) (*lbConfig, error) { - var cfg lbConfig +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, err } return &cfg, nil } -func equalDropCategories(a, b []dropCategory) bool { +func equalDropCategories(a, b []DropConfig) bool { if len(a) != len(b) { return false } diff --git a/xds/internal/balancer/clusterimpl/config_test.go b/xds/internal/balancer/clusterimpl/config_test.go index 89696981e2a..ccb0c5e74d9 100644 --- a/xds/internal/balancer/clusterimpl/config_test.go +++ b/xds/internal/balancer/clusterimpl/config_test.go @@ -87,7 +87,7 @@ func TestParseConfig(t *testing.T) { tests := []struct { name string js string - want *lbConfig + want *LBConfig wantErr bool }{ { @@ -105,12 +105,12 @@ func TestParseConfig(t *testing.T) { { name: "OK", js: testJSONConfig, - want: &lbConfig{ - Cluster: "test_cluster", - EDSServiceName: "test-eds", - LRSLoadReportingServerName: newString("lrs_server"), - MaxConcurrentRequests: newUint32(123), - DropCategories: []dropCategory{ + want: &LBConfig{ + Cluster: "test_cluster", + EDSServiceName: "test-eds", + LoadReportingServerName: newString("lrs_server"), + MaxConcurrentRequests: newUint32(123), + DropCategories: []DropConfig{ {Category: "drop-1", RequestsPerMillion: 314}, {Category: "drop-2", RequestsPerMillion: 159}, }, diff --git a/xds/internal/balancer/clusterimpl/picker.go b/xds/internal/balancer/clusterimpl/picker.go index 6e9d2791153..db29c550be1 100644 --- a/xds/internal/balancer/clusterimpl/picker.go +++ b/xds/internal/balancer/clusterimpl/picker.go @@ -19,16 +19,19 @@ package clusterimpl import ( + orcapb "github.com/cncf/udpa/go/udpa/data/orca/v1" "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/wrr" "google.golang.org/grpc/status" - "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) -var newRandomWRR = wrr.NewRandom +// NewRandomWRR is used when calculating drops. It's exported so that tests can +// override it. +var NewRandomWRR = wrr.NewRandom const million = 1000000 @@ -47,8 +50,8 @@ func gcd(a, b uint32) uint32 { return a } -func newDropper(c dropCategory) *dropper { - w := newRandomWRR() +func newDropper(c DropConfig) *dropper { + w := NewRandomWRR() gcdv := gcd(c.RequestsPerMillion, million) // Return true for RequestPerMillion, false for the rest. w.Add(true, int64(c.RequestsPerMillion/gcdv)) @@ -64,21 +67,30 @@ func (d *dropper) drop() (ret bool) { return d.w.Next().(bool) } +const ( + serverLoadCPUName = "cpu_utilization" + serverLoadMemoryName = "mem_utilization" +) + // loadReporter wraps the methods from the loadStore that are used here. type loadReporter interface { + CallStarted(locality string) + CallFinished(locality string, err error) + CallServerLoad(locality, name string, val float64) CallDropped(locality string) } -type dropPicker struct { +// Picker implements RPC drop, circuit breaking drop and load reporting. +type picker struct { drops []*dropper s balancer.State loadStore loadReporter - counter *client.ServiceRequestsCounter + counter *xdsclient.ClusterRequestsCounter countMax uint32 } -func newDropPicker(s balancer.State, config *dropConfigs, loadStore load.PerClusterReporter) *dropPicker { - return &dropPicker{ +func newPicker(s balancer.State, config *dropConfigs, loadStore load.PerClusterReporter) *picker { + return &picker{ drops: config.drops, s: s, loadStore: loadStore, @@ -87,13 +99,14 @@ func newDropPicker(s balancer.State, config *dropConfigs, loadStore load.PerClus } } -func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { +func (d *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { // Don't drop unless the inner picker is READY. Similar to // https://github.com/grpc/grpc-go/issues/2622. if d.s.ConnectivityState != connectivity.Ready { return d.s.Picker.Pick(info) } + // Check if this RPC should be dropped by category. for _, dp := range d.drops { if dp.drop() { if d.loadStore != nil { @@ -103,6 +116,7 @@ func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { } } + // Check if this RPC should be dropped by circuit breaking. if d.counter != nil { if err := d.counter.StartRequest(d.countMax); err != nil { // Drops by circuit breaking are reported with empty category. They @@ -112,11 +126,58 @@ func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { } return balancer.PickResult{}, status.Errorf(codes.Unavailable, err.Error()) } - pr, err := d.s.Picker.Pick(info) - if err != nil { + } + + var lIDStr string + pr, err := d.s.Picker.Pick(info) + if scw, ok := pr.SubConn.(*scWrapper); ok { + // This OK check also covers the case err!=nil, because SubConn will be + // nil. + pr.SubConn = scw.SubConn + var e error + // If locality ID isn't found in the wrapper, an empty locality ID will + // be used. + lIDStr, e = scw.localityID().ToString() + if e != nil { + logger.Infof("failed to marshal LocalityID: %#v, loads won't be reported", scw.localityID()) + } + } + + if err != nil { + if d.counter != nil { + // Release one request count if this pick fails. d.counter.EndRequest() - return pr, err } + return pr, err + } + + if d.loadStore != nil { + d.loadStore.CallStarted(lIDStr) + oldDone := pr.Done + pr.Done = func(info balancer.DoneInfo) { + if oldDone != nil { + oldDone(info) + } + d.loadStore.CallFinished(lIDStr, info.Err) + + load, ok := info.ServerLoad.(*orcapb.OrcaLoadReport) + if !ok { + return + } + d.loadStore.CallServerLoad(lIDStr, serverLoadCPUName, load.CpuUtilization) + d.loadStore.CallServerLoad(lIDStr, serverLoadMemoryName, load.MemUtilization) + for n, c := range load.RequestCost { + d.loadStore.CallServerLoad(lIDStr, n, c) + } + for n, c := range load.Utilization { + d.loadStore.CallServerLoad(lIDStr, n, c) + } + } + } + + if d.counter != nil { + // Update Done() so that when the RPC finishes, the request count will + // be released. oldDone := pr.Done pr.Done = func(doneInfo balancer.DoneInfo) { d.counter.EndRequest() @@ -124,8 +185,7 @@ func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { oldDone(doneInfo) } } - return pr, err } - return d.s.Picker.Pick(info) + return pr, err } diff --git a/xds/internal/balancer/clustermanager/balancerstateaggregator.go b/xds/internal/balancer/clustermanager/balancerstateaggregator.go index 35eb86c3590..6e0e03299f9 100644 --- a/xds/internal/balancer/clustermanager/balancerstateaggregator.go +++ b/xds/internal/balancer/clustermanager/balancerstateaggregator.go @@ -183,13 +183,18 @@ func (bsa *balancerStateAggregator) build() balancer.State { // handling the special connecting after ready, as in UpdateState(). Then a // function to calculate the aggregated connectivity state as in this // function. - var readyN, connectingN int + // + // TODO: use balancer.ConnectivityStateEvaluator to calculate the aggregated + // state. + var readyN, connectingN, idleN int for _, ps := range bsa.idToPickerState { switch ps.stateToAggregate { case connectivity.Ready: readyN++ case connectivity.Connecting: connectingN++ + case connectivity.Idle: + idleN++ } } var aggregatedState connectivity.State @@ -198,6 +203,8 @@ func (bsa *balancerStateAggregator) build() balancer.State { aggregatedState = connectivity.Ready case connectingN > 0: aggregatedState = connectivity.Connecting + case idleN > 0: + aggregatedState = connectivity.Idle default: aggregatedState = connectivity.TransientFailure } diff --git a/xds/internal/balancer/clustermanager/clustermanager.go b/xds/internal/balancer/clustermanager/clustermanager.go index 1e4dee7f5d3..318545d79b0 100644 --- a/xds/internal/balancer/clustermanager/clustermanager.go +++ b/xds/internal/balancer/clustermanager/clustermanager.go @@ -27,6 +27,7 @@ import ( "google.golang.org/grpc/grpclog" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/hierarchy" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/xds/internal/balancer/balancergroup" @@ -35,12 +36,12 @@ import ( const balancerName = "xds_cluster_manager_experimental" func init() { - balancer.Register(builder{}) + balancer.Register(bb{}) } -type builder struct{} +type bb struct{} -func (builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { b := &bal{} b.logger = prefixLogger(b) b.stateAggregator = newBalancerStateAggregator(cc, b.logger) @@ -51,11 +52,11 @@ func (builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balance return b } -func (builder) Name() string { +func (bb) Name() string { return balancerName } -func (builder) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(c) } @@ -115,7 +116,7 @@ func (b *bal) UpdateClientConnState(s balancer.ClientConnState) error { if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) } - b.logger.Infof("update with config %+v, resolver state %+v", s.BalancerConfig, s.ResolverState) + b.logger.Infof("update with config %+v, resolver state %+v", pretty.ToJSON(s.BalancerConfig), s.ResolverState) b.updateChildren(s, newConfig) return nil @@ -132,6 +133,11 @@ func (b *bal) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnStat func (b *bal) Close() { b.stateAggregator.close() b.bg.Close() + b.logger.Infof("Shutdown") +} + +func (b *bal) ExitIdle() { + b.bg.ExitIdle() } const prefix = "[xds-cluster-manager-lb %p] " diff --git a/xds/internal/balancer/clustermanager/clustermanager_test.go b/xds/internal/balancer/clustermanager/clustermanager_test.go index a40d954ad64..d3475ea3f5d 100644 --- a/xds/internal/balancer/clustermanager/clustermanager_test.go +++ b/xds/internal/balancer/clustermanager/clustermanager_test.go @@ -565,3 +565,68 @@ func TestClusterManagerForwardsBalancerBuildOptions(t *testing.T) { t.Fatal(err2) } } + +const initIdleBalancerName = "test-init-Idle-balancer" + +var errTestInitIdle = fmt.Errorf("init Idle balancer error 0") + +func init() { + stub.Register(initIdleBalancerName, stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, opts balancer.ClientConnState) error { + bd.ClientConn.NewSubConn(opts.ResolverState.Addresses, balancer.NewSubConnOptions{}) + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { + err := fmt.Errorf("wrong picker error") + if state.ConnectivityState == connectivity.Idle { + err = errTestInitIdle + } + bd.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: &testutils.TestConstPicker{Err: err}, + }) + }, + }) +} + +// TestInitialIdle covers the case that if the child reports Idle, the overall +// state will be Idle. +func TestInitialIdle(t *testing.T) { + cc := testutils.NewTestClientConn(t) + rtb := rtBuilder.Build(cc, balancer.BuildOptions{}) + + configJSON1 := `{ +"children": { + "cds:cluster_1":{ "childPolicy": [{"test-init-Idle-balancer":""}] } +} +}` + + config1, err := rtParser.ParseConfig([]byte(configJSON1)) + if err != nil { + t.Fatalf("failed to parse balancer config: %v", err) + } + + // Send the config, and an address with hierarchy path ["cluster_1"]. + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0], Attributes: nil}, + } + if err := rtb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + hierarchy.Set(wantAddrs[0], []string{"cds:cluster_1"}), + }}, + BalancerConfig: config1, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + // Verify that a subconn is created with the address, and the hierarchy path + // in the address is cleared. + for range wantAddrs { + sc := <-cc.NewSubConnCh + rtb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + } + + if state1 := <-cc.NewStateCh; state1 != connectivity.Idle { + t.Fatalf("Received aggregated state: %v, want Idle", state1) + } +} diff --git a/xds/internal/balancer/clusterresolver/clusterresolver.go b/xds/internal/balancer/clusterresolver/clusterresolver.go new file mode 100644 index 00000000000..66a5aab305e --- /dev/null +++ b/xds/internal/balancer/clusterresolver/clusterresolver.go @@ -0,0 +1,378 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package clusterresolver contains EDS balancer implementation. +package clusterresolver + +import ( + "encoding/json" + "errors" + "fmt" + + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/buffer" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +// Name is the name of the cluster_resolver balancer. +const Name = "cluster_resolver_experimental" + +var ( + errBalancerClosed = errors.New("cdsBalancer is closed") + newChildBalancer = func(bb balancer.Builder, cc balancer.ClientConn, o balancer.BuildOptions) balancer.Balancer { + return bb.Build(cc, o) + } +) + +func init() { + balancer.Register(bb{}) +} + +type bb struct{} + +// Build helps implement the balancer.Builder interface. +func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + priorityBuilder := balancer.Get(priority.Name) + if priorityBuilder == nil { + logger.Errorf("priority balancer is needed but not registered") + return nil + } + priorityConfigParser, ok := priorityBuilder.(balancer.ConfigParser) + if !ok { + logger.Errorf("priority balancer builder is not a config parser") + return nil + } + + b := &clusterResolverBalancer{ + bOpts: opts, + updateCh: buffer.NewUnbounded(), + closed: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), + + priorityBuilder: priorityBuilder, + priorityConfigParser: priorityConfigParser, + } + b.logger = prefixLogger(b) + b.logger.Infof("Created") + + b.resourceWatcher = newResourceResolver(b) + b.cc = &ccWrapper{ + ClientConn: cc, + resourceWatcher: b.resourceWatcher, + } + + go b.run() + return b +} + +func (bb) Name() string { + return Name +} + +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + var cfg LBConfig + if err := json.Unmarshal(c, &cfg); err != nil { + return nil, fmt.Errorf("unable to unmarshal balancer config %s into cluster-resolver config, error: %v", string(c), err) + } + return &cfg, nil +} + +// ccUpdate wraps a clientConn update received from gRPC (pushed from the +// xdsResolver). +type ccUpdate struct { + state balancer.ClientConnState + err error +} + +// scUpdate wraps a subConn update received from gRPC. This is directly passed +// on to the child balancer. +type scUpdate struct { + subConn balancer.SubConn + state balancer.SubConnState +} + +type exitIdle struct{} + +// clusterResolverBalancer manages xdsClient and the actual EDS balancer implementation that +// does load balancing. +// +// It currently has only an clusterResolverBalancer. Later, we may add fallback. +type clusterResolverBalancer struct { + cc balancer.ClientConn + bOpts balancer.BuildOptions + updateCh *buffer.Unbounded // Channel for updates from gRPC. + resourceWatcher *resourceResolver + logger *grpclog.PrefixLogger + closed *grpcsync.Event + done *grpcsync.Event + + priorityBuilder balancer.Builder + priorityConfigParser balancer.ConfigParser + + config *LBConfig + configRaw *serviceconfig.ParseResult + xdsClient xdsclient.XDSClient // xDS client to watch EDS resource. + attrsWithClient *attributes.Attributes // Attributes with xdsClient attached to be passed to the child policies. + + child balancer.Balancer + priorities []priorityConfig + watchUpdateReceived bool +} + +// handleClientConnUpdate handles a ClientConnUpdate received from gRPC. Good +// updates lead to registration of EDS and DNS watches. Updates with error lead +// to cancellation of existing watch and propagation of the same error to the +// child balancer. +func (b *clusterResolverBalancer) handleClientConnUpdate(update *ccUpdate) { + // We first handle errors, if any, and then proceed with handling the + // update, only if the status quo has changed. + if err := update.err; err != nil { + b.handleErrorFromUpdate(err, true) + return + } + + b.logger.Infof("Receive update from resolver, balancer config: %v", pretty.ToJSON(update.state.BalancerConfig)) + cfg, _ := update.state.BalancerConfig.(*LBConfig) + if cfg == nil { + b.logger.Warningf("xds: unexpected LoadBalancingConfig type: %T", update.state.BalancerConfig) + return + } + + b.config = cfg + b.configRaw = update.state.ResolverState.ServiceConfig + b.resourceWatcher.updateMechanisms(cfg.DiscoveryMechanisms) + + if !b.watchUpdateReceived { + // If update was not received, wait for it. + return + } + // If eds resp was received before this, the child policy was created. We + // need to generate a new balancer config and send it to the child, because + // certain fields (unrelated to EDS watch) might have changed. + if err := b.updateChildConfig(); err != nil { + b.logger.Warningf("failed to update child policy config: %v", err) + } +} + +// handleWatchUpdate handles a watch update from the xDS Client. Good updates +// lead to clientConn updates being invoked on the underlying child balancer. +func (b *clusterResolverBalancer) handleWatchUpdate(update *resourceUpdate) { + if err := update.err; err != nil { + b.logger.Warningf("Watch error from xds-client %p: %v", b.xdsClient, err) + b.handleErrorFromUpdate(err, false) + return + } + + b.logger.Infof("resource update: %+v", pretty.ToJSON(update.priorities)) + b.watchUpdateReceived = true + b.priorities = update.priorities + + // A new EDS update triggers new child configs (e.g. different priorities + // for the priority balancer), and new addresses (the endpoints come from + // the EDS response). + if err := b.updateChildConfig(); err != nil { + b.logger.Warningf("failed to update child policy's balancer config: %v", err) + } +} + +// updateChildConfig builds a balancer config from eb's cached eds resp and +// service config, and sends that to the child balancer. Note that it also +// generates the addresses, because the endpoints come from the EDS resp. +// +// If child balancer doesn't already exist, one will be created. +func (b *clusterResolverBalancer) updateChildConfig() error { + // Child was build when the first EDS resp was received, so we just build + // the config and addresses. + if b.child == nil { + b.child = newChildBalancer(b.priorityBuilder, b.cc, b.bOpts) + } + + childCfgBytes, addrs, err := buildPriorityConfigJSON(b.priorities, b.config.XDSLBPolicy) + if err != nil { + return fmt.Errorf("failed to build priority balancer config: %v", err) + } + childCfg, err := b.priorityConfigParser.ParseConfig(childCfgBytes) + if err != nil { + return fmt.Errorf("failed to parse generated priority balancer config, this should never happen because the config is generated: %v", err) + } + b.logger.Infof("build balancer config: %v", pretty.ToJSON(childCfg)) + return b.child.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: addrs, + ServiceConfig: b.configRaw, + Attributes: b.attrsWithClient, + }, + BalancerConfig: childCfg, + }) +} + +// handleErrorFromUpdate handles both the error from parent ClientConn (from CDS +// balancer) and the error from xds client (from the watcher). fromParent is +// true if error is from parent ClientConn. +// +// If the error is connection error, it should be handled for fallback purposes. +// +// If the error is resource-not-found: +// - If it's from CDS balancer (shows as a resolver error), it means LDS or CDS +// resources were removed. The EDS watch should be canceled. +// - If it's from xds client, it means EDS resource were removed. The EDS +// watcher should keep watching. +// In both cases, the sub-balancers will be receive the error. +func (b *clusterResolverBalancer) handleErrorFromUpdate(err error, fromParent bool) { + b.logger.Warningf("Received error: %v", err) + if fromParent && xdsclient.ErrType(err) == xdsclient.ErrorTypeResourceNotFound { + // This is an error from the parent ClientConn (can be the parent CDS + // balancer), and is a resource-not-found error. This means the resource + // (can be either LDS or CDS) was removed. Stop the EDS watch. + b.resourceWatcher.stop() + } + if b.child != nil { + b.child.ResolverError(err) + } else { + // If eds balancer was never created, fail the RPCs with errors. + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: base.NewErrPicker(err), + }) + } + +} + +// run is a long-running goroutine which handles all updates from gRPC and +// xdsClient. All methods which are invoked directly by gRPC or xdsClient simply +// push an update onto a channel which is read and acted upon right here. +func (b *clusterResolverBalancer) run() { + for { + select { + case u := <-b.updateCh.Get(): + b.updateCh.Load() + switch update := u.(type) { + case *ccUpdate: + b.handleClientConnUpdate(update) + case *scUpdate: + // SubConn updates are simply handed over to the underlying + // child balancer. + if b.child == nil { + b.logger.Errorf("xds: received scUpdate {%+v} with no child balancer", update) + break + } + b.child.UpdateSubConnState(update.subConn, update.state) + case exitIdle: + if b.child == nil { + b.logger.Errorf("xds: received ExitIdle with no child balancer") + break + } + // This implementation assumes the child balancer supports + // ExitIdle (but still checks for the interface's existence to + // avoid a panic if not). If the child does not, no subconns + // will be connected. + if ei, ok := b.child.(balancer.ExitIdler); ok { + ei.ExitIdle() + } + } + case u := <-b.resourceWatcher.updateChannel: + b.handleWatchUpdate(u) + + // Close results in cancellation of the EDS watch and closing of the + // underlying child policy and is the only way to exit this goroutine. + case <-b.closed.Done(): + b.resourceWatcher.stop() + + if b.child != nil { + b.child.Close() + b.child = nil + } + // This is the *ONLY* point of return from this function. + b.logger.Infof("Shutdown") + b.done.Fire() + return + } + } +} + +// Following are methods to implement the balancer interface. + +// UpdateClientConnState receives the serviceConfig (which contains the +// clusterName to watch for in CDS) and the xdsClient object from the +// xdsResolver. +func (b *clusterResolverBalancer) UpdateClientConnState(state balancer.ClientConnState) error { + if b.closed.HasFired() { + b.logger.Warningf("xds: received ClientConnState {%+v} after clusterResolverBalancer was closed", state) + return errBalancerClosed + } + + if b.xdsClient == nil { + c := xdsclient.FromResolverState(state.ResolverState) + if c == nil { + return balancer.ErrBadResolverState + } + b.xdsClient = c + b.attrsWithClient = state.ResolverState.Attributes + } + + b.updateCh.Put(&ccUpdate{state: state}) + return nil +} + +// ResolverError handles errors reported by the xdsResolver. +func (b *clusterResolverBalancer) ResolverError(err error) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received resolver error {%v} after clusterResolverBalancer was closed", err) + return + } + b.updateCh.Put(&ccUpdate{err: err}) +} + +// UpdateSubConnState handles subConn updates from gRPC. +func (b *clusterResolverBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received subConn update {%v, %v} after clusterResolverBalancer was closed", sc, state) + return + } + b.updateCh.Put(&scUpdate{subConn: sc, state: state}) +} + +// Close closes the cdsBalancer and the underlying child balancer. +func (b *clusterResolverBalancer) Close() { + b.closed.Fire() + <-b.done.Done() +} + +func (b *clusterResolverBalancer) ExitIdle() { + b.updateCh.Put(exitIdle{}) +} + +// ccWrapper overrides ResolveNow(), so that re-resolution from the child +// policies will trigger the DNS resolver in cluster_resolver balancer. +type ccWrapper struct { + balancer.ClientConn + resourceWatcher *resourceResolver +} + +func (c *ccWrapper) ResolveNow(resolver.ResolveNowOptions) { + c.resourceWatcher.resolveNow() +} diff --git a/xds/internal/balancer/clusterresolver/clusterresolver_test.go b/xds/internal/balancer/clusterresolver/clusterresolver_test.go new file mode 100644 index 00000000000..6af81f89f1f --- /dev/null +++ b/xds/internal/balancer/clusterresolver/clusterresolver_test.go @@ -0,0 +1,500 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // V2 client registration. +) + +const ( + defaultTestTimeout = 1 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + testEDSServcie = "test-eds-service-name" + testClusterName = "test-cluster-name" +) + +var ( + // A non-empty endpoints update which is expected to be accepted by the EDS + // LB policy. + defaultEndpointsUpdate = xdsclient.EndpointsUpdate{ + Localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{{Address: "endpoint1"}}, + ID: internal.LocalityID{Zone: "zone"}, + Priority: 1, + Weight: 100, + }, + }, + } +) + +func init() { + balancer.Register(bb{}) +} + +type s struct { + grpctest.Tester + + cleanup func() +} + +func (ss s) Teardown(t *testing.T) { + xdsclient.ClearAllCountersForTesting() + ss.Tester.Teardown(t) + if ss.cleanup != nil { + ss.cleanup() + } +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +const testBalancerNameFooBar = "foo.bar" + +func newNoopTestClientConn() *noopTestClientConn { + return &noopTestClientConn{} +} + +// noopTestClientConn is used in EDS balancer config update tests that only +// cover the config update handling, but not SubConn/load-balancing. +type noopTestClientConn struct { + balancer.ClientConn +} + +func (t *noopTestClientConn) NewSubConn([]resolver.Address, balancer.NewSubConnOptions) (balancer.SubConn, error) { + return nil, nil +} + +func (noopTestClientConn) Target() string { return testEDSServcie } + +type scStateChange struct { + sc balancer.SubConn + state balancer.SubConnState +} + +type fakeChildBalancer struct { + cc balancer.ClientConn + subConnState *testutils.Channel + clientConnState *testutils.Channel + resolverError *testutils.Channel +} + +func (f *fakeChildBalancer) UpdateClientConnState(state balancer.ClientConnState) error { + f.clientConnState.Send(state) + return nil +} + +func (f *fakeChildBalancer) ResolverError(err error) { + f.resolverError.Send(err) +} + +func (f *fakeChildBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + f.subConnState.Send(&scStateChange{sc: sc, state: state}) +} + +func (f *fakeChildBalancer) Close() {} + +func (f *fakeChildBalancer) ExitIdle() {} + +func (f *fakeChildBalancer) waitForClientConnStateChange(ctx context.Context) error { + _, err := f.clientConnState.Receive(ctx) + if err != nil { + return err + } + return nil +} + +func (f *fakeChildBalancer) waitForResolverError(ctx context.Context) error { + _, err := f.resolverError.Receive(ctx) + if err != nil { + return err + } + return nil +} + +func (f *fakeChildBalancer) waitForSubConnStateChange(ctx context.Context, wantState *scStateChange) error { + val, err := f.subConnState.Receive(ctx) + if err != nil { + return err + } + gotState := val.(*scStateChange) + if !cmp.Equal(gotState, wantState, cmp.AllowUnexported(scStateChange{})) { + return fmt.Errorf("got subconnStateChange %v, want %v", gotState, wantState) + } + return nil +} + +func newFakeChildBalancer(cc balancer.ClientConn) balancer.Balancer { + return &fakeChildBalancer{ + cc: cc, + subConnState: testutils.NewChannelWithSize(10), + clientConnState: testutils.NewChannelWithSize(10), + resolverError: testutils.NewChannelWithSize(10), + } +} + +type fakeSubConn struct{} + +func (*fakeSubConn) UpdateAddresses([]resolver.Address) { panic("implement me") } +func (*fakeSubConn) Connect() { panic("implement me") } + +// waitForNewChildLB makes sure that a new child LB is created by the top-level +// clusterResolverBalancer. +func waitForNewChildLB(ctx context.Context, ch *testutils.Channel) (*fakeChildBalancer, error) { + val, err := ch.Receive(ctx) + if err != nil { + return nil, fmt.Errorf("error when waiting for a new edsLB: %v", err) + } + return val.(*fakeChildBalancer), nil +} + +// setup overrides the functions which are used to create the xdsClient and the +// edsLB, creates fake version of them and makes them available on the provided +// channels. The returned cancel function should be called by the test for +// cleanup. +func setup(childLBCh *testutils.Channel) (*fakeclient.Client, func()) { + xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) + + origNewChildBalancer := newChildBalancer + newChildBalancer = func(_ balancer.Builder, cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { + childLB := newFakeChildBalancer(cc) + defer func() { childLBCh.Send(childLB) }() + return childLB + } + return xdsC, func() { + newChildBalancer = origNewChildBalancer + xdsC.Close() + } +} + +// TestSubConnStateChange verifies if the top-level clusterResolverBalancer passes on +// the subConnState to appropriate child balancer. +func (s) TestSubConnStateChange(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + xdsC.InvokeWatchEDSCallback("", defaultEndpointsUpdate, nil) + edsLB, err := waitForNewChildLB(ctx, edsLBCh) + if err != nil { + t.Fatal(err) + } + + fsc := &fakeSubConn{} + state := balancer.SubConnState{ConnectivityState: connectivity.Ready} + edsB.UpdateSubConnState(fsc, state) + if err := edsLB.waitForSubConnStateChange(ctx, &scStateChange{sc: fsc, state: state}); err != nil { + t.Fatal(err) + } +} + +// TestErrorFromXDSClientUpdate verifies that an error from xdsClient update is +// handled correctly. +// +// If it's resource-not-found, watch will NOT be canceled, the EDS impl will +// receive an empty EDS update, and new RPCs will fail. +// +// If it's connection error, nothing will happen. This will need to change to +// handle fallback. +func (s) TestErrorFromXDSClientUpdate(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + edsLB, err := waitForNewChildLB(ctx, edsLBCh) + if err != nil { + t.Fatal(err) + } + if err := edsLB.waitForClientConnStateChange(ctx); err != nil { + t.Fatalf("EDS impl got unexpected update: %v", err) + } + + connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, connectionErr) + + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { + t.Fatal("watch was canceled, want not canceled (timeout error)") + } + + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal(err) + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "clusterResolverBalancer resource not found error") + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, resourceErr) + // Even if error is resource not found, watch shouldn't be canceled, because + // this is an EDS resource removed (and xds client actually never sends this + // error, but we still handles it). + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { + t.Fatal("watch was canceled, want not canceled (timeout error)") + } + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal(err) + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + // An update with the same service name should not trigger a new watch. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForWatchEDS(sCtx); err != context.DeadlineExceeded { + t.Fatal("got unexpected new EDS watch") + } +} + +// TestErrorFromResolver verifies that resolver errors are handled correctly. +// +// If it's resource-not-found, watch will be canceled, the EDS impl will receive +// an empty EDS update, and new RPCs will fail. +// +// If it's connection error, nothing will happen. This will need to change to +// handle fallback. +func (s) TestErrorFromResolver(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + edsLB, err := waitForNewChildLB(ctx, edsLBCh) + if err != nil { + t.Fatal(err) + } + if err := edsLB.waitForClientConnStateChange(ctx); err != nil { + t.Fatalf("EDS impl got unexpected update: %v", err) + } + + connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") + edsB.ResolverError(connectionErr) + + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { + t.Fatal("watch was canceled, want not canceled (timeout error)") + } + + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal("eds impl got EDS resp, want timeout error") + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "clusterResolverBalancer resource not found error") + edsB.ResolverError(resourceErr) + if _, err := xdsC.WaitForCancelEDSWatch(ctx); err != nil { + t.Fatalf("want watch to be canceled, waitForCancel failed: %v", err) + } + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal(err) + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + // An update with the same service name should trigger a new watch, because + // the previous watch was canceled. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } +} + +// Given a list of resource names, verifies that EDS requests for the same are +// sent by the EDS balancer, through the fake xDS client. +func verifyExpectedRequests(ctx context.Context, fc *fakeclient.Client, resourceNames ...string) error { + for _, name := range resourceNames { + if name == "" { + // ResourceName empty string indicates a cancel. + if _, err := fc.WaitForCancelEDSWatch(ctx); err != nil { + return fmt.Errorf("timed out when expecting resource %q", name) + } + continue + } + + resName, err := fc.WaitForWatchEDS(ctx) + if err != nil { + return fmt.Errorf("timed out when expecting resource %q, %p", name, fc) + } + if resName != name { + return fmt.Errorf("got EDS request for resource %q, expected: %q", resName, name) + } + } + return nil +} + +// TestClientWatchEDS verifies that the xdsClient inside the top-level EDS LB +// policy registers an EDS watch for expected resource upon receiving an update +// from gRPC. +func (s) TestClientWatchEDS(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // If eds service name is not set, should watch for cluster name. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS("cluster-1"), + }); err != nil { + t.Fatal(err) + } + if err := verifyExpectedRequests(ctx, xdsC, "cluster-1"); err != nil { + t.Fatal(err) + } + + // Update with an non-empty edsServiceName should trigger an EDS watch for + // the same. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS("foobar-1"), + }); err != nil { + t.Fatal(err) + } + if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-1"); err != nil { + t.Fatal(err) + } + + // Also test the case where the edsServerName changes from one non-empty + // name to another, and make sure a new watch is registered. The previously + // registered watch will be cancelled, which will result in an EDS request + // with no resource names being sent to the server. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS("foobar-2"), + }); err != nil { + t.Fatal(err) + } + if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-2"); err != nil { + t.Fatal(err) + } +} + +func newLBConfigWithOneEDS(edsServiceName string) *LBConfig { + return &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: edsServiceName, + }}, + } +} diff --git a/xds/internal/balancer/clusterresolver/config.go b/xds/internal/balancer/clusterresolver/config.go new file mode 100644 index 00000000000..a6a3cbab804 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/config.go @@ -0,0 +1,185 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package clusterresolver + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + + "google.golang.org/grpc/balancer/roundrobin" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/serviceconfig" + "google.golang.org/grpc/xds/internal/balancer/ringhash" +) + +// DiscoveryMechanismType is the type of discovery mechanism. +type DiscoveryMechanismType int + +const ( + // DiscoveryMechanismTypeEDS is eds. + DiscoveryMechanismTypeEDS DiscoveryMechanismType = iota // `json:"EDS"` + // DiscoveryMechanismTypeLogicalDNS is DNS. + DiscoveryMechanismTypeLogicalDNS // `json:"LOGICAL_DNS"` +) + +// MarshalJSON marshals a DiscoveryMechanismType to a quoted json string. +// +// This is necessary to handle enum (as strings) from JSON. +// +// Note that this needs to be defined on the type not pointer, otherwise the +// variables of this type will marshal to int not string. +func (t DiscoveryMechanismType) MarshalJSON() ([]byte, error) { + buffer := bytes.NewBufferString(`"`) + switch t { + case DiscoveryMechanismTypeEDS: + buffer.WriteString("EDS") + case DiscoveryMechanismTypeLogicalDNS: + buffer.WriteString("LOGICAL_DNS") + } + buffer.WriteString(`"`) + return buffer.Bytes(), nil +} + +// UnmarshalJSON unmarshals a quoted json string to the DiscoveryMechanismType. +func (t *DiscoveryMechanismType) UnmarshalJSON(b []byte) error { + var s string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + switch s { + case "EDS": + *t = DiscoveryMechanismTypeEDS + case "LOGICAL_DNS": + *t = DiscoveryMechanismTypeLogicalDNS + default: + return fmt.Errorf("unable to unmarshal string %q to type DiscoveryMechanismType", s) + } + return nil +} + +// DiscoveryMechanism is the discovery mechanism, can be either EDS or DNS. +// +// For DNS, the ClientConn target will be used for name resolution. +// +// For EDS, if EDSServiceName is not empty, it will be used for watching. If +// EDSServiceName is empty, Cluster will be used. +type DiscoveryMechanism struct { + // Cluster is the cluster name. + Cluster string `json:"cluster,omitempty"` + // LoadReportingServerName is the LRS server to send load reports to. If + // not present, load reporting will be disabled. If set to the empty string, + // load reporting will be sent to the same server that we obtained CDS data + // from. + LoadReportingServerName *string `json:"lrsLoadReportingServerName,omitempty"` + // MaxConcurrentRequests is the maximum number of outstanding requests can + // be made to the upstream cluster. Default is 1024. + MaxConcurrentRequests *uint32 `json:"maxConcurrentRequests,omitempty"` + // Type is the discovery mechanism type. + Type DiscoveryMechanismType `json:"type,omitempty"` + // EDSServiceName is the EDS service name, as returned in CDS. May be unset + // if not specified in CDS. For type EDS only. + // + // This is used for EDS watch if set. If unset, Cluster is used for EDS + // watch. + EDSServiceName string `json:"edsServiceName,omitempty"` + // DNSHostname is the DNS name to resolve in "host:port" form. For type + // LOGICAL_DNS only. + DNSHostname string `json:"dnsHostname,omitempty"` +} + +// Equal returns whether the DiscoveryMechanism is the same with the parameter. +func (dm DiscoveryMechanism) Equal(b DiscoveryMechanism) bool { + switch { + case dm.Cluster != b.Cluster: + return false + case !equalStringP(dm.LoadReportingServerName, b.LoadReportingServerName): + return false + case !equalUint32P(dm.MaxConcurrentRequests, b.MaxConcurrentRequests): + return false + case dm.Type != b.Type: + return false + case dm.EDSServiceName != b.EDSServiceName: + return false + case dm.DNSHostname != b.DNSHostname: + return false + } + return true +} + +func equalStringP(a, b *string) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +func equalUint32P(a, b *uint32) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +// LBConfig is the config for cluster resolver balancer. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + // DiscoveryMechanisms is an ordered list of discovery mechanisms. + // + // Must have at least one element. Results from each discovery mechanism are + // concatenated together in successive priorities. + DiscoveryMechanisms []DiscoveryMechanism `json:"discoveryMechanisms,omitempty"` + + // XDSLBPolicy specifies the policy for locality picking and endpoint picking. + // + // Note that it's not normal balancing policy, and it can only be either + // ROUND_ROBIN or RING_HASH. + // + // For ROUND_ROBIN, the policy name will be "ROUND_ROBIN", and the config + // will be empty. This sets the locality-picking policy to weighted_target + // and the endpoint-picking policy to round_robin. + // + // For RING_HASH, the policy name will be "RING_HASH", and the config will + // be lb config for the ring_hash_experimental LB Policy. ring_hash policy + // is responsible for both locality picking and endpoint picking. + XDSLBPolicy *internalserviceconfig.BalancerConfig `json:"xdsLbPolicy,omitempty"` +} + +const ( + rrName = roundrobin.Name + rhName = ringhash.Name +) + +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig + if err := json.Unmarshal(c, &cfg); err != nil { + return nil, err + } + if lbp := cfg.XDSLBPolicy; lbp != nil && !strings.EqualFold(lbp.Name, rrName) && !strings.EqualFold(lbp.Name, rhName) { + return nil, fmt.Errorf("unsupported child policy with name %q, not one of {%q,%q}", lbp.Name, rrName, rhName) + } + return &cfg, nil +} diff --git a/xds/internal/balancer/clusterresolver/config_test.go b/xds/internal/balancer/clusterresolver/config_test.go new file mode 100644 index 00000000000..796f8a49372 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/config_test.go @@ -0,0 +1,269 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/internal/balancer/stub" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/xds/internal/balancer/ringhash" +) + +func TestDiscoveryMechanismTypeMarshalJSON(t *testing.T) { + tests := []struct { + name string + typ DiscoveryMechanismType + want string + }{ + { + name: "eds", + typ: DiscoveryMechanismTypeEDS, + want: `"EDS"`, + }, + { + name: "dns", + typ: DiscoveryMechanismTypeLogicalDNS, + want: `"LOGICAL_DNS"`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, err := json.Marshal(tt.typ); err != nil || string(got) != tt.want { + t.Fatalf("DiscoveryMechanismTypeEDS.MarshalJSON() = (%v, %v), want (%s, nil)", string(got), err, tt.want) + } + }) + } +} +func TestDiscoveryMechanismTypeUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + js string + want DiscoveryMechanismType + wantErr bool + }{ + { + name: "eds", + js: `"EDS"`, + want: DiscoveryMechanismTypeEDS, + }, + { + name: "dns", + js: `"LOGICAL_DNS"`, + want: DiscoveryMechanismTypeLogicalDNS, + }, + { + name: "error", + js: `"1234"`, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got DiscoveryMechanismType + err := json.Unmarshal([]byte(tt.js), &got) + if (err != nil) != tt.wantErr { + t.Fatalf("DiscoveryMechanismTypeEDS.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Fatalf("DiscoveryMechanismTypeEDS.UnmarshalJSON() got unexpected output, diff (-got +want): %v", diff) + } + }) + } +} + +func init() { + // This is needed now for the config parsing tests to pass. Otherwise they + // will fail with "RING_HASH unsupported". + // + // TODO: delete this once ring-hash policy is implemented and imported. + stub.Register(rhName, stub.BalancerFuncs{}) +} + +const ( + testJSONConfig1 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }] +}` + testJSONConfig2 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + },{ + "type": "LOGICAL_DNS" + }] +}` + testJSONConfig3 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }], + "xdsLbPolicy":[{"ROUND_ROBIN":{}}] +}` + testJSONConfig4 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }], + "xdsLbPolicy":[{"ring_hash_experimental":{}}] +}` + testJSONConfig5 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }], + "xdsLbPolicy":[{"pick_first":{}}] +}` +) + +func TestParseConfig(t *testing.T) { + tests := []struct { + name string + js string + want *LBConfig + wantErr bool + }{ + { + name: "empty json", + js: "", + want: nil, + wantErr: true, + }, + { + name: "OK with one discovery mechanism", + js: testJSONConfig1, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + }, + XDSLBPolicy: nil, + }, + wantErr: false, + }, + { + name: "OK with multiple discovery mechanisms", + js: testJSONConfig2, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + { + Type: DiscoveryMechanismTypeLogicalDNS, + }, + }, + XDSLBPolicy: nil, + }, + wantErr: false, + }, + { + name: "OK with picking policy round_robin", + js: testJSONConfig3, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + }, + XDSLBPolicy: &internalserviceconfig.BalancerConfig{ + Name: "ROUND_ROBIN", + Config: nil, + }, + }, + wantErr: false, + }, + { + name: "OK with picking policy ring_hash", + js: testJSONConfig4, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + }, + XDSLBPolicy: &internalserviceconfig.BalancerConfig{ + Name: ringhash.Name, + Config: nil, + }, + }, + wantErr: false, + }, + { + name: "unsupported picking policy", + js: testJSONConfig5, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseConfig([]byte(tt.js)) + if (err != nil) != tt.wantErr { + t.Fatalf("parseConfig() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("parseConfig() got unexpected output, diff (-got +want): %v", diff) + } + }) + } +} + +func newString(s string) *string { + return &s +} + +func newUint32(i uint32) *uint32 { + return &i +} diff --git a/xds/internal/balancer/clusterresolver/configbuilder.go b/xds/internal/balancer/clusterresolver/configbuilder.go new file mode 100644 index 00000000000..475497d4895 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/configbuilder.go @@ -0,0 +1,364 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "encoding/json" + "fmt" + "sort" + + "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/internal/hierarchy" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/balancer/clusterimpl" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/balancer/ringhash" + "google.golang.org/grpc/xds/internal/balancer/weightedtarget" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const million = 1000000 + +// priorityConfig is config for one priority. For example, if there an EDS and a +// DNS, the priority list will be [priorityConfig{EDS}, priorityConfig{DNS}]. +// +// Each priorityConfig corresponds to one discovery mechanism from the LBConfig +// generated by the CDS balancer. The CDS balancer resolves the cluster name to +// an ordered list of discovery mechanisms (if the top cluster is an aggregated +// cluster), one for each underlying cluster. +type priorityConfig struct { + mechanism DiscoveryMechanism + // edsResp is set only if type is EDS. + edsResp xdsclient.EndpointsUpdate + // addresses is set only if type is DNS. + addresses []string +} + +// buildPriorityConfigJSON builds balancer config for the passed in +// priorities. +// +// The built tree of balancers (see test for the output struct). +// +// If xds lb policy is ROUND_ROBIN, the children will be weighted_target for +// locality picking, and round_robin for endpoint picking. +// +// ┌────────┐ +// │priority│ +// └┬──────┬┘ +// │ │ +// ┌───────────▼┐ ┌▼───────────┐ +// │cluster_impl│ │cluster_impl│ +// └─┬──────────┘ └──────────┬─┘ +// │ │ +// ┌──────────────▼─┐ ┌─▼──────────────┐ +// │locality_picking│ │locality_picking│ +// └┬──────────────┬┘ └┬──────────────┬┘ +// │ │ │ │ +// ┌─▼─┐ ┌─▼─┐ ┌─▼─┐ ┌─▼─┐ +// │LRS│ │LRS│ │LRS│ │LRS│ +// └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ +// │ │ │ │ +// ┌──────────▼─────┐ ┌─────▼──────────┐ ┌──────────▼─────┐ ┌─────▼──────────┐ +// │endpoint_picking│ │endpoint_picking│ │endpoint_picking│ │endpoint_picking│ +// └────────────────┘ └────────────────┘ └────────────────┘ └────────────────┘ +// +// If xds lb policy is RING_HASH, the children will be just a ring_hash policy. +// The endpoints from all localities will be flattened to one addresses list, +// and the ring_hash policy will pick endpoints from it. +// +// ┌────────┐ +// │priority│ +// └┬──────┬┘ +// │ │ +// ┌──────────▼─┐ ┌─▼──────────┐ +// │cluster_impl│ │cluster_impl│ +// └──────┬─────┘ └─────┬──────┘ +// │ │ +// ┌──────▼─────┐ ┌─────▼──────┐ +// │ ring_hash │ │ ring_hash │ +// └────────────┘ └────────────┘ +// +// If endpointPickingPolicy is nil, roundrobin will be used. +// +// Custom locality picking policy isn't support, and weighted_target is always +// used. +func buildPriorityConfigJSON(priorities []priorityConfig, xdsLBPolicy *internalserviceconfig.BalancerConfig) ([]byte, []resolver.Address, error) { + pc, addrs, err := buildPriorityConfig(priorities, xdsLBPolicy) + if err != nil { + return nil, nil, fmt.Errorf("failed to build priority config: %v", err) + } + ret, err := json.Marshal(pc) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal built priority config struct into json: %v", err) + } + return ret, addrs, nil +} + +func buildPriorityConfig(priorities []priorityConfig, xdsLBPolicy *internalserviceconfig.BalancerConfig) (*priority.LBConfig, []resolver.Address, error) { + var ( + retConfig = &priority.LBConfig{Children: make(map[string]*priority.Child)} + retAddrs []resolver.Address + ) + for i, p := range priorities { + switch p.mechanism.Type { + case DiscoveryMechanismTypeEDS: + names, configs, addrs, err := buildClusterImplConfigForEDS(i, p.edsResp, p.mechanism, xdsLBPolicy) + if err != nil { + return nil, nil, err + } + retConfig.Priorities = append(retConfig.Priorities, names...) + for n, c := range configs { + retConfig.Children[n] = &priority.Child{ + Config: &internalserviceconfig.BalancerConfig{Name: clusterimpl.Name, Config: c}, + // Ignore all re-resolution from EDS children. + IgnoreReresolutionRequests: true, + } + } + retAddrs = append(retAddrs, addrs...) + case DiscoveryMechanismTypeLogicalDNS: + name, config, addrs := buildClusterImplConfigForDNS(i, p.addresses) + retConfig.Priorities = append(retConfig.Priorities, name) + retConfig.Children[name] = &priority.Child{ + Config: &internalserviceconfig.BalancerConfig{Name: clusterimpl.Name, Config: config}, + // Not ignore re-resolution from DNS children, they will trigger + // DNS to re-resolve. + IgnoreReresolutionRequests: false, + } + retAddrs = append(retAddrs, addrs...) + } + } + return retConfig, retAddrs, nil +} + +func buildClusterImplConfigForDNS(parentPriority int, addrStrs []string) (string, *clusterimpl.LBConfig, []resolver.Address) { + // Endpoint picking policy for DNS is hardcoded to pick_first. + const childPolicy = "pick_first" + retAddrs := make([]resolver.Address, 0, len(addrStrs)) + pName := fmt.Sprintf("priority-%v", parentPriority) + for _, addrStr := range addrStrs { + retAddrs = append(retAddrs, hierarchy.Set(resolver.Address{Addr: addrStr}, []string{pName})) + } + return pName, &clusterimpl.LBConfig{ChildPolicy: &internalserviceconfig.BalancerConfig{Name: childPolicy}}, retAddrs +} + +// buildClusterImplConfigForEDS returns a list of cluster_impl configs, one for +// each priority, sorted by priority, and the addresses for each priority (with +// hierarchy attributes set). +// +// For example, if there are two priorities, the returned values will be +// - ["p0", "p1"] +// - map{"p0":p0_config, "p1":p1_config} +// - [p0_address_0, p0_address_1, p1_address_0, p1_address_1] +// - p0 addresses' hierarchy attributes are set to p0 +func buildClusterImplConfigForEDS(parentPriority int, edsResp xdsclient.EndpointsUpdate, mechanism DiscoveryMechanism, xdsLBPolicy *internalserviceconfig.BalancerConfig) ([]string, map[string]*clusterimpl.LBConfig, []resolver.Address, error) { + drops := make([]clusterimpl.DropConfig, 0, len(edsResp.Drops)) + for _, d := range edsResp.Drops { + drops = append(drops, clusterimpl.DropConfig{ + Category: d.Category, + RequestsPerMillion: d.Numerator * million / d.Denominator, + }) + } + + priorityChildNames, priorities := groupLocalitiesByPriority(edsResp.Localities) + retNames := make([]string, 0, len(priorityChildNames)) + retAddrs := make([]resolver.Address, 0, len(priorityChildNames)) + retConfigs := make(map[string]*clusterimpl.LBConfig, len(priorityChildNames)) + for _, priorityName := range priorityChildNames { + priorityLocalities := priorities[priorityName] + // Prepend parent priority to the priority names, to avoid duplicates. + pName := fmt.Sprintf("priority-%v-%v", parentPriority, priorityName) + retNames = append(retNames, pName) + cfg, addrs, err := priorityLocalitiesToClusterImpl(priorityLocalities, pName, mechanism, drops, xdsLBPolicy) + if err != nil { + return nil, nil, nil, err + } + retConfigs[pName] = cfg + retAddrs = append(retAddrs, addrs...) + } + return retNames, retConfigs, retAddrs, nil +} + +// groupLocalitiesByPriority returns the localities grouped by priority. +// +// It also returns a list of strings where each string represents a priority, +// and the list is sorted from higher priority to lower priority. +// +// For example, for L0-p0, L1-p0, L2-p1, results will be +// - ["p0", "p1"] +// - map{"p0":[L0, L1], "p1":[L2]} +func groupLocalitiesByPriority(localities []xdsclient.Locality) ([]string, map[string][]xdsclient.Locality) { + var priorityIntSlice []int + priorities := make(map[string][]xdsclient.Locality) + for _, locality := range localities { + if locality.Weight == 0 { + continue + } + priorityName := fmt.Sprintf("%v", locality.Priority) + priorities[priorityName] = append(priorities[priorityName], locality) + priorityIntSlice = append(priorityIntSlice, int(locality.Priority)) + } + // Sort the priorities based on the int value, deduplicate, and then turn + // the sorted list into a string list. This will be child names, in priority + // order. + sort.Ints(priorityIntSlice) + priorityIntSliceDeduped := dedupSortedIntSlice(priorityIntSlice) + priorityNameSlice := make([]string, 0, len(priorityIntSliceDeduped)) + for _, p := range priorityIntSliceDeduped { + priorityNameSlice = append(priorityNameSlice, fmt.Sprintf("%v", p)) + } + return priorityNameSlice, priorities +} + +func dedupSortedIntSlice(a []int) []int { + if len(a) == 0 { + return a + } + i, j := 0, 1 + for ; j < len(a); j++ { + if a[i] == a[j] { + continue + } + i++ + if i != j { + a[i] = a[j] + } + } + return a[:i+1] +} + +// rrBalancerConfig is a const roundrobin config, used as child of +// weighted-roundrobin. To avoid allocating memory everytime. +var rrBalancerConfig = &internalserviceconfig.BalancerConfig{Name: roundrobin.Name} + +// priorityLocalitiesToClusterImpl takes a list of localities (with the same +// priority), and generates a cluster impl policy config, and a list of +// addresses. +func priorityLocalitiesToClusterImpl(localities []xdsclient.Locality, priorityName string, mechanism DiscoveryMechanism, drops []clusterimpl.DropConfig, xdsLBPolicy *internalserviceconfig.BalancerConfig) (*clusterimpl.LBConfig, []resolver.Address, error) { + clusterImplCfg := &clusterimpl.LBConfig{ + Cluster: mechanism.Cluster, + EDSServiceName: mechanism.EDSServiceName, + LoadReportingServerName: mechanism.LoadReportingServerName, + MaxConcurrentRequests: mechanism.MaxConcurrentRequests, + DropCategories: drops, + // ChildPolicy is not set. Will be set based on xdsLBPolicy + } + + if xdsLBPolicy == nil || xdsLBPolicy.Name == rrName { + // If lb policy is ROUND_ROBIN: + // - locality-picking policy is weighted_target + // - endpoint-picking policy is round_robin + logger.Infof("xds lb policy is %q, building config with weighted_target + round_robin", rrName) + // Child of weighted_target is hardcoded to round_robin. + wtConfig, addrs := localitiesToWeightedTarget(localities, priorityName, rrBalancerConfig) + clusterImplCfg.ChildPolicy = &internalserviceconfig.BalancerConfig{Name: weightedtarget.Name, Config: wtConfig} + return clusterImplCfg, addrs, nil + } + + if xdsLBPolicy.Name == rhName { + // If lb policy is RIHG_HASH, will build one ring_hash policy as child. + // The endpoints from all localities will be flattened to one addresses + // list, and the ring_hash policy will pick endpoints from it. + logger.Infof("xds lb policy is %q, building config with ring_hash", rhName) + addrs := localitiesToRingHash(localities, priorityName) + // Set child to ring_hash, note that the ring_hash config is from + // xdsLBPolicy. + clusterImplCfg.ChildPolicy = &internalserviceconfig.BalancerConfig{Name: ringhash.Name, Config: xdsLBPolicy.Config} + return clusterImplCfg, addrs, nil + } + + return nil, nil, fmt.Errorf("unsupported xds LB policy %q, not one of {%q,%q}", xdsLBPolicy.Name, rrName, rhName) +} + +// localitiesToRingHash takes a list of localities (with the same priority), and +// generates a list of addresses. +// +// The addresses have path hierarchy set to [priority-name], so priority knows +// which child policy they are for. +func localitiesToRingHash(localities []xdsclient.Locality, priorityName string) []resolver.Address { + var addrs []resolver.Address + for _, locality := range localities { + var lw uint32 = 1 + if locality.Weight != 0 { + lw = locality.Weight + } + localityStr, err := locality.ID.ToString() + if err != nil { + localityStr = fmt.Sprintf("%+v", locality.ID) + } + for _, endpoint := range locality.Endpoints { + // Filter out all "unhealthy" endpoints (unknown and healthy are + // both considered to be healthy: + // https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/core/health_check.proto#envoy-api-enum-core-healthstatus). + if endpoint.HealthStatus != xdsclient.EndpointHealthStatusHealthy && endpoint.HealthStatus != xdsclient.EndpointHealthStatusUnknown { + continue + } + + var ew uint32 = 1 + if endpoint.Weight != 0 { + ew = endpoint.Weight + } + + // The weight of each endpoint is locality_weight * endpoint_weight. + ai := weightedroundrobin.AddrInfo{Weight: lw * ew} + addr := weightedroundrobin.SetAddrInfo(resolver.Address{Addr: endpoint.Address}, ai) + addr = hierarchy.Set(addr, []string{priorityName, localityStr}) + addr = internal.SetLocalityID(addr, locality.ID) + addrs = append(addrs, addr) + } + } + return addrs +} + +// localitiesToWeightedTarget takes a list of localities (with the same +// priority), and generates a weighted target config, and list of addresses. +// +// The addresses have path hierarchy set to [priority-name, locality-name], so +// priority and weighted target know which child policy they are for. +func localitiesToWeightedTarget(localities []xdsclient.Locality, priorityName string, childPolicy *internalserviceconfig.BalancerConfig) (*weightedtarget.LBConfig, []resolver.Address) { + weightedTargets := make(map[string]weightedtarget.Target) + var addrs []resolver.Address + for _, locality := range localities { + localityStr, err := locality.ID.ToString() + if err != nil { + localityStr = fmt.Sprintf("%+v", locality.ID) + } + weightedTargets[localityStr] = weightedtarget.Target{Weight: locality.Weight, ChildPolicy: childPolicy} + for _, endpoint := range locality.Endpoints { + // Filter out all "unhealthy" endpoints (unknown and healthy are + // both considered to be healthy: + // https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/core/health_check.proto#envoy-api-enum-core-healthstatus). + if endpoint.HealthStatus != xdsclient.EndpointHealthStatusHealthy && endpoint.HealthStatus != xdsclient.EndpointHealthStatusUnknown { + continue + } + + addr := resolver.Address{Addr: endpoint.Address} + if childPolicy.Name == weightedroundrobin.Name && endpoint.Weight != 0 { + ai := weightedroundrobin.AddrInfo{Weight: endpoint.Weight} + addr = weightedroundrobin.SetAddrInfo(addr, ai) + } + addr = hierarchy.Set(addr, []string{priorityName, localityStr}) + addr = internal.SetLocalityID(addr, locality.ID) + addrs = append(addrs, addr) + } + } + return &weightedtarget.LBConfig{Targets: weightedTargets}, addrs +} diff --git a/xds/internal/balancer/clusterresolver/configbuilder_test.go b/xds/internal/balancer/clusterresolver/configbuilder_test.go new file mode 100644 index 00000000000..3e2ad8a2e64 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/configbuilder_test.go @@ -0,0 +1,979 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "bytes" + "encoding/json" + "fmt" + "sort" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/internal/hierarchy" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/balancer/clusterimpl" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/balancer/ringhash" + "google.golang.org/grpc/xds/internal/balancer/weightedtarget" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const ( + testLRSServer = "test-lrs-server" + testMaxRequests = 314 + testEDSServiceName = "service-name-from-parent" + testDropCategory = "test-drops" + testDropOverMillion = 1 + + localityCount = 5 + addressPerLocality = 2 +) + +var ( + testLocalityIDs []internal.LocalityID + testAddressStrs [][]string + testEndpoints [][]xdsclient.Endpoint + + testLocalitiesP0, testLocalitiesP1 []xdsclient.Locality + + addrCmpOpts = cmp.Options{ + cmp.AllowUnexported(attributes.Attributes{}), + cmp.Transformer("SortAddrs", func(in []resolver.Address) []resolver.Address { + out := append([]resolver.Address(nil), in...) // Copy input to avoid mutating it + sort.Slice(out, func(i, j int) bool { + return out[i].Addr < out[j].Addr + }) + return out + })} +) + +func init() { + for i := 0; i < localityCount; i++ { + testLocalityIDs = append(testLocalityIDs, internal.LocalityID{Zone: fmt.Sprintf("test-zone-%d", i)}) + var ( + addrs []string + ends []xdsclient.Endpoint + ) + for j := 0; j < addressPerLocality; j++ { + addr := fmt.Sprintf("addr-%d-%d", i, j) + addrs = append(addrs, addr) + ends = append(ends, xdsclient.Endpoint{ + Address: addr, + HealthStatus: xdsclient.EndpointHealthStatusHealthy, + }) + } + testAddressStrs = append(testAddressStrs, addrs) + testEndpoints = append(testEndpoints, ends) + } + + testLocalitiesP0 = []xdsclient.Locality{ + { + Endpoints: testEndpoints[0], + ID: testLocalityIDs[0], + Weight: 20, + Priority: 0, + }, + { + Endpoints: testEndpoints[1], + ID: testLocalityIDs[1], + Weight: 80, + Priority: 0, + }, + } + testLocalitiesP1 = []xdsclient.Locality{ + { + Endpoints: testEndpoints[2], + ID: testLocalityIDs[2], + Weight: 20, + Priority: 1, + }, + { + Endpoints: testEndpoints[3], + ID: testLocalityIDs[3], + Weight: 80, + Priority: 1, + }, + } +} + +// TestBuildPriorityConfigJSON is a sanity check that the built balancer config +// can be parsed. The behavior test is covered by TestBuildPriorityConfig. +func TestBuildPriorityConfigJSON(t *testing.T) { + gotConfig, _, err := buildPriorityConfigJSON([]priorityConfig{ + { + mechanism: DiscoveryMechanism{ + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServiceName, + }, + edsResp: xdsclient.EndpointsUpdate{ + Drops: []xdsclient.OverloadDropConfig{ + { + Category: testDropCategory, + Numerator: testDropOverMillion, + Denominator: million, + }, + }, + Localities: []xdsclient.Locality{ + testLocalitiesP0[0], + testLocalitiesP0[1], + testLocalitiesP1[0], + testLocalitiesP1[1], + }, + }, + }, + { + mechanism: DiscoveryMechanism{ + Type: DiscoveryMechanismTypeLogicalDNS, + }, + addresses: testAddressStrs[4], + }, + }, nil) + if err != nil { + t.Fatalf("buildPriorityConfigJSON(...) failed: %v", err) + } + + var prettyGot bytes.Buffer + if err := json.Indent(&prettyGot, gotConfig, ">>> ", " "); err != nil { + t.Fatalf("json.Indent() failed: %v", err) + } + // Print the indented json if this test fails. + t.Log(prettyGot.String()) + + priorityB := balancer.Get(priority.Name) + if _, err = priorityB.(balancer.ConfigParser).ParseConfig(gotConfig); err != nil { + t.Fatalf("ParseConfig(%+v) failed: %v", gotConfig, err) + } +} + +func TestBuildPriorityConfig(t *testing.T) { + gotConfig, gotAddrs, _ := buildPriorityConfig([]priorityConfig{ + { + mechanism: DiscoveryMechanism{ + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServiceName, + }, + edsResp: xdsclient.EndpointsUpdate{ + Drops: []xdsclient.OverloadDropConfig{ + { + Category: testDropCategory, + Numerator: testDropOverMillion, + Denominator: million, + }, + }, + Localities: []xdsclient.Locality{ + testLocalitiesP0[0], + testLocalitiesP0[1], + testLocalitiesP1[0], + testLocalitiesP1[1], + }, + }, + }, + { + mechanism: DiscoveryMechanism{ + Type: DiscoveryMechanismTypeLogicalDNS, + }, + addresses: testAddressStrs[4], + }, + }, nil) + + wantConfig := &priority.LBConfig{ + Children: map[string]*priority.Child{ + "priority-0-0": { + Config: &internalserviceconfig.BalancerConfig{ + Name: clusterimpl.Name, + Config: &clusterimpl.LBConfig{ + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[0].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[1].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + }, + IgnoreReresolutionRequests: true, + }, + "priority-0-1": { + Config: &internalserviceconfig.BalancerConfig{ + Name: clusterimpl.Name, + Config: &clusterimpl.LBConfig{ + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[2].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[3].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + }, + IgnoreReresolutionRequests: true, + }, + "priority-1": { + Config: &internalserviceconfig.BalancerConfig{ + Name: clusterimpl.Name, + Config: &clusterimpl.LBConfig{ + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: "pick_first"}, + }, + }, + IgnoreReresolutionRequests: false, + }, + }, + Priorities: []string{"priority-0-0", "priority-0-1", "priority-1"}, + } + wantAddrs := []resolver.Address{ + testAddrWithAttrs(testAddressStrs[0][0], nil, "priority-0-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[0][1], nil, "priority-0-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[1][0], nil, "priority-0-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[1][1], nil, "priority-0-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[2][0], nil, "priority-0-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[2][1], nil, "priority-0-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[3][0], nil, "priority-0-1", &testLocalityIDs[3]), + testAddrWithAttrs(testAddressStrs[3][1], nil, "priority-0-1", &testLocalityIDs[3]), + testAddrWithAttrs(testAddressStrs[4][0], nil, "priority-1", nil), + testAddrWithAttrs(testAddressStrs[4][1], nil, "priority-1", nil), + } + + if diff := cmp.Diff(gotConfig, wantConfig); diff != "" { + t.Errorf("buildPriorityConfig() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotAddrs, wantAddrs, addrCmpOpts); diff != "" { + t.Errorf("buildPriorityConfig() diff (-got +want) %v", diff) + } +} + +func TestBuildClusterImplConfigForDNS(t *testing.T) { + gotName, gotConfig, gotAddrs := buildClusterImplConfigForDNS(3, testAddressStrs[0]) + wantName := "priority-3" + wantConfig := &clusterimpl.LBConfig{ + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: "pick_first", + }, + } + wantAddrs := []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testAddressStrs[0][0]}, []string{"priority-3"}), + hierarchy.Set(resolver.Address{Addr: testAddressStrs[0][1]}, []string{"priority-3"}), + } + + if diff := cmp.Diff(gotName, wantName); diff != "" { + t.Errorf("buildClusterImplConfigForDNS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotConfig, wantConfig); diff != "" { + t.Errorf("buildClusterImplConfigForDNS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotAddrs, wantAddrs, addrCmpOpts); diff != "" { + t.Errorf("buildClusterImplConfigForDNS() diff (-got +want) %v", diff) + } +} + +func TestBuildClusterImplConfigForEDS(t *testing.T) { + gotNames, gotConfigs, gotAddrs, _ := buildClusterImplConfigForEDS( + 2, + xdsclient.EndpointsUpdate{ + Drops: []xdsclient.OverloadDropConfig{ + { + Category: testDropCategory, + Numerator: testDropOverMillion, + Denominator: million, + }, + }, + Localities: []xdsclient.Locality{ + { + Endpoints: testEndpoints[3], + ID: testLocalityIDs[3], + Weight: 80, + Priority: 1, + }, { + Endpoints: testEndpoints[1], + ID: testLocalityIDs[1], + Weight: 80, + Priority: 0, + }, { + Endpoints: testEndpoints[2], + ID: testLocalityIDs[2], + Weight: 20, + Priority: 1, + }, { + Endpoints: testEndpoints[0], + ID: testLocalityIDs[0], + Weight: 20, + Priority: 0, + }, + }, + }, + DiscoveryMechanism{ + Cluster: testClusterName, + MaxConcurrentRequests: newUint32(testMaxRequests), + LoadReportingServerName: newString(testLRSServer), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServiceName, + }, + nil, + ) + + wantNames := []string{ + fmt.Sprintf("priority-%v-%v", 2, 0), + fmt.Sprintf("priority-%v-%v", 2, 1), + } + wantConfigs := map[string]*clusterimpl.LBConfig{ + "priority-2-0": { + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[0].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[1].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + "priority-2-1": { + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[2].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[3].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + } + wantAddrs := []resolver.Address{ + testAddrWithAttrs(testAddressStrs[0][0], nil, "priority-2-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[0][1], nil, "priority-2-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[1][0], nil, "priority-2-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[1][1], nil, "priority-2-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[2][0], nil, "priority-2-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[2][1], nil, "priority-2-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[3][0], nil, "priority-2-1", &testLocalityIDs[3]), + testAddrWithAttrs(testAddressStrs[3][1], nil, "priority-2-1", &testLocalityIDs[3]), + } + + if diff := cmp.Diff(gotNames, wantNames); diff != "" { + t.Errorf("buildClusterImplConfigForEDS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotConfigs, wantConfigs); diff != "" { + t.Errorf("buildClusterImplConfigForEDS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotAddrs, wantAddrs, addrCmpOpts); diff != "" { + t.Errorf("buildClusterImplConfigForEDS() diff (-got +want) %v", diff) + } + +} + +func TestGroupLocalitiesByPriority(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + wantPriorities []string + wantLocalities map[string][]xdsclient.Locality + }{ + { + name: "1 locality 1 priority", + localities: []xdsclient.Locality{testLocalitiesP0[0]}, + wantPriorities: []string{"0"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0]}, + }, + }, + { + name: "2 locality 1 priority", + localities: []xdsclient.Locality{testLocalitiesP0[0], testLocalitiesP0[1]}, + wantPriorities: []string{"0"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0], testLocalitiesP0[1]}, + }, + }, + { + name: "1 locality in each", + localities: []xdsclient.Locality{testLocalitiesP0[0], testLocalitiesP1[0]}, + wantPriorities: []string{"0", "1"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0]}, + "1": {testLocalitiesP1[0]}, + }, + }, + { + name: "2 localities in each sorted", + localities: []xdsclient.Locality{ + testLocalitiesP0[0], testLocalitiesP0[1], + testLocalitiesP1[0], testLocalitiesP1[1]}, + wantPriorities: []string{"0", "1"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0], testLocalitiesP0[1]}, + "1": {testLocalitiesP1[0], testLocalitiesP1[1]}, + }, + }, + { + // The localities are given in order [p1, p0, p1, p0], but the + // returned priority list must be sorted [p0, p1], because the list + // order is the priority order. + name: "2 localities in each needs to sort", + localities: []xdsclient.Locality{ + testLocalitiesP1[1], testLocalitiesP0[1], + testLocalitiesP1[0], testLocalitiesP0[0]}, + wantPriorities: []string{"0", "1"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[1], testLocalitiesP0[0]}, + "1": {testLocalitiesP1[1], testLocalitiesP1[0]}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPriorities, gotLocalities := groupLocalitiesByPriority(tt.localities) + if diff := cmp.Diff(gotPriorities, tt.wantPriorities); diff != "" { + t.Errorf("groupLocalitiesByPriority() diff(-got +want) %v", diff) + } + if diff := cmp.Diff(gotLocalities, tt.wantLocalities); diff != "" { + t.Errorf("groupLocalitiesByPriority() diff(-got +want) %v", diff) + } + }) + } +} + +func TestDedupSortedIntSlice(t *testing.T) { + tests := []struct { + name string + a []int + want []int + }{ + { + name: "empty", + a: []int{}, + want: []int{}, + }, + { + name: "no dup", + a: []int{0, 1, 2, 3}, + want: []int{0, 1, 2, 3}, + }, + { + name: "with dup", + a: []int{0, 0, 1, 1, 1, 2, 3}, + want: []int{0, 1, 2, 3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := dedupSortedIntSlice(tt.a); !cmp.Equal(got, tt.want) { + t.Errorf("dedupSortedIntSlice() = %v, want %v, diff %v", got, tt.want, cmp.Diff(got, tt.want)) + } + }) + } +} + +func TestPriorityLocalitiesToClusterImpl(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + priorityName string + mechanism DiscoveryMechanism + childPolicy *internalserviceconfig.BalancerConfig + wantConfig *clusterimpl.LBConfig + wantAddrs []resolver.Address + wantErr bool + }{{ + name: "round robin as child, no LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: rrName}, + mechanism: DiscoveryMechanism{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &clusterimpl.LBConfig{ + Cluster: testClusterName, + EDSServiceName: testEDSServcie, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "ring_hash as child", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: rhName, Config: &ringhash.LBConfig{MinRingSize: 1, MaxRingSize: 2}}, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &clusterimpl.LBConfig{ + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: ringhash.Name, + Config: &ringhash.LBConfig{MinRingSize: 1, MaxRingSize: 2}, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(1800), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(200), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(7200), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(800), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "unsupported child", + localities: []xdsclient.Locality{{ + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }}, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: "some-child"}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1, err := priorityLocalitiesToClusterImpl(tt.localities, tt.priorityName, tt.mechanism, nil, tt.childPolicy) + if (err != nil) != tt.wantErr { + t.Fatalf("priorityLocalitiesToClusterImpl() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.wantConfig); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(got1, tt.wantAddrs, cmp.AllowUnexported(attributes.Attributes{})); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + }) + } +} + +func TestLocalitiesToWeightedTarget(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + priorityName string + childPolicy *internalserviceconfig.BalancerConfig + lrsServer *string + wantConfig *weightedtarget.LBConfig + wantAddrs []resolver.Address + }{ + { + name: "roundrobin as child, with LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + lrsServer: newString("test-lrs-server"), + wantConfig: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "roundrobin as child, no LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "weighted round robin as child, no LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: weightedroundrobin.Name}, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedroundrobin.Name, + }, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedroundrobin.Name, + }, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := localitiesToWeightedTarget(tt.localities, tt.priorityName, tt.childPolicy) + if diff := cmp.Diff(got, tt.wantConfig); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(got1, tt.wantAddrs, cmp.AllowUnexported(attributes.Attributes{})); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + }) + } +} + +func TestLocalitiesToRingHash(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + priorityName string + wantAddrs []resolver.Address + }{ + { + // Check that address weights are locality_weight * endpoint_weight. + name: "with locality and endpoint weight", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(1800), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(200), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(7200), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(800), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + // Check that endpoint_weight is 0, weight is the locality weight. + name: "locality weight only", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(20), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(20), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(80), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(80), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + // Check that locality_weight is 0, weight is the endpoint weight. + name: "endpoint weight only", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + }, + }, + priorityName: "test-priority", + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := localitiesToRingHash(tt.localities, tt.priorityName) + if diff := cmp.Diff(got, tt.wantAddrs, cmp.AllowUnexported(attributes.Attributes{})); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + }) + } +} + +func assertString(f func() (string, error)) string { + s, err := f() + if err != nil { + panic(err.Error()) + } + return s +} + +func testAddrWithAttrs(addrStr string, weight *uint32, priority string, lID *internal.LocalityID) resolver.Address { + addr := resolver.Address{Addr: addrStr} + if weight != nil { + addr = weightedroundrobin.SetAddrInfo(addr, weightedroundrobin.AddrInfo{Weight: *weight}) + } + path := []string{priority} + if lID != nil { + path = append(path, assertString(lID.ToString)) + addr = internal.SetLocalityID(addr, *lID) + } + addr = hierarchy.Set(addr, path) + return addr +} diff --git a/xds/internal/balancer/clusterresolver/eds_impl_test.go b/xds/internal/balancer/clusterresolver/eds_impl_test.go new file mode 100644 index 00000000000..00814a6212b --- /dev/null +++ b/xds/internal/balancer/clusterresolver/eds_impl_test.go @@ -0,0 +1,575 @@ +/* + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package clusterresolver + +import ( + "context" + "fmt" + "sort" + "testing" + "time" + + corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/balancer/balancergroup" + "google.golang.org/grpc/xds/internal/balancer/clusterimpl" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/balancer/weightedtarget" + "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +var ( + testClusterNames = []string{"test-cluster-1", "test-cluster-2"} + testSubZones = []string{"I", "II", "III", "IV"} + testEndpointAddrs []string +) + +const testBackendAddrsCount = 12 + +func init() { + for i := 0; i < testBackendAddrsCount; i++ { + testEndpointAddrs = append(testEndpointAddrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) + } + balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond + clusterimpl.NewRandomWRR = testutils.NewTestWRR + weightedtarget.NewRandomWRR = testutils.NewTestWRR + balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond * 100 +} + +func setupTestEDS(t *testing.T, initChild *internalserviceconfig.BalancerConfig) (balancer.Balancer, *testutils.TestClientConn, *fakeclient.Client, func()) { + xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) + cc := testutils.NewTestClientConn(t) + builder := balancer.Get(Name) + edsb := builder.Build(cc, balancer.BuildOptions{Target: resolver.Target{Endpoint: testEDSServcie}}) + if edsb == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + }}, + }, + }); err != nil { + edsb.Close() + xdsC.Close() + t.Fatal(err) + } + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + edsb.Close() + xdsC.Close() + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + return edsb, cc, xdsC, func() { + edsb.Close() + xdsC.Close() + } +} + +// One locality +// - add backend +// - remove backend +// - replace backend +// - change drop rate +func (s) TestEDS_OneLocality(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Pick with only the first backend. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) + } + + // The same locality, add one more backend. + clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) + + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test roundrobin with two subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2}); err != nil { + t.Fatal(err) + } + + // The same locality, delete first backend. + clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[1:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) + + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Test pick with only the second subconn. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) + } + + // The same locality, replace backend. + clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab4.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab4.Build()), nil) + + sc3 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + scToRemove = <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Test pick with only the third subconn. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3}); err != nil { + t.Fatal(err) + } + + // The same locality, different drop rate, dropping 50%. + clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], map[string]uint32{"test-drop": 50}) + clab5.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab5.Build()), nil) + + // Picks with drops. + if err := testPickerFromCh(cc.NewPickerCh, func(p balancer.Picker) error { + for i := 0; i < 100; i++ { + _, err := p.Pick(balancer.PickInfo{}) + // TODO: the dropping algorithm needs a design. When the dropping algorithm + // is fixed, this test also needs fix. + if i%2 == 0 && err == nil { + return fmt.Errorf("%d - the even number picks should be drops, got error ", i) + } else if i%2 != 0 && err != nil { + return fmt.Errorf("%d - the odd number picks should be non-drops, got error %v", i, err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + // The same locality, remove drops. + clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab6.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab6.Build()), nil) + + // Pick without drops. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3}); err != nil { + t.Fatal(err) + } +} + +// 2 locality +// - start with 2 locality +// - add locality +// - remove locality +// - address change for the locality +// - update locality weight +func (s) TestEDS_TwoLocalities(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + // Two localities, each with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Add the second locality later to make sure sc2 belongs to the second + // locality. Otherwise the test is flaky because of a map is used in EDS to + // keep localities. + clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test roundrobin with two subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2}); err != nil { + t.Fatal(err) + } + + // Add another locality, with one backend. + clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + clab2.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + clab2.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) + + sc3 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test roundrobin with three subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2, sc3}); err != nil { + t.Fatal(err) + } + + // Remove first locality. + clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab3.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + clab3.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) + + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Test pick with two subconns (without the first one). + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2, sc3}); err != nil { + t.Fatal(err) + } + + // Add a backend to the last locality. + clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab4.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + clab4.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab4.Build()), nil) + + sc4 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test pick with two subconns (without the first one). + // + // Locality-1 will be picked twice, and locality-2 will be picked twice. + // Locality-1 contains only sc2, locality-2 contains sc3 and sc4. So expect + // two sc2's and sc3, sc4. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2, sc2, sc3, sc4}); err != nil { + t.Fatal(err) + } + + // Change weight of the locality[1]. + clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab5.AddLocality(testSubZones[1], 2, 0, testEndpointAddrs[1:2], nil) + clab5.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab5.Build()), nil) + + // Test pick with two subconns different locality weight. + // + // Locality-1 will be picked four times, and locality-2 will be picked twice + // (weight 2 and 1). Locality-1 contains only sc2, locality-2 contains sc3 and + // sc4. So expect four sc2's and sc3, sc4. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2, sc2, sc2, sc2, sc3, sc4}); err != nil { + t.Fatal(err) + } + + // Change weight of the locality[1] to 0, it should never be picked. + clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab6.AddLocality(testSubZones[1], 0, 0, testEndpointAddrs[1:2], nil) + clab6.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab6.Build()), nil) + + // Changing weight of locality[1] to 0 caused it to be removed. It's subconn + // should also be removed. + // + // NOTE: this is because we handle locality with weight 0 same as the + // locality doesn't exist. If this changes in the future, this removeSubConn + // behavior will also change. + scToRemove2 := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove2, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove2) + } + + // Test pick with two subconns different locality weight. + // + // Locality-1 will be not be picked, and locality-2 will be picked. + // Locality-2 contains sc3 and sc4. So expect sc3, sc4. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3, sc4}); err != nil { + t.Fatal(err) + } +} + +// The EDS balancer gets EDS resp with unhealthy endpoints. Test that only +// healthy ones are used. +func (s) TestEDS_EndpointsHealth(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + // Two localities, each 3 backend, one Healthy, one Unhealthy, one Unknown. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:6], &testutils.AddLocalityOptions{ + Health: []corepb.HealthStatus{ + corepb.HealthStatus_HEALTHY, + corepb.HealthStatus_UNHEALTHY, + corepb.HealthStatus_UNKNOWN, + corepb.HealthStatus_DRAINING, + corepb.HealthStatus_TIMEOUT, + corepb.HealthStatus_DEGRADED, + }, + }) + clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[6:12], &testutils.AddLocalityOptions{ + Health: []corepb.HealthStatus{ + corepb.HealthStatus_HEALTHY, + corepb.HealthStatus_UNHEALTHY, + corepb.HealthStatus_UNKNOWN, + corepb.HealthStatus_DRAINING, + corepb.HealthStatus_TIMEOUT, + corepb.HealthStatus_DEGRADED, + }, + }) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + var ( + readySCs []balancer.SubConn + newSubConnAddrStrs []string + ) + for i := 0; i < 4; i++ { + addr := <-cc.NewSubConnAddrsCh + newSubConnAddrStrs = append(newSubConnAddrStrs, addr[0].Addr) + sc := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + readySCs = append(readySCs, sc) + } + + wantNewSubConnAddrStrs := []string{ + testEndpointAddrs[0], + testEndpointAddrs[2], + testEndpointAddrs[6], + testEndpointAddrs[8], + } + sortStrTrans := cmp.Transformer("Sort", func(in []string) []string { + out := append([]string(nil), in...) // Copy input to avoid mutating it. + sort.Strings(out) + return out + }) + if !cmp.Equal(newSubConnAddrStrs, wantNewSubConnAddrStrs, sortStrTrans) { + t.Fatalf("want newSubConn with address %v, got %v", wantNewSubConnAddrStrs, newSubConnAddrStrs) + } + + // There should be exactly 4 new SubConns. Check to make sure there's no + // more subconns being created. + select { + case <-cc.NewSubConnCh: + t.Fatalf("Got unexpected new subconn") + case <-time.After(time.Microsecond * 100): + } + + // Test roundrobin with the subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, readySCs); err != nil { + t.Fatal(err) + } +} + +// TestEDS_EmptyUpdate covers the cases when eds impl receives an empty update. +// +// It should send an error picker with transient failure to the parent. +func (s) TestEDS_EmptyUpdate(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + const cacheTimeout = 100 * time.Microsecond + oldCacheTimeout := balancergroup.DefaultSubBalancerCloseTimeout + balancergroup.DefaultSubBalancerCloseTimeout = cacheTimeout + defer func() { balancergroup.DefaultSubBalancerCloseTimeout = oldCacheTimeout }() + + // The first update is an empty update. + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + // Pick should fail with transient failure, and all priority removed error. + if err := testErrPickerFromCh(cc.NewPickerCh, priority.ErrAllPrioritiesRemoved); err != nil { + t.Fatal(err) + } + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Pick with only the first backend. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) + } + + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + // Pick should fail with transient failure, and all priority removed error. + if err := testErrPickerFromCh(cc.NewPickerCh, priority.ErrAllPrioritiesRemoved); err != nil { + t.Fatal(err) + } + + // Wait for the old SubConn to be removed (which happens when the child + // policy is closed), so a new update would trigger a new SubConn (we need + // this new SubConn to tell if the next picker is newly created). + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Handle another update with priorities and localities. + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Pick with only the first backend. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) + } +} + +func (s) TestEDS_CircuitBreaking(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + var maxRequests uint32 = 50 + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + MaxConcurrentRequests: &maxRequests, + Type: DiscoveryMechanismTypeEDS, + }}, + }, + }); err != nil { + t.Fatal(err) + } + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Picks with drops. + dones := []func(){} + p := <-cc.NewPickerCh + for i := 0; i < 100; i++ { + pr, err := p.Pick(balancer.PickInfo{}) + if i < 50 && err != nil { + t.Errorf("The first 50%% picks should be non-drops, got error %v", err) + } else if i > 50 && err == nil { + t.Errorf("The second 50%% picks should be drops, got error ") + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + for _, done := range dones { + done() + } + dones = []func(){} + + // Pick without drops. + for i := 0; i < 50; i++ { + pr, err := p.Pick(balancer.PickInfo{}) + if err != nil { + t.Errorf("The third 50%% picks should be non-drops, got error %v", err) + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + // Without this, future tests with the same service name will fail. + for _, done := range dones { + done() + } + + // Send another update, with only circuit breaking update (and no picker + // update afterwards). Make sure the new picker uses the new configs. + var maxRequests2 uint32 = 10 + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + MaxConcurrentRequests: &maxRequests2, + Type: DiscoveryMechanismTypeEDS, + }}, + }, + }); err != nil { + t.Fatal(err) + } + + // Picks with drops. + dones = []func(){} + p2 := <-cc.NewPickerCh + for i := 0; i < 100; i++ { + pr, err := p2.Pick(balancer.PickInfo{}) + if i < 10 && err != nil { + t.Errorf("The first 10%% picks should be non-drops, got error %v", err) + } else if i > 10 && err == nil { + t.Errorf("The next 90%% picks should be drops, got error ") + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + for _, done := range dones { + done() + } + dones = []func(){} + + // Pick without drops. + for i := 0; i < 10; i++ { + pr, err := p2.Pick(balancer.PickInfo{}) + if err != nil { + t.Errorf("The next 10%% picks should be non-drops, got error %v", err) + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + // Without this, future tests with the same service name will fail. + for _, done := range dones { + done() + } +} diff --git a/xds/internal/balancer/lrs/logging.go b/xds/internal/balancer/clusterresolver/logging.go similarity index 84% rename from xds/internal/balancer/lrs/logging.go rename to xds/internal/balancer/clusterresolver/logging.go index 602dac09959..728f1f709c2 100644 --- a/xds/internal/balancer/lrs/logging.go +++ b/xds/internal/balancer/clusterresolver/logging.go @@ -16,7 +16,7 @@ * */ -package lrs +package clusterresolver import ( "fmt" @@ -25,10 +25,10 @@ import ( internalgrpclog "google.golang.org/grpc/internal/grpclog" ) -const prefix = "[lrs-lb %p] " +const prefix = "[xds-cluster-resolver-lb %p] " var logger = grpclog.Component("xds") -func prefixLogger(p *lrsBalancer) *internalgrpclog.PrefixLogger { +func prefixLogger(p *clusterResolverBalancer) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) } diff --git a/xds/internal/balancer/edsbalancer/eds_impl_priority_test.go b/xds/internal/balancer/clusterresolver/priority_test.go similarity index 54% rename from xds/internal/balancer/edsbalancer/eds_impl_priority_test.go rename to xds/internal/balancer/clusterresolver/priority_test.go index 7696feb5bd0..8438a373d9d 100644 --- a/xds/internal/balancer/edsbalancer/eds_impl_priority_test.go +++ b/xds/internal/balancer/clusterresolver/priority_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package edsbalancer +package clusterresolver import ( "context" @@ -26,6 +26,8 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/balancer/priority" "google.golang.org/grpc/xds/internal/testutils" ) @@ -34,15 +36,14 @@ import ( // // Init 0 and 1; 0 is up, use 0; add 2, use 0; remove 2, use 0. func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with priorities [0, 1], each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[0]; got != want { @@ -51,22 +52,20 @@ func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { sc1 := <-cc.NewSubConnCh // p0 is ready. - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p1 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } - // Add p2, it shouldn't cause any udpates. + // Add p2, it shouldn't cause any updates. clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab2.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) select { case <-cc.NewPickerCh: @@ -82,7 +81,7 @@ func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab3.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) select { case <-cc.NewPickerCh: @@ -100,15 +99,14 @@ func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { // Init 0 and 1; 0 is up, use 0; 0 is down, 1 is up, use 1; add 2, use 1; 1 is // down, use 2; remove 2, use 1. func (s) TestEDSPriority_SwitchPriority(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with priorities [0, 1], each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { @@ -117,41 +115,35 @@ func (s) TestEDSPriority_SwitchPriority(t *testing.T) { sc0 := <-cc.NewSubConnCh // p0 is ready. - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p0 := <-cc.NewPickerCh - want := []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p0)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Turn down 0, 1 is used. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 1. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } - // Add p2, it shouldn't cause any udpates. + // Add p2, it shouldn't cause any updates. clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab2.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) select { case <-cc.NewPickerCh: @@ -164,29 +156,25 @@ func (s) TestEDSPriority_SwitchPriority(t *testing.T) { } // Turn down 1, use 2 - edsb.handleSubConnStateChange(sc1, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 2. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) } // Remove 2, use 1. clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab3.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) // p2 SubConns are removed. scToRemove := <-cc.RemoveSubConnCh @@ -195,28 +183,23 @@ func (s) TestEDSPriority_SwitchPriority(t *testing.T) { } // Should get an update with 1's old picker, to override 2's old picker. - p3 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if _, err := p3.Pick(balancer.PickInfo{}); err != balancer.ErrTransientFailure { - t.Fatalf("want pick error %v, got %v", balancer.ErrTransientFailure, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, balancer.ErrTransientFailure); err != nil { + t.Fatal(err) } + } // Add a lower priority while the higher priority is down. // // Init 0 and 1; 0 and 1 both down; add 2, use 2. func (s) TestEDSPriority_HigherDownWhileAddingLower(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -224,21 +207,18 @@ func (s) TestEDSPriority_HigherDownWhileAddingLower(t *testing.T) { sc0 := <-cc.NewSubConnCh // Turn down 0, 1 is used. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh // Turn down 1, pick should error. - edsb.handleSubConnStateChange(sc1, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) // Test pick failure. - pFail := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != balancer.ErrTransientFailure { - t.Fatalf("want pick error %v, got %v", balancer.ErrTransientFailure, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, balancer.ErrTransientFailure); err != nil { + t.Fatal(err) } // Add p2, it should create a new SubConn. @@ -246,41 +226,34 @@ func (s) TestEDSPriority_HigherDownWhileAddingLower(t *testing.T) { clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab2.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 2. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) } + } // When a higher priority becomes available, all lower priorities are closed. // // Init 0,1,2; 0 and 1 down, use 2; 0 up, close 1 and 2. func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with priorities [0,1,2], each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab1.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -288,39 +261,55 @@ func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { sc0 := <-cc.NewSubConnCh // Turn down 0, 1 is used. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh // Turn down 1, 2 is used. - edsb.handleSubConnStateChange(sc1, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 2. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) } // When 0 becomes ready, 0 should be used, 1 and 2 should all be closed. - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + var ( + scToRemove []balancer.SubConn + scToRemoveMap = make(map[balancer.SubConn]struct{}) + ) + // Each subconn is removed twice. This is OK in production, but it makes + // testing harder. + // + // The sub-balancer to be closed is priority's child, clusterimpl, who has + // weightedtarget as children. + // + // - When clusterimpl is removed from priority's balancergroup, all its + // subconns are removed once. + // - When clusterimpl is closed, it closes weightedtarget, and this + // weightedtarget's balancer removes all the same subconns again. + for i := 0; i < 4; i++ { + // We expect 2 subconns, so we recv from channel 4 times. + scToRemoveMap[<-cc.RemoveSubConnCh] = struct{}{} + } + for sc := range scToRemoveMap { + scToRemove = append(scToRemove, sc) + } // sc1 and sc2 should be removed. // // With localities caching, the lower priorities are closed after a timeout, // in goroutines. The order is no longer guaranteed. - scToRemove := []balancer.SubConn{<-cc.RemoveSubConnCh, <-cc.RemoveSubConnCh} if !(cmp.Equal(scToRemove[0], sc1, cmp.AllowUnexported(testutils.TestSubConn{})) && cmp.Equal(scToRemove[1], sc2, cmp.AllowUnexported(testutils.TestSubConn{}))) && !(cmp.Equal(scToRemove[0], sc2, cmp.AllowUnexported(testutils.TestSubConn{})) && @@ -329,12 +318,8 @@ func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { } // Test pick with 0. - p0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p0.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } } @@ -345,23 +330,20 @@ func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { func (s) TestEDSPriority_InitTimeout(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := priority.DefaultPriorityInitTimeout + priority.DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + priority.DefaultPriorityInitTimeout = old } }()() - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -369,7 +351,7 @@ func (s) TestEDSPriority_InitTimeout(t *testing.T) { sc0 := <-cc.NewSubConnCh // Keep 0 in connecting, 1 will be used after init timeout. - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) // Make sure new SubConn is created before timeout. select { @@ -384,16 +366,12 @@ func (s) TestEDSPriority_InitTimeout(t *testing.T) { } sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 1. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } } @@ -402,51 +380,44 @@ func (s) TestEDSPriority_InitTimeout(t *testing.T) { // - start with 2 locality with p0 and p1 // - add localities to existing p0 and p1 func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab0 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab0.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab0.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab0.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab0.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc0 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p0 := <-cc.NewPickerCh - want := []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p0)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Turn down p0 subconns, p1 subconns will be created. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p1 subconns. - p1 := <-cc.NewPickerCh - want = []balancer.SubConn{sc1} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } // Reconnect p0 subconns, p1 subconn will be closed. - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) scToRemove := <-cc.RemoveSubConnCh if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { @@ -454,10 +425,8 @@ func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { } // Test roundrobin with only p0 subconns. - p2 := <-cc.NewPickerCh - want = []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Add two localities, with two priorities, with one backend. @@ -466,39 +435,34 @@ func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab1.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) clab1.AddLocality(testSubZones[3], 1, 1, testEndpointAddrs[3:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only two p0 subconns. - p3 := <-cc.NewPickerCh - want = []balancer.SubConn{sc0, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p3)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0, sc2}); err != nil { + t.Fatal(err) } // Turn down p0 subconns, p1 subconns will be created. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) - edsb.handleSubConnStateChange(sc2, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) sc4 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc4, connectivity.Connecting) - edsb.handleSubConnStateChange(sc4, connectivity.Ready) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p1 subconns. - p4 := <-cc.NewPickerCh - want = []balancer.SubConn{sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p4)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3, sc4}); err != nil { + t.Fatal(err) } } @@ -506,62 +470,55 @@ func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { func (s) TestEDSPriority_RemovesAllLocalities(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := priority.DefaultPriorityInitTimeout + priority.DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + priority.DefaultPriorityInitTimeout = old } }()() - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab0 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab0.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab0.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab0.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab0.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc0 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p0 := <-cc.NewPickerCh - want := []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p0)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Remove all priorities. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) // p0 subconn should be removed. scToRemove := <-cc.RemoveSubConnCh + <-cc.RemoveSubConnCh // Drain the duplicate subconn removed. if !cmp.Equal(scToRemove, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { t.Fatalf("RemoveSubConn, want %v, got %v", sc0, scToRemove) } + // time.Sleep(time.Second) + // Test pick return TransientFailure. - pFail := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != errAllPrioritiesRemoved { - t.Fatalf("want pick error %v, got %v", errAllPrioritiesRemoved, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, priority.ErrAllPrioritiesRemoved); err != nil { + t.Fatal(err) } // Re-add two localities, with previous priorities, but different backends. clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[3:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) addrs01 := <-cc.NewSubConnAddrsCh if got, want := addrs01[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -578,45 +535,39 @@ func (s) TestEDSPriority_RemovesAllLocalities(t *testing.T) { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc11 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc11, connectivity.Connecting) - edsb.handleSubConnStateChange(sc11, connectivity.Ready) + edsb.UpdateSubConnState(sc11, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc11, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p1 subconns. - p1 := <-cc.NewPickerCh - want = []balancer.SubConn{sc11} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc11}); err != nil { + t.Fatal(err) } // Remove p1 from EDS, to fallback to p0. clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) // p1 subconn should be removed. scToRemove1 := <-cc.RemoveSubConnCh + <-cc.RemoveSubConnCh // Drain the duplicate subconn removed. if !cmp.Equal(scToRemove1, sc11, cmp.AllowUnexported(testutils.TestSubConn{})) { t.Fatalf("RemoveSubConn, want %v, got %v", sc11, scToRemove1) } // Test pick return TransientFailure. - pFail1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if scst, err := pFail1.Pick(balancer.PickInfo{}); err != balancer.ErrNoSubConnAvailable { - t.Fatalf("want pick error _, %v, got %v, _ ,%v", balancer.ErrTransientFailure, scst, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, balancer.ErrNoSubConnAvailable); err != nil { + t.Fatal(err) } // Send an ready update for the p0 sc that was received when re-adding // localities to EDS. - edsb.handleSubConnStateChange(sc01, connectivity.Connecting) - edsb.handleSubConnStateChange(sc01, connectivity.Ready) + edsb.UpdateSubConnState(sc01, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc01, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p2 := <-cc.NewPickerCh - want = []balancer.SubConn{sc01} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc01}); err != nil { + t.Fatal(err) } select { @@ -630,83 +581,16 @@ func (s) TestEDSPriority_RemovesAllLocalities(t *testing.T) { } } -func (s) TestPriorityType(t *testing.T) { - p0 := newPriorityType(0) - p1 := newPriorityType(1) - p2 := newPriorityType(2) - - if !p0.higherThan(p1) || !p0.higherThan(p2) { - t.Errorf("want p0 to be higher than p1 and p2, got p0>p1: %v, p0>p2: %v", !p0.higherThan(p1), !p0.higherThan(p2)) - } - if !p1.lowerThan(p0) || !p1.higherThan(p2) { - t.Errorf("want p1 to be between p0 and p2, got p1p2: %v", !p1.lowerThan(p0), !p1.higherThan(p2)) - } - if !p2.lowerThan(p0) || !p2.lowerThan(p1) { - t.Errorf("want p2 to be lower than p0 and p1, got p2 subConnMu, but this is implicit via - // balancers (starting balancer with next priority while holding priorityMu, - // and the balancer may create new SubConn). - - priorityMu sync.Mutex - // priorities are pointers, and will be nil when EDS returns empty result. - priorityInUse priorityType - priorityLowest priorityType - priorityToState map[priorityType]*balancer.State - // The timer to give a priority 10 seconds to connect. And if the priority - // doesn't go into Ready/Failure, start the next priority. - // - // One timer is enough because there can be at most one priority in init - // state. - priorityInitTimer *time.Timer - - subConnMu sync.Mutex - subConnToPriority map[balancer.SubConn]priorityType - - pickerMu sync.Mutex - dropConfig []xdsclient.OverloadDropConfig - drops []*dropper - innerState balancer.State // The state of the picker without drop support. - serviceRequestsCounter *client.ServiceRequestsCounter - serviceRequestCountMax uint32 -} - -// newEDSBalancerImpl create a new edsBalancerImpl. -func newEDSBalancerImpl(cc balancer.ClientConn, bOpts balancer.BuildOptions, enqueueState func(priorityType, balancer.State), lr load.PerClusterReporter, logger *grpclog.PrefixLogger) *edsBalancerImpl { - edsImpl := &edsBalancerImpl{ - cc: cc, - buildOpts: bOpts, - logger: logger, - subBalancerBuilder: balancer.Get(roundrobin.Name), - loadReporter: lr, - - enqueueChildBalancerStateUpdate: enqueueState, - - priorityToLocalities: make(map[priorityType]*balancerGroupWithConfig), - priorityToState: make(map[priorityType]*balancer.State), - subConnToPriority: make(map[balancer.SubConn]priorityType), - serviceRequestCountMax: defaultServiceRequestCountMax, - } - // Don't start balancer group here. Start it when handling the first EDS - // response. Otherwise the balancer group will be started with round-robin, - // and if users specify a different sub-balancer, all balancers in balancer - // group will be closed and recreated when sub-balancer update happens. - return edsImpl -} - -// handleChildPolicy updates the child balancers handling endpoints. Child -// policy is roundrobin by default. If the specified balancer is not installed, -// the old child balancer will be used. -// -// HandleChildPolicy and HandleEDSResponse must be called by the same goroutine. -func (edsImpl *edsBalancerImpl) handleChildPolicy(name string, config json.RawMessage) { - if edsImpl.subBalancerBuilder.Name() == name { - return - } - newSubBalancerBuilder := balancer.Get(name) - if newSubBalancerBuilder == nil { - edsImpl.logger.Infof("edsBalancerImpl: failed to find balancer with name %q, keep using %q", name, edsImpl.subBalancerBuilder.Name()) - return - } - edsImpl.subBalancerBuilder = newSubBalancerBuilder - for _, bgwc := range edsImpl.priorityToLocalities { - if bgwc == nil { - continue - } - for lid, config := range bgwc.configs { - lidJSON, err := lid.ToString() - if err != nil { - edsImpl.logger.Errorf("failed to marshal LocalityID: %#v, skipping this locality", lid) - continue - } - // TODO: (eds) add support to balancer group to support smoothly - // switching sub-balancers (keep old balancer around until new - // balancer becomes ready). - bgwc.bg.Remove(lidJSON) - bgwc.bg.Add(lidJSON, edsImpl.subBalancerBuilder) - bgwc.bg.UpdateClientConnState(lidJSON, balancer.ClientConnState{ - ResolverState: resolver.State{Addresses: config.addrs}, - }) - // This doesn't need to manually update picker, because the new - // sub-balancer will send it's picker later. - } - } -} - -// updateDrops compares new drop policies with the old. If they are different, -// it updates the drop policies and send ClientConn an updated picker. -func (edsImpl *edsBalancerImpl) updateDrops(dropConfig []xdsclient.OverloadDropConfig) { - if cmp.Equal(dropConfig, edsImpl.dropConfig) { - return - } - edsImpl.pickerMu.Lock() - edsImpl.dropConfig = dropConfig - var newDrops []*dropper - for _, c := range edsImpl.dropConfig { - newDrops = append(newDrops, newDropper(c)) - } - edsImpl.drops = newDrops - if edsImpl.innerState.Picker != nil { - // Update picker with old inner picker, new drops. - edsImpl.cc.UpdateState(balancer.State{ - ConnectivityState: edsImpl.innerState.ConnectivityState, - Picker: newDropPicker(edsImpl.innerState.Picker, newDrops, edsImpl.loadReporter, edsImpl.serviceRequestsCounter, edsImpl.serviceRequestCountMax)}, - ) - } - edsImpl.pickerMu.Unlock() -} - -// handleEDSResponse handles the EDS response and creates/deletes localities and -// SubConns. It also handles drops. -// -// HandleChildPolicy and HandleEDSResponse must be called by the same goroutine. -func (edsImpl *edsBalancerImpl) handleEDSResponse(edsResp xdsclient.EndpointsUpdate) { - // TODO: Unhandled fields from EDS response: - // - edsResp.GetPolicy().GetOverprovisioningFactor() - // - locality.GetPriority() - // - lbEndpoint.GetMetadata(): contains BNS name, send to sub-balancers - // - as service config or as resolved address - // - if socketAddress is not ip:port - // - socketAddress.GetNamedPort(), socketAddress.GetResolverName() - // - resolve endpoint's name with another resolver - - // If the first EDS update is an empty update, nothing is changing from the - // previous update (which is the default empty value). We need to explicitly - // handle first update being empty, and send a transient failure picker. - // - // TODO: define Equal() on type EndpointUpdate to avoid DeepEqual. And do - // the same for the other types. - if !edsImpl.respReceived && reflect.DeepEqual(edsResp, xdsclient.EndpointsUpdate{}) { - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: connectivity.TransientFailure, Picker: base.NewErrPicker(errAllPrioritiesRemoved)}) - } - edsImpl.respReceived = true - - edsImpl.updateDrops(edsResp.Drops) - - // Filter out all localities with weight 0. - // - // Locality weighted load balancer can be enabled by setting an option in - // CDS, and the weight of each locality. Currently, without the guarantee - // that CDS is always sent, we assume locality weighted load balance is - // always enabled, and ignore all weight 0 localities. - // - // In the future, we should look at the config in CDS response and decide - // whether locality weight matters. - newLocalitiesWithPriority := make(map[priorityType][]xdsclient.Locality) - for _, locality := range edsResp.Localities { - if locality.Weight == 0 { - continue - } - priority := newPriorityType(locality.Priority) - newLocalitiesWithPriority[priority] = append(newLocalitiesWithPriority[priority], locality) - } - - var ( - priorityLowest priorityType - priorityChanged bool - ) - - for priority, newLocalities := range newLocalitiesWithPriority { - if !priorityLowest.isSet() || priorityLowest.higherThan(priority) { - priorityLowest = priority - } - - bgwc, ok := edsImpl.priorityToLocalities[priority] - if !ok { - // Create balancer group if it's never created (this is the first - // time this priority is received). We don't start it here. It may - // be started when necessary (e.g. when higher is down, or if it's a - // new lowest priority). - ccPriorityWrapper := edsImpl.ccWrapperWithPriority(priority) - stateAggregator := weightedaggregator.New(ccPriorityWrapper, edsImpl.logger, newRandomWRR) - bgwc = &balancerGroupWithConfig{ - bg: balancergroup.New(ccPriorityWrapper, edsImpl.buildOpts, stateAggregator, edsImpl.loadReporter, edsImpl.logger), - stateAggregator: stateAggregator, - configs: make(map[internal.LocalityID]*localityConfig), - } - edsImpl.priorityToLocalities[priority] = bgwc - priorityChanged = true - edsImpl.logger.Infof("New priority %v added", priority) - } - edsImpl.handleEDSResponsePerPriority(bgwc, newLocalities) - } - edsImpl.priorityLowest = priorityLowest - - // Delete priorities that are removed in the latest response, and also close - // the balancer group. - for p, bgwc := range edsImpl.priorityToLocalities { - if _, ok := newLocalitiesWithPriority[p]; !ok { - delete(edsImpl.priorityToLocalities, p) - bgwc.bg.Close() - delete(edsImpl.priorityToState, p) - priorityChanged = true - edsImpl.logger.Infof("Priority %v deleted", p) - } - } - - // If priority was added/removed, it may affect the balancer group to use. - // E.g. priorityInUse was removed, or all priorities are down, and a new - // lower priority was added. - if priorityChanged { - edsImpl.handlePriorityChange() - } -} - -func (edsImpl *edsBalancerImpl) handleEDSResponsePerPriority(bgwc *balancerGroupWithConfig, newLocalities []xdsclient.Locality) { - // newLocalitiesSet contains all names of localities in the new EDS response - // for the same priority. It's used to delete localities that are removed in - // the new EDS response. - newLocalitiesSet := make(map[internal.LocalityID]struct{}) - var rebuildStateAndPicker bool - for _, locality := range newLocalities { - // One balancer for each locality. - - lid := locality.ID - lidJSON, err := lid.ToString() - if err != nil { - edsImpl.logger.Errorf("failed to marshal LocalityID: %#v, skipping this locality", lid) - continue - } - newLocalitiesSet[lid] = struct{}{} - - newWeight := locality.Weight - var newAddrs []resolver.Address - for _, lbEndpoint := range locality.Endpoints { - // Filter out all "unhealthy" endpoints (unknown and - // healthy are both considered to be healthy: - // https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/core/health_check.proto#envoy-api-enum-core-healthstatus). - if lbEndpoint.HealthStatus != xdsclient.EndpointHealthStatusHealthy && - lbEndpoint.HealthStatus != xdsclient.EndpointHealthStatusUnknown { - continue - } - - address := resolver.Address{ - Addr: lbEndpoint.Address, - } - if edsImpl.subBalancerBuilder.Name() == weightedroundrobin.Name && lbEndpoint.Weight != 0 { - ai := weightedroundrobin.AddrInfo{Weight: lbEndpoint.Weight} - address = weightedroundrobin.SetAddrInfo(address, ai) - // Metadata field in resolver.Address is deprecated. The - // attributes field should be used to specify arbitrary - // attributes about the address. We still need to populate the - // Metadata field here to allow users of this field to migrate - // to the new one. - // TODO(easwars): Remove this once all users have migrated. - // See https://github.com/grpc/grpc-go/issues/3563. - address.Metadata = &ai - } - newAddrs = append(newAddrs, address) - } - var weightChanged, addrsChanged bool - config, ok := bgwc.configs[lid] - if !ok { - // A new balancer, add it to balancer group and balancer map. - bgwc.stateAggregator.Add(lidJSON, newWeight) - bgwc.bg.Add(lidJSON, edsImpl.subBalancerBuilder) - config = &localityConfig{ - weight: newWeight, - } - bgwc.configs[lid] = config - - // weightChanged is false for new locality, because there's no need - // to update weight in bg. - addrsChanged = true - edsImpl.logger.Infof("New locality %v added", lid) - } else { - // Compare weight and addrs. - if config.weight != newWeight { - weightChanged = true - } - if !cmp.Equal(config.addrs, newAddrs) { - addrsChanged = true - } - edsImpl.logger.Infof("Locality %v updated, weightedChanged: %v, addrsChanged: %v", lid, weightChanged, addrsChanged) - } - - if weightChanged { - config.weight = newWeight - bgwc.stateAggregator.UpdateWeight(lidJSON, newWeight) - rebuildStateAndPicker = true - } - - if addrsChanged { - config.addrs = newAddrs - bgwc.bg.UpdateClientConnState(lidJSON, balancer.ClientConnState{ - ResolverState: resolver.State{Addresses: newAddrs}, - }) - } - } - - // Delete localities that are removed in the latest response. - for lid := range bgwc.configs { - lidJSON, err := lid.ToString() - if err != nil { - edsImpl.logger.Errorf("failed to marshal LocalityID: %#v, skipping this locality", lid) - continue - } - if _, ok := newLocalitiesSet[lid]; !ok { - bgwc.stateAggregator.Remove(lidJSON) - bgwc.bg.Remove(lidJSON) - delete(bgwc.configs, lid) - edsImpl.logger.Infof("Locality %v deleted", lid) - rebuildStateAndPicker = true - } - } - - if rebuildStateAndPicker { - bgwc.stateAggregator.BuildAndUpdate() - } -} - -// handleSubConnStateChange handles the state change and update pickers accordingly. -func (edsImpl *edsBalancerImpl) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { - edsImpl.subConnMu.Lock() - var bgwc *balancerGroupWithConfig - if p, ok := edsImpl.subConnToPriority[sc]; ok { - if s == connectivity.Shutdown { - // Only delete sc from the map when state changed to Shutdown. - delete(edsImpl.subConnToPriority, sc) - } - bgwc = edsImpl.priorityToLocalities[p] - } - edsImpl.subConnMu.Unlock() - if bgwc == nil { - edsImpl.logger.Infof("edsBalancerImpl: priority not found for sc state change") - return - } - if bg := bgwc.bg; bg != nil { - bg.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: s}) - } -} - -// updateServiceRequestsConfig handles changes to the circuit breaking configuration. -func (edsImpl *edsBalancerImpl) updateServiceRequestsConfig(serviceName string, max *uint32) { - if !env.CircuitBreakingSupport { - return - } - edsImpl.pickerMu.Lock() - var updatePicker bool - if edsImpl.serviceRequestsCounter == nil || edsImpl.serviceRequestsCounter.ServiceName != serviceName { - edsImpl.serviceRequestsCounter = client.GetServiceRequestsCounter(serviceName) - updatePicker = true - } - - var newMax uint32 = defaultServiceRequestCountMax - if max != nil { - newMax = *max - } - if edsImpl.serviceRequestCountMax != newMax { - edsImpl.serviceRequestCountMax = newMax - updatePicker = true - } - if updatePicker && edsImpl.innerState.Picker != nil { - // Update picker with old inner picker, new counter and counterMax. - edsImpl.cc.UpdateState(balancer.State{ - ConnectivityState: edsImpl.innerState.ConnectivityState, - Picker: newDropPicker(edsImpl.innerState.Picker, edsImpl.drops, edsImpl.loadReporter, edsImpl.serviceRequestsCounter, edsImpl.serviceRequestCountMax)}, - ) - } - edsImpl.pickerMu.Unlock() -} - -// updateState first handles priority, and then wraps picker in a drop picker -// before forwarding the update. -func (edsImpl *edsBalancerImpl) updateState(priority priorityType, s balancer.State) { - _, ok := edsImpl.priorityToLocalities[priority] - if !ok { - edsImpl.logger.Infof("eds: received picker update from unknown priority") - return - } - - if edsImpl.handlePriorityWithNewState(priority, s) { - edsImpl.pickerMu.Lock() - defer edsImpl.pickerMu.Unlock() - edsImpl.innerState = s - // Don't reset drops when it's a state change. - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: newDropPicker(s.Picker, edsImpl.drops, edsImpl.loadReporter, edsImpl.serviceRequestsCounter, edsImpl.serviceRequestCountMax)}) - } -} - -func (edsImpl *edsBalancerImpl) ccWrapperWithPriority(priority priorityType) *edsBalancerWrapperCC { - return &edsBalancerWrapperCC{ - ClientConn: edsImpl.cc, - priority: priority, - parent: edsImpl, - } -} - -// edsBalancerWrapperCC implements the balancer.ClientConn API and get passed to -// each balancer group. It contains the locality priority. -type edsBalancerWrapperCC struct { - balancer.ClientConn - priority priorityType - parent *edsBalancerImpl -} - -func (ebwcc *edsBalancerWrapperCC) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { - return ebwcc.parent.newSubConn(ebwcc.priority, addrs, opts) -} -func (ebwcc *edsBalancerWrapperCC) UpdateState(state balancer.State) { - ebwcc.parent.enqueueChildBalancerStateUpdate(ebwcc.priority, state) -} - -func (edsImpl *edsBalancerImpl) newSubConn(priority priorityType, addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { - sc, err := edsImpl.cc.NewSubConn(addrs, opts) - if err != nil { - return nil, err - } - edsImpl.subConnMu.Lock() - edsImpl.subConnToPriority[sc] = priority - edsImpl.subConnMu.Unlock() - return sc, nil -} - -// close closes the balancer. -func (edsImpl *edsBalancerImpl) close() { - for _, bgwc := range edsImpl.priorityToLocalities { - if bg := bgwc.bg; bg != nil { - bgwc.stateAggregator.Stop() - bg.Close() - } - } -} - -type dropPicker struct { - drops []*dropper - p balancer.Picker - loadStore load.PerClusterReporter - counter *client.ServiceRequestsCounter - countMax uint32 -} - -func newDropPicker(p balancer.Picker, drops []*dropper, loadStore load.PerClusterReporter, counter *client.ServiceRequestsCounter, countMax uint32) *dropPicker { - return &dropPicker{ - drops: drops, - p: p, - loadStore: loadStore, - counter: counter, - countMax: countMax, - } -} - -func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - var ( - drop bool - category string - ) - for _, dp := range d.drops { - if dp.drop() { - drop = true - category = dp.c.Category - break - } - } - if drop { - if d.loadStore != nil { - d.loadStore.CallDropped(category) - } - return balancer.PickResult{}, status.Errorf(codes.Unavailable, "RPC is dropped") - } - if d.counter != nil { - if err := d.counter.StartRequest(d.countMax); err != nil { - // Drops by circuit breaking are reported with empty category. They - // will be reported only in total drops, but not in per category. - if d.loadStore != nil { - d.loadStore.CallDropped("") - } - return balancer.PickResult{}, status.Errorf(codes.Unavailable, err.Error()) - } - pr, err := d.p.Pick(info) - if err != nil { - d.counter.EndRequest() - return pr, err - } - oldDone := pr.Done - pr.Done = func(doneInfo balancer.DoneInfo) { - d.counter.EndRequest() - if oldDone != nil { - oldDone(doneInfo) - } - } - return pr, err - } - // TODO: (eds) don't drop unless the inner picker is READY. Similar to - // https://github.com/grpc/grpc-go/issues/2622. - return d.p.Pick(info) -} diff --git a/xds/internal/balancer/edsbalancer/eds_impl_priority.go b/xds/internal/balancer/edsbalancer/eds_impl_priority.go deleted file mode 100644 index 53ac6ef5e87..00000000000 --- a/xds/internal/balancer/edsbalancer/eds_impl_priority.go +++ /dev/null @@ -1,358 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "errors" - "fmt" - "time" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/base" - "google.golang.org/grpc/connectivity" -) - -var errAllPrioritiesRemoved = errors.New("eds: no locality is provided, all priorities are removed") - -// handlePriorityChange handles priority after EDS adds/removes a -// priority. -// -// - If all priorities were deleted, unset priorityInUse, and set parent -// ClientConn to TransientFailure -// - If priorityInUse wasn't set, this is either the first EDS resp, or the -// previous EDS resp deleted everything. Set priorityInUse to 0, and start 0. -// - If priorityInUse was deleted, send the picker from the new lowest priority -// to parent ClientConn, and set priorityInUse to the new lowest. -// - If priorityInUse has a non-Ready state, and also there's a priority lower -// than priorityInUse (which means a lower priority was added), set the next -// priority as new priorityInUse, and start the bg. -func (edsImpl *edsBalancerImpl) handlePriorityChange() { - edsImpl.priorityMu.Lock() - defer edsImpl.priorityMu.Unlock() - - // Everything was removed by EDS. - if !edsImpl.priorityLowest.isSet() { - edsImpl.priorityInUse = newPriorityTypeUnset() - // Stop the init timer. This can happen if the only priority is removed - // shortly after it's added. - if timer := edsImpl.priorityInitTimer; timer != nil { - timer.Stop() - edsImpl.priorityInitTimer = nil - } - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: connectivity.TransientFailure, Picker: base.NewErrPicker(errAllPrioritiesRemoved)}) - return - } - - // priorityInUse wasn't set, use 0. - if !edsImpl.priorityInUse.isSet() { - edsImpl.logger.Infof("Switching priority from unset to %v", 0) - edsImpl.startPriority(newPriorityType(0)) - return - } - - // priorityInUse was deleted, use the new lowest. - if _, ok := edsImpl.priorityToLocalities[edsImpl.priorityInUse]; !ok { - oldP := edsImpl.priorityInUse - edsImpl.priorityInUse = edsImpl.priorityLowest - edsImpl.logger.Infof("Switching priority from %v to %v, because former was deleted", oldP, edsImpl.priorityInUse) - if s, ok := edsImpl.priorityToState[edsImpl.priorityLowest]; ok { - edsImpl.cc.UpdateState(*s) - } else { - // If state for priorityLowest is not found, this means priorityLowest was - // started, but never sent any update. The init timer fired and - // triggered the next priority. The old_priorityInUse (that was just - // deleted EDS) was picked later. - // - // We don't have an old state to send to parent, but we also don't - // want parent to keep using picker from old_priorityInUse. Send an - // update to trigger block picks until a new picker is ready. - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable)}) - } - return - } - - // priorityInUse is not ready, look for next priority, and use if found. - if s, ok := edsImpl.priorityToState[edsImpl.priorityInUse]; ok && s.ConnectivityState != connectivity.Ready { - pNext := edsImpl.priorityInUse.nextLower() - if _, ok := edsImpl.priorityToLocalities[pNext]; ok { - edsImpl.logger.Infof("Switching priority from %v to %v, because latter was added, and former wasn't Ready") - edsImpl.startPriority(pNext) - } - } -} - -// startPriority sets priorityInUse to p, and starts the balancer group for p. -// It also starts a timer to fall to next priority after timeout. -// -// Caller must hold priorityMu, priority must exist, and edsImpl.priorityInUse -// must be non-nil. -func (edsImpl *edsBalancerImpl) startPriority(priority priorityType) { - edsImpl.priorityInUse = priority - p := edsImpl.priorityToLocalities[priority] - // NOTE: this will eventually send addresses to sub-balancers. If the - // sub-balancer tries to update picker, it will result in a deadlock on - // priorityMu in the update is handled synchronously. The deadlock is - // currently avoided by handling balancer update in a goroutine (the run - // goroutine in the parent eds balancer). When priority balancer is split - // into its own, this asynchronous state handling needs to be copied. - p.stateAggregator.Start() - p.bg.Start() - // startPriority can be called when - // 1. first EDS resp, start p0 - // 2. a high priority goes Failure, start next - // 3. a high priority init timeout, start next - // - // In all the cases, the existing init timer is either closed, also already - // expired. There's no need to close the old timer. - edsImpl.priorityInitTimer = time.AfterFunc(defaultPriorityInitTimeout, func() { - edsImpl.priorityMu.Lock() - defer edsImpl.priorityMu.Unlock() - if !edsImpl.priorityInUse.isSet() || !edsImpl.priorityInUse.equal(priority) { - return - } - edsImpl.priorityInitTimer = nil - pNext := priority.nextLower() - if _, ok := edsImpl.priorityToLocalities[pNext]; ok { - edsImpl.startPriority(pNext) - } - }) -} - -// handlePriorityWithNewState start/close priorities based on the connectivity -// state. It returns whether the state should be forwarded to parent ClientConn. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewState(priority priorityType, s balancer.State) bool { - edsImpl.priorityMu.Lock() - defer edsImpl.priorityMu.Unlock() - - if !edsImpl.priorityInUse.isSet() { - edsImpl.logger.Infof("eds: received picker update when no priority is in use (EDS returned an empty list)") - return false - } - - if edsImpl.priorityInUse.higherThan(priority) { - // Lower priorities should all be closed, this is an unexpected update. - edsImpl.logger.Infof("eds: received picker update from priority lower then priorityInUse") - return false - } - - bState, ok := edsImpl.priorityToState[priority] - if !ok { - bState = &balancer.State{} - edsImpl.priorityToState[priority] = bState - } - oldState := bState.ConnectivityState - *bState = s - - switch s.ConnectivityState { - case connectivity.Ready: - return edsImpl.handlePriorityWithNewStateReady(priority) - case connectivity.TransientFailure: - return edsImpl.handlePriorityWithNewStateTransientFailure(priority) - case connectivity.Connecting: - return edsImpl.handlePriorityWithNewStateConnecting(priority, oldState) - default: - // New state is Idle, should never happen. Don't forward. - return false - } -} - -// handlePriorityWithNewStateReady handles state Ready and decides whether to -// forward update or not. -// -// An update with state Ready: -// - If it's from higher priority: -// - Forward the update -// - Set the priority as priorityInUse -// - Close all priorities lower than this one -// - If it's from priorityInUse: -// - Forward and do nothing else -// -// Caller must make sure priorityInUse is not higher than priority. -// -// Caller must hold priorityMu. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewStateReady(priority priorityType) bool { - // If one priority higher or equal to priorityInUse goes Ready, stop the - // init timer. If update is from higher than priorityInUse, - // priorityInUse will be closed, and the init timer will become useless. - if timer := edsImpl.priorityInitTimer; timer != nil { - timer.Stop() - edsImpl.priorityInitTimer = nil - } - - if edsImpl.priorityInUse.lowerThan(priority) { - edsImpl.logger.Infof("Switching priority from %v to %v, because latter became Ready", edsImpl.priorityInUse, priority) - edsImpl.priorityInUse = priority - for i := priority.nextLower(); !i.lowerThan(edsImpl.priorityLowest); i = i.nextLower() { - bgwc := edsImpl.priorityToLocalities[i] - bgwc.stateAggregator.Stop() - bgwc.bg.Close() - } - return true - } - return true -} - -// handlePriorityWithNewStateTransientFailure handles state TransientFailure and -// decides whether to forward update or not. -// -// An update with state Failure: -// - If it's from a higher priority: -// - Do not forward, and do nothing -// - If it's from priorityInUse: -// - If there's no lower: -// - Forward and do nothing else -// - If there's a lower priority: -// - Forward -// - Set lower as priorityInUse -// - Start lower -// -// Caller must make sure priorityInUse is not higher than priority. -// -// Caller must hold priorityMu. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewStateTransientFailure(priority priorityType) bool { - if edsImpl.priorityInUse.lowerThan(priority) { - return false - } - // priorityInUse sends a failure. Stop its init timer. - if timer := edsImpl.priorityInitTimer; timer != nil { - timer.Stop() - edsImpl.priorityInitTimer = nil - } - pNext := priority.nextLower() - if _, okNext := edsImpl.priorityToLocalities[pNext]; !okNext { - return true - } - edsImpl.logger.Infof("Switching priority from %v to %v, because former became TransientFailure", priority, pNext) - edsImpl.startPriority(pNext) - return true -} - -// handlePriorityWithNewStateConnecting handles state Connecting and decides -// whether to forward update or not. -// -// An update with state Connecting: -// - If it's from a higher priority -// - Do nothing -// - If it's from priorityInUse, the behavior depends on previous state. -// -// When new state is Connecting, the behavior depends on previous state. If the -// previous state was Ready, this is a transition out from Ready to Connecting. -// Assuming there are multiple backends in the same priority, this mean we are -// in a bad situation and we should failover to the next priority (Side note: -// the current connectivity state aggregating algorhtim (e.g. round-robin) is -// not handling this right, because if many backends all go from Ready to -// Connecting, the overall situation is more like TransientFailure, not -// Connecting). -// -// If the previous state was Idle, we don't do anything special with failure, -// and simply forward the update. The init timer should be in process, will -// handle failover if it timeouts. If the previous state was TransientFailure, -// we do not forward, because the lower priority is in use. -// -// Caller must make sure priorityInUse is not higher than priority. -// -// Caller must hold priorityMu. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewStateConnecting(priority priorityType, oldState connectivity.State) bool { - if edsImpl.priorityInUse.lowerThan(priority) { - return false - } - - switch oldState { - case connectivity.Ready: - pNext := priority.nextLower() - if _, okNext := edsImpl.priorityToLocalities[pNext]; !okNext { - return true - } - edsImpl.logger.Infof("Switching priority from %v to %v, because former became Connecting from Ready", priority, pNext) - edsImpl.startPriority(pNext) - return true - case connectivity.Idle: - return true - case connectivity.TransientFailure: - return false - default: - // Old state is Connecting or Shutdown. Don't forward. - return false - } -} - -// priorityType represents the priority from EDS response. -// -// 0 is the highest priority. The bigger the number, the lower the priority. -type priorityType struct { - set bool - p uint32 -} - -func newPriorityType(p uint32) priorityType { - return priorityType{ - set: true, - p: p, - } -} - -func newPriorityTypeUnset() priorityType { - return priorityType{} -} - -func (p priorityType) isSet() bool { - return p.set -} - -func (p priorityType) equal(p2 priorityType) bool { - if !p.isSet() && !p2.isSet() { - return true - } - if !p.isSet() || !p2.isSet() { - return false - } - return p == p2 -} - -func (p priorityType) higherThan(p2 priorityType) bool { - if !p.isSet() || !p2.isSet() { - // TODO(menghanl): return an appropriate value instead of panic. - panic("priority unset") - } - return p.p < p2.p -} - -func (p priorityType) lowerThan(p2 priorityType) bool { - if !p.isSet() || !p2.isSet() { - // TODO(menghanl): return an appropriate value instead of panic. - panic("priority unset") - } - return p.p > p2.p -} - -func (p priorityType) nextLower() priorityType { - if !p.isSet() { - panic("priority unset") - } - return priorityType{ - set: true, - p: p.p + 1, - } -} - -func (p priorityType) String() string { - if !p.set { - return "Nil" - } - return fmt.Sprint(p.p) -} diff --git a/xds/internal/balancer/edsbalancer/eds_impl_test.go b/xds/internal/balancer/edsbalancer/eds_impl_test.go deleted file mode 100644 index 3334376f4a9..00000000000 --- a/xds/internal/balancer/edsbalancer/eds_impl_test.go +++ /dev/null @@ -1,935 +0,0 @@ -/* - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "fmt" - "reflect" - "sort" - "testing" - "time" - - corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/roundrobin" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/internal/balancer/stub" - "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/balancer/balancergroup" - "google.golang.org/grpc/xds/internal/client" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/testutils" -) - -var ( - testClusterNames = []string{"test-cluster-1", "test-cluster-2"} - testSubZones = []string{"I", "II", "III", "IV"} - testEndpointAddrs []string -) - -const testBackendAddrsCount = 12 - -func init() { - for i := 0; i < testBackendAddrsCount; i++ { - testEndpointAddrs = append(testEndpointAddrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) - } - balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond -} - -// One locality -// - add backend -// - remove backend -// - replace backend -// - change drop rate -func (s) TestEDS_OneLocality(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Pick with only the first backend. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - } - - // The same locality, add one more backend. - clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p2 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // The same locality, delete first backend. - clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) - - scToRemove := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - - // Test pick with only the second subconn. - p3 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p3.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } - } - - // The same locality, replace backend. - clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab4.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab4.Build())) - - sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) - scToRemove = <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - - // Test pick with only the third subconn. - p4 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p4.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc3, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc3) - } - } - - // The same locality, different drop rate, dropping 50%. - clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], map[string]uint32{"test-drop": 50}) - clab5.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab5.Build())) - - // Picks with drops. - p5 := <-cc.NewPickerCh - for i := 0; i < 100; i++ { - _, err := p5.Pick(balancer.PickInfo{}) - // TODO: the dropping algorithm needs a design. When the dropping algorithm - // is fixed, this test also needs fix. - if i < 50 && err == nil { - t.Errorf("The first 50%% picks should be drops, got error ") - } else if i > 50 && err != nil { - t.Errorf("The second 50%% picks should be non-drops, got error %v", err) - } - } - - // The same locality, remove drops. - clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab6.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab6.Build())) - - // Pick without drops. - p6 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p6.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc3, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc3) - } - } -} - -// 2 locality -// - start with 2 locality -// - add locality -// - remove locality -// - address change for the locality -// - update locality weight -func (s) TestEDS_TwoLocalities(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // Two localities, each with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Add the second locality later to make sure sc2 belongs to the second - // locality. Otherwise the test is flaky because of a map is used in EDS to - // keep localities. - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Add another locality, with one backend. - clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - clab2.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - clab2.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - - sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) - - // Test roundrobin with three subconns. - p2 := <-cc.NewPickerCh - want = []balancer.SubConn{sc1, sc2, sc3} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Remove first locality. - clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab3.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - clab3.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) - - scToRemove := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - - // Test pick with two subconns (without the first one). - p3 := <-cc.NewPickerCh - want = []balancer.SubConn{sc2, sc3} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p3)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Add a backend to the last locality. - clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab4.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - clab4.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab4.Build())) - - sc4 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc4, connectivity.Connecting) - edsb.handleSubConnStateChange(sc4, connectivity.Ready) - - // Test pick with two subconns (without the first one). - p4 := <-cc.NewPickerCh - // Locality-1 will be picked twice, and locality-2 will be picked twice. - // Locality-1 contains only sc2, locality-2 contains sc3 and sc4. So expect - // two sc2's and sc3, sc4. - want = []balancer.SubConn{sc2, sc2, sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p4)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Change weight of the locality[1]. - clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab5.AddLocality(testSubZones[1], 2, 0, testEndpointAddrs[1:2], nil) - clab5.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab5.Build())) - - // Test pick with two subconns different locality weight. - p5 := <-cc.NewPickerCh - // Locality-1 will be picked four times, and locality-2 will be picked twice - // (weight 2 and 1). Locality-1 contains only sc2, locality-2 contains sc3 and - // sc4. So expect four sc2's and sc3, sc4. - want = []balancer.SubConn{sc2, sc2, sc2, sc2, sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p5)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Change weight of the locality[1] to 0, it should never be picked. - clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab6.AddLocality(testSubZones[1], 0, 0, testEndpointAddrs[1:2], nil) - clab6.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab6.Build())) - - // Changing weight of locality[1] to 0 caused it to be removed. It's subconn - // should also be removed. - // - // NOTE: this is because we handle locality with weight 0 same as the - // locality doesn't exist. If this changes in the future, this removeSubConn - // behavior will also change. - scToRemove2 := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove2, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove2) - } - - // Test pick with two subconns different locality weight. - p6 := <-cc.NewPickerCh - // Locality-1 will be not be picked, and locality-2 will be picked. - // Locality-2 contains sc3 and sc4. So expect sc3, sc4. - want = []balancer.SubConn{sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p6)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } -} - -// The EDS balancer gets EDS resp with unhealthy endpoints. Test that only -// healthy ones are used. -func (s) TestEDS_EndpointsHealth(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // Two localities, each 3 backend, one Healthy, one Unhealthy, one Unknown. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:6], &testutils.AddLocalityOptions{ - Health: []corepb.HealthStatus{ - corepb.HealthStatus_HEALTHY, - corepb.HealthStatus_UNHEALTHY, - corepb.HealthStatus_UNKNOWN, - corepb.HealthStatus_DRAINING, - corepb.HealthStatus_TIMEOUT, - corepb.HealthStatus_DEGRADED, - }, - }) - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[6:12], &testutils.AddLocalityOptions{ - Health: []corepb.HealthStatus{ - corepb.HealthStatus_HEALTHY, - corepb.HealthStatus_UNHEALTHY, - corepb.HealthStatus_UNKNOWN, - corepb.HealthStatus_DRAINING, - corepb.HealthStatus_TIMEOUT, - corepb.HealthStatus_DEGRADED, - }, - }) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - var ( - readySCs []balancer.SubConn - newSubConnAddrStrs []string - ) - for i := 0; i < 4; i++ { - addr := <-cc.NewSubConnAddrsCh - newSubConnAddrStrs = append(newSubConnAddrStrs, addr[0].Addr) - sc := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc, connectivity.Connecting) - edsb.handleSubConnStateChange(sc, connectivity.Ready) - readySCs = append(readySCs, sc) - } - - wantNewSubConnAddrStrs := []string{ - testEndpointAddrs[0], - testEndpointAddrs[2], - testEndpointAddrs[6], - testEndpointAddrs[8], - } - sortStrTrans := cmp.Transformer("Sort", func(in []string) []string { - out := append([]string(nil), in...) // Copy input to avoid mutating it. - sort.Strings(out) - return out - }) - if !cmp.Equal(newSubConnAddrStrs, wantNewSubConnAddrStrs, sortStrTrans) { - t.Fatalf("want newSubConn with address %v, got %v", wantNewSubConnAddrStrs, newSubConnAddrStrs) - } - - // There should be exactly 4 new SubConns. Check to make sure there's no - // more subconns being created. - select { - case <-cc.NewSubConnCh: - t.Fatalf("Got unexpected new subconn") - case <-time.After(time.Microsecond * 100): - } - - // Test roundrobin with the subconns. - p1 := <-cc.NewPickerCh - want := readySCs - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } -} - -func (s) TestClose(t *testing.T) { - edsb := newEDSBalancerImpl(nil, balancer.BuildOptions{}, nil, nil, nil) - // This is what could happen when switching between fallback and eds. This - // make sure it doesn't panic. - edsb.close() -} - -// TestEDS_EmptyUpdate covers the cases when eds impl receives an empty update. -// -// It should send an error picker with transient failure to the parent. -func (s) TestEDS_EmptyUpdate(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // The first update is an empty update. - edsb.handleEDSResponse(xdsclient.EndpointsUpdate{}) - // Pick should fail with transient failure, and all priority removed error. - perr0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := perr0.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(err, errAllPrioritiesRemoved) { - t.Fatalf("picker.Pick, got error %v, want error %v", err, errAllPrioritiesRemoved) - } - } - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Pick with only the first backend. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(gotSCSt.SubConn, sc1) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - } - - edsb.handleEDSResponse(xdsclient.EndpointsUpdate{}) - // Pick should fail with transient failure, and all priority removed error. - perr1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := perr1.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(err, errAllPrioritiesRemoved) { - t.Fatalf("picker.Pick, got error %v, want error %v", err, errAllPrioritiesRemoved) - } - } - - // Handle another update with priorities and localities. - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Pick with only the first backend. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(gotSCSt.SubConn, sc2) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } - } -} - -// Create XDS balancer, and update sub-balancer before handling eds responses. -// Then switch between round-robin and a test stub-balancer after handling first -// eds response. -func (s) TestEDS_UpdateSubBalancerName(t *testing.T) { - const balancerName = "stubBalancer-TestEDS_UpdateSubBalancerName" - - cc := testutils.NewTestClientConn(t) - stub.Register(balancerName, stub.BalancerFuncs{ - UpdateClientConnState: func(bd *stub.BalancerData, s balancer.ClientConnState) error { - if len(s.ResolverState.Addresses) == 0 { - return nil - } - bd.ClientConn.NewSubConn(s.ResolverState.Addresses, balancer.NewSubConnOptions{}) - return nil - }, - UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { - bd.ClientConn.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &testutils.TestConstPicker{Err: testutils.ErrTestConstPicker}, - }) - }, - }) - - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - t.Logf("update sub-balancer to stub-balancer") - edsb.handleChildPolicy(balancerName, nil) - - // Two localities, each with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - for i := 0; i < 2; i++ { - sc := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc, connectivity.Ready) - } - - p0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := p0.Pick(balancer.PickInfo{}) - if err != testutils.ErrTestConstPicker { - t.Fatalf("picker.Pick, got err %+v, want err %+v", err, testutils.ErrTestConstPicker) - } - } - - t.Logf("update sub-balancer to round-robin") - edsb.handleChildPolicy(roundrobin.Name, nil) - - for i := 0; i < 2; i++ { - <-cc.RemoveSubConnCh - } - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - t.Logf("update sub-balancer to stub-balancer") - edsb.handleChildPolicy(balancerName, nil) - - for i := 0; i < 2; i++ { - scToRemove := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) && - !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want (%v or %v), got %v", sc1, sc2, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - } - - for i := 0; i < 2; i++ { - sc := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc, connectivity.Ready) - } - - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := p2.Pick(balancer.PickInfo{}) - if err != testutils.ErrTestConstPicker { - t.Fatalf("picker.Pick, got err %q, want err %q", err, testutils.ErrTestConstPicker) - } - } - - t.Logf("update sub-balancer to round-robin") - edsb.handleChildPolicy(roundrobin.Name, nil) - - for i := 0; i < 2; i++ { - <-cc.RemoveSubConnCh - } - - sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) - sc4 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc4, connectivity.Connecting) - edsb.handleSubConnStateChange(sc4, connectivity.Ready) - - p3 := <-cc.NewPickerCh - want = []balancer.SubConn{sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p3)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } -} - -func (s) TestEDS_CircuitBreaking(t *testing.T) { - origCircuitBreakingSupport := env.CircuitBreakingSupport - env.CircuitBreakingSupport = true - defer func() { env.CircuitBreakingSupport = origCircuitBreakingSupport }() - - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - var maxRequests uint32 = 50 - edsb.updateServiceRequestsConfig("test", &maxRequests) - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Picks with drops. - dones := []func(){} - p := <-cc.NewPickerCh - for i := 0; i < 100; i++ { - pr, err := p.Pick(balancer.PickInfo{}) - if i < 50 && err != nil { - t.Errorf("The first 50%% picks should be non-drops, got error %v", err) - } else if i > 50 && err == nil { - t.Errorf("The second 50%% picks should be drops, got error ") - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - for _, done := range dones { - done() - } - dones = []func(){} - - // Pick without drops. - for i := 0; i < 50; i++ { - pr, err := p.Pick(balancer.PickInfo{}) - if err != nil { - t.Errorf("The third 50%% picks should be non-drops, got error %v", err) - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - // Without this, future tests with the same service name will fail. - for _, done := range dones { - done() - } - - // Send another update, with only circuit breaking update (and no picker - // update afterwards). Make sure the new picker uses the new configs. - var maxRequests2 uint32 = 10 - edsb.updateServiceRequestsConfig("test", &maxRequests2) - - // Picks with drops. - dones = []func(){} - p2 := <-cc.NewPickerCh - for i := 0; i < 100; i++ { - pr, err := p2.Pick(balancer.PickInfo{}) - if i < 10 && err != nil { - t.Errorf("The first 10%% picks should be non-drops, got error %v", err) - } else if i > 10 && err == nil { - t.Errorf("The next 90%% picks should be drops, got error ") - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - for _, done := range dones { - done() - } - dones = []func(){} - - // Pick without drops. - for i := 0; i < 10; i++ { - pr, err := p2.Pick(balancer.PickInfo{}) - if err != nil { - t.Errorf("The next 10%% picks should be non-drops, got error %v", err) - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - // Without this, future tests with the same service name will fail. - for _, done := range dones { - done() - } -} - -func init() { - balancer.Register(&testInlineUpdateBalancerBuilder{}) -} - -// A test balancer that updates balancer.State inline when handling ClientConn -// state. -type testInlineUpdateBalancerBuilder struct{} - -func (*testInlineUpdateBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - return &testInlineUpdateBalancer{cc: cc} -} - -func (*testInlineUpdateBalancerBuilder) Name() string { - return "test-inline-update-balancer" -} - -type testInlineUpdateBalancer struct { - cc balancer.ClientConn -} - -func (tb *testInlineUpdateBalancer) ResolverError(error) { - panic("not implemented") -} - -func (tb *testInlineUpdateBalancer) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) { -} - -var errTestInlineStateUpdate = fmt.Errorf("don't like addresses, empty or not") - -func (tb *testInlineUpdateBalancer) UpdateClientConnState(balancer.ClientConnState) error { - tb.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Ready, - Picker: &testutils.TestConstPicker{Err: errTestInlineStateUpdate}, - }) - return nil -} - -func (*testInlineUpdateBalancer) Close() { -} - -// When the child policy update picker inline in a handleClientUpdate call -// (e.g., roundrobin handling empty addresses). There could be deadlock caused -// by acquiring a locked mutex. -func (s) TestEDS_ChildPolicyUpdatePickerInline(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = func(p priorityType, state balancer.State) { - // For this test, euqueue needs to happen asynchronously (like in the - // real implementation). - go edsb.updateState(p, state) - } - - edsb.handleChildPolicy("test-inline-update-balancer", nil) - - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - p0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := p0.Pick(balancer.PickInfo{}) - if err != errTestInlineStateUpdate { - t.Fatalf("picker.Pick, got err %q, want err %q", err, errTestInlineStateUpdate) - } - } -} - -func (s) TestDropPicker(t *testing.T) { - const pickCount = 12 - var constPicker = &testutils.TestConstPicker{ - SC: testutils.TestSubConns[0], - } - - tests := []struct { - name string - drops []*dropper - }{ - { - name: "no drop", - drops: nil, - }, - { - name: "one drop", - drops: []*dropper{ - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 2}), - }, - }, - { - name: "two drops", - drops: []*dropper{ - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 3}), - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 2}), - }, - }, - { - name: "three drops", - drops: []*dropper{ - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 3}), - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 4}), - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 2}), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - p := newDropPicker(constPicker, tt.drops, nil, nil, defaultServiceRequestCountMax) - - // scCount is the number of sc's returned by pick. The opposite of - // drop-count. - var ( - scCount int - wantCount = pickCount - ) - for _, dp := range tt.drops { - wantCount = wantCount * int(dp.c.Denominator-dp.c.Numerator) / int(dp.c.Denominator) - } - - for i := 0; i < pickCount; i++ { - _, err := p.Pick(balancer.PickInfo{}) - if err == nil { - scCount++ - } - } - - if scCount != (wantCount) { - t.Errorf("drops: %+v, scCount %v, wantCount %v", tt.drops, scCount, wantCount) - } - }) - } -} - -func (s) TestEDS_LoadReport(t *testing.T) { - origCircuitBreakingSupport := env.CircuitBreakingSupport - env.CircuitBreakingSupport = true - defer func() { env.CircuitBreakingSupport = origCircuitBreakingSupport }() - - // We create an xdsClientWrapper with a dummy xdsClientInterface which only - // implements the LoadStore() method to return the underlying load.Store to - // be used. - loadStore := load.NewStore() - lsWrapper := &loadStoreWrapper{} - lsWrapper.updateServiceName(testClusterNames[0]) - lsWrapper.updateLoadStore(loadStore) - - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, lsWrapper, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - const ( - testServiceName = "test-service" - cbMaxRequests = 20 - ) - var maxRequestsTemp uint32 = cbMaxRequests - edsb.updateServiceRequestsConfig(testServiceName, &maxRequestsTemp) - defer client.ClearCounterForTesting(testServiceName) - - backendToBalancerID := make(map[balancer.SubConn]internal.LocalityID) - - const testDropCategory = "test-drop" - // Two localities, each with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], map[string]uint32{testDropCategory: 50}) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - locality1 := internal.LocalityID{SubZone: testSubZones[0]} - backendToBalancerID[sc1] = locality1 - - // Add the second locality later to make sure sc2 belongs to the second - // locality. Otherwise the test is flaky because of a map is used in EDS to - // keep localities. - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - locality2 := internal.LocalityID{SubZone: testSubZones[1]} - backendToBalancerID[sc2] = locality2 - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - // We expect the 10 picks to be split between the localities since they are - // of equal weight. And since we only mark the picks routed to sc2 as done, - // the picks on sc1 should show up as inProgress. - locality1JSON, _ := locality1.ToString() - locality2JSON, _ := locality2.ToString() - const ( - rpcCount = 100 - // 50% will be dropped with category testDropCategory. - dropWithCategory = rpcCount / 2 - // In the remaining RPCs, only cbMaxRequests are allowed by circuit - // breaking. Others will be dropped by CB. - dropWithCB = rpcCount - dropWithCategory - cbMaxRequests - - rpcInProgress = cbMaxRequests / 2 // 50% of RPCs will be never done. - rpcSucceeded = cbMaxRequests / 2 // 50% of RPCs will succeed. - ) - wantStoreData := []*load.Data{{ - Cluster: testClusterNames[0], - Service: "", - LocalityStats: map[string]load.LocalityData{ - locality1JSON: {RequestStats: load.RequestData{InProgress: rpcInProgress}}, - locality2JSON: {RequestStats: load.RequestData{Succeeded: rpcSucceeded}}, - }, - TotalDrops: dropWithCategory + dropWithCB, - Drops: map[string]uint64{ - testDropCategory: dropWithCategory, - }, - }} - - var rpcsToBeDone []balancer.PickResult - // Run the picks, but only pick with sc1 will be done later. - for i := 0; i < rpcCount; i++ { - scst, _ := p1.Pick(balancer.PickInfo{}) - if scst.Done != nil && scst.SubConn != sc1 { - rpcsToBeDone = append(rpcsToBeDone, scst) - } - } - // Call done on those sc1 picks. - for _, scst := range rpcsToBeDone { - scst.Done(balancer.DoneInfo{}) - } - - gotStoreData := loadStore.Stats(testClusterNames[0:1]) - if diff := cmp.Diff(wantStoreData, gotStoreData, cmpopts.EquateEmpty(), cmpopts.IgnoreFields(load.Data{}, "ReportInterval")); diff != "" { - t.Errorf("store.stats() returned unexpected diff (-want +got):\n%s", diff) - } -} - -// TestEDS_LoadReportDisabled covers the case that LRS is disabled. It makes -// sure the EDS implementation isn't broken (doesn't panic). -func (s) TestEDS_LoadReportDisabled(t *testing.T) { - lsWrapper := &loadStoreWrapper{} - lsWrapper.updateServiceName(testClusterNames[0]) - // Not calling lsWrapper.updateLoadStore(loadStore) because LRS is disabled. - - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, lsWrapper, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // One localities, with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - // We call picks to make sure they don't panic. - for i := 0; i < 10; i++ { - p1.Pick(balancer.PickInfo{}) - } -} diff --git a/xds/internal/balancer/edsbalancer/eds_test.go b/xds/internal/balancer/edsbalancer/eds_test.go deleted file mode 100644 index 5fe1f2ef6b9..00000000000 --- a/xds/internal/balancer/edsbalancer/eds_test.go +++ /dev/null @@ -1,825 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package edsbalancer - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "reflect" - "testing" - "time" - - "github.com/golang/protobuf/jsonpb" - wrapperspb "github.com/golang/protobuf/ptypes/wrappers" - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/internal/grpctest" - scpb "google.golang.org/grpc/internal/proto/grpc_service_config" - "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" - "google.golang.org/grpc/xds/internal/testutils/fakeclient" - - _ "google.golang.org/grpc/xds/internal/client/v2" // V2 client registration. -) - -const ( - defaultTestTimeout = 1 * time.Second - defaultTestShortTimeout = 10 * time.Millisecond - testServiceName = "test/foo" - testEDSClusterName = "test/service/eds" -) - -var ( - // A non-empty endpoints update which is expected to be accepted by the EDS - // LB policy. - defaultEndpointsUpdate = xdsclient.EndpointsUpdate{ - Localities: []xdsclient.Locality{ - { - Endpoints: []xdsclient.Endpoint{{Address: "endpoint1"}}, - ID: internal.LocalityID{Zone: "zone"}, - Priority: 1, - Weight: 100, - }, - }, - } -) - -func init() { - balancer.Register(&edsBalancerBuilder{}) -} - -func subConnFromPicker(p balancer.Picker) func() balancer.SubConn { - return func() balancer.SubConn { - scst, _ := p.Pick(balancer.PickInfo{}) - return scst.SubConn - } -} - -type s struct { - grpctest.Tester -} - -func Test(t *testing.T) { - grpctest.RunSubTests(t, s{}) -} - -const testBalancerNameFooBar = "foo.bar" - -func newNoopTestClientConn() *noopTestClientConn { - return &noopTestClientConn{} -} - -// noopTestClientConn is used in EDS balancer config update tests that only -// cover the config update handling, but not SubConn/load-balancing. -type noopTestClientConn struct { - balancer.ClientConn -} - -func (t *noopTestClientConn) NewSubConn([]resolver.Address, balancer.NewSubConnOptions) (balancer.SubConn, error) { - return nil, nil -} - -func (noopTestClientConn) Target() string { return testServiceName } - -type scStateChange struct { - sc balancer.SubConn - state connectivity.State -} - -type fakeEDSBalancer struct { - cc balancer.ClientConn - childPolicy *testutils.Channel - subconnStateChange *testutils.Channel - edsUpdate *testutils.Channel - serviceName *testutils.Channel - serviceRequestMax *testutils.Channel -} - -func (f *fakeEDSBalancer) handleSubConnStateChange(sc balancer.SubConn, state connectivity.State) { - f.subconnStateChange.Send(&scStateChange{sc: sc, state: state}) -} - -func (f *fakeEDSBalancer) handleChildPolicy(name string, config json.RawMessage) { - f.childPolicy.Send(&loadBalancingConfig{Name: name, Config: config}) -} - -func (f *fakeEDSBalancer) handleEDSResponse(edsResp xdsclient.EndpointsUpdate) { - f.edsUpdate.Send(edsResp) -} - -func (f *fakeEDSBalancer) updateState(priority priorityType, s balancer.State) {} - -func (f *fakeEDSBalancer) updateServiceRequestsConfig(serviceName string, max *uint32) { - f.serviceName.Send(serviceName) - f.serviceRequestMax.Send(max) -} - -func (f *fakeEDSBalancer) close() {} - -func (f *fakeEDSBalancer) waitForChildPolicy(ctx context.Context, wantPolicy *loadBalancingConfig) error { - val, err := f.childPolicy.Receive(ctx) - if err != nil { - return err - } - gotPolicy := val.(*loadBalancingConfig) - if !cmp.Equal(gotPolicy, wantPolicy) { - return fmt.Errorf("got childPolicy %v, want %v", gotPolicy, wantPolicy) - } - return nil -} - -func (f *fakeEDSBalancer) waitForSubConnStateChange(ctx context.Context, wantState *scStateChange) error { - val, err := f.subconnStateChange.Receive(ctx) - if err != nil { - return err - } - gotState := val.(*scStateChange) - if !cmp.Equal(gotState, wantState, cmp.AllowUnexported(scStateChange{})) { - return fmt.Errorf("got subconnStateChange %v, want %v", gotState, wantState) - } - return nil -} - -func (f *fakeEDSBalancer) waitForEDSResponse(ctx context.Context, wantUpdate xdsclient.EndpointsUpdate) error { - val, err := f.edsUpdate.Receive(ctx) - if err != nil { - return err - } - gotUpdate := val.(xdsclient.EndpointsUpdate) - if !reflect.DeepEqual(gotUpdate, wantUpdate) { - return fmt.Errorf("got edsUpdate %+v, want %+v", gotUpdate, wantUpdate) - } - return nil -} - -func (f *fakeEDSBalancer) waitForCounterUpdate(ctx context.Context, wantServiceName string) error { - val, err := f.serviceName.Receive(ctx) - if err != nil { - return err - } - gotServiceName := val.(string) - if gotServiceName != wantServiceName { - return fmt.Errorf("got serviceName %v, want %v", gotServiceName, wantServiceName) - } - return nil -} - -func (f *fakeEDSBalancer) waitForCountMaxUpdate(ctx context.Context, want *uint32) error { - val, err := f.serviceRequestMax.Receive(ctx) - if err != nil { - return err - } - got := val.(*uint32) - - if got == nil && want == nil { - return nil - } - if got != nil && want != nil { - if *got != *want { - return fmt.Errorf("got countMax %v, want %v", *got, *want) - } - return nil - } - return fmt.Errorf("got countMax %+v, want %+v", got, want) -} - -func newFakeEDSBalancer(cc balancer.ClientConn) edsBalancerImplInterface { - return &fakeEDSBalancer{ - cc: cc, - childPolicy: testutils.NewChannelWithSize(10), - subconnStateChange: testutils.NewChannelWithSize(10), - edsUpdate: testutils.NewChannelWithSize(10), - serviceName: testutils.NewChannelWithSize(10), - serviceRequestMax: testutils.NewChannelWithSize(10), - } -} - -type fakeSubConn struct{} - -func (*fakeSubConn) UpdateAddresses([]resolver.Address) { panic("implement me") } -func (*fakeSubConn) Connect() { panic("implement me") } - -// waitForNewEDSLB makes sure that a new edsLB is created by the top-level -// edsBalancer. -func waitForNewEDSLB(ctx context.Context, ch *testutils.Channel) (*fakeEDSBalancer, error) { - val, err := ch.Receive(ctx) - if err != nil { - return nil, fmt.Errorf("error when waiting for a new edsLB: %v", err) - } - return val.(*fakeEDSBalancer), nil -} - -// setup overrides the functions which are used to create the xdsClient and the -// edsLB, creates fake version of them and makes them available on the provided -// channels. The returned cancel function should be called by the test for -// cleanup. -func setup(edsLBCh *testutils.Channel) (*fakeclient.Client, func()) { - xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - - origNewEDSBalancer := newEDSBalancer - newEDSBalancer = func(cc balancer.ClientConn, _ balancer.BuildOptions, _ func(priorityType, balancer.State), _ load.PerClusterReporter, _ *grpclog.PrefixLogger) edsBalancerImplInterface { - edsLB := newFakeEDSBalancer(cc) - defer func() { edsLBCh.Send(edsLB) }() - return edsLB - } - return xdsC, func() { - newEDSBalancer = origNewEDSBalancer - newXDSClient = oldNewXDSClient - } -} - -const ( - fakeBalancerA = "fake_balancer_A" - fakeBalancerB = "fake_balancer_B" -) - -// Install two fake balancers for service config update tests. -// -// ParseConfig only accepts the json if the balancer specified is registered. -func init() { - balancer.Register(&fakeBalancerBuilder{name: fakeBalancerA}) - balancer.Register(&fakeBalancerBuilder{name: fakeBalancerB}) -} - -type fakeBalancerBuilder struct { - name string -} - -func (b *fakeBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - return &fakeBalancer{cc: cc} -} - -func (b *fakeBalancerBuilder) Name() string { - return b.name -} - -type fakeBalancer struct { - cc balancer.ClientConn -} - -func (b *fakeBalancer) ResolverError(error) { - panic("implement me") -} - -func (b *fakeBalancer) UpdateClientConnState(balancer.ClientConnState) error { - panic("implement me") -} - -func (b *fakeBalancer) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) { - panic("implement me") -} - -func (b *fakeBalancer) Close() {} - -// TestConfigChildPolicyUpdate verifies scenarios where the childPolicy -// section of the lbConfig is updated. -// -// The test does the following: -// * Builds a new EDS balancer. -// * Pushes a new ClientConnState with a childPolicy set to fakeBalancerA. -// Verifies that an EDS watch is registered. It then pushes a new edsUpdate -// through the fakexds client. Verifies that a new edsLB is created and it -// receives the expected childPolicy. -// * Pushes a new ClientConnState with a childPolicy set to fakeBalancerB. -// Verifies that the existing edsLB receives the new child policy. -func (s) TestConfigChildPolicyUpdate(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - lbCfgA := &loadBalancingConfig{ - Name: fakeBalancerA, - Config: json.RawMessage("{}"), - } - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - ChildPolicy: lbCfgA, - EDSServiceName: testServiceName, - }, - }); err != nil { - t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(defaultEndpointsUpdate, nil) - if err := edsLB.waitForChildPolicy(ctx, lbCfgA); err != nil { - t.Fatal(err) - } - if err := edsLB.waitForCounterUpdate(ctx, testServiceName); err != nil { - t.Fatal(err) - } - if err := edsLB.waitForCountMaxUpdate(ctx, nil); err != nil { - t.Fatal(err) - } - - var testCountMax uint32 = 100 - lbCfgB := &loadBalancingConfig{ - Name: fakeBalancerB, - Config: json.RawMessage("{}"), - } - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - ChildPolicy: lbCfgB, - EDSServiceName: testServiceName, - MaxConcurrentRequests: &testCountMax, - }, - }); err != nil { - t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) - } - if err := edsLB.waitForChildPolicy(ctx, lbCfgB); err != nil { - t.Fatal(err) - } - if err := edsLB.waitForCounterUpdate(ctx, testServiceName); err != nil { - // Counter is updated even though the service name didn't change. The - // eds_impl will compare the service names, and skip if it didn't change. - t.Fatal(err) - } - if err := edsLB.waitForCountMaxUpdate(ctx, &testCountMax); err != nil { - t.Fatal(err) - } -} - -// TestSubConnStateChange verifies if the top-level edsBalancer passes on -// the subConnStateChange to appropriate child balancer. -func (s) TestSubConnStateChange(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(defaultEndpointsUpdate, nil) - - fsc := &fakeSubConn{} - state := connectivity.Ready - edsB.UpdateSubConnState(fsc, balancer.SubConnState{ConnectivityState: state}) - if err := edsLB.waitForSubConnStateChange(ctx, &scStateChange{sc: fsc, state: state}); err != nil { - t.Fatal(err) - } -} - -// TestErrorFromXDSClientUpdate verifies that an error from xdsClient update is -// handled correctly. -// -// If it's resource-not-found, watch will NOT be canceled, the EDS impl will -// receive an empty EDS update, and new RPCs will fail. -// -// If it's connection error, nothing will happen. This will need to change to -// handle fallback. -func (s) TestErrorFromXDSClientUpdate(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatal(err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil) - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("EDS impl got unexpected EDS response: %v", err) - } - - connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, connectionErr) - - sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { - t.Fatal("watch was canceled, want not canceled (timeout error)") - } - - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := edsLB.waitForEDSResponse(sCtx, xdsclient.EndpointsUpdate{}); err != context.DeadlineExceeded { - t.Fatal(err) - } - - resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "edsBalancer resource not found error") - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, resourceErr) - // Even if error is resource not found, watch shouldn't be canceled, because - // this is an EDS resource removed (and xds client actually never sends this - // error, but we still handles it). - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { - t.Fatal("watch was canceled, want not canceled (timeout error)") - } - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("eds impl expecting empty update, got %v", err) - } -} - -// TestErrorFromResolver verifies that resolver errors are handled correctly. -// -// If it's resource-not-found, watch will be canceled, the EDS impl will receive -// an empty EDS update, and new RPCs will fail. -// -// If it's connection error, nothing will happen. This will need to change to -// handle fallback. -func (s) TestErrorFromResolver(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatal(err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil) - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("EDS impl got unexpected EDS response: %v", err) - } - - connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") - edsB.ResolverError(connectionErr) - - sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { - t.Fatal("watch was canceled, want not canceled (timeout error)") - } - - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := edsLB.waitForEDSResponse(sCtx, xdsclient.EndpointsUpdate{}); err != context.DeadlineExceeded { - t.Fatal("eds impl got EDS resp, want timeout error") - } - - resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "edsBalancer resource not found error") - edsB.ResolverError(resourceErr) - if err := xdsC.WaitForCancelEDSWatch(ctx); err != nil { - t.Fatalf("want watch to be canceled, waitForCancel failed: %v", err) - } - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("EDS impl got unexpected EDS response: %v", err) - } -} - -// Given a list of resource names, verifies that EDS requests for the same are -// sent by the EDS balancer, through the fake xDS client. -func verifyExpectedRequests(ctx context.Context, fc *fakeclient.Client, resourceNames ...string) error { - for _, name := range resourceNames { - if name == "" { - // ResourceName empty string indicates a cancel. - if err := fc.WaitForCancelEDSWatch(ctx); err != nil { - return fmt.Errorf("timed out when expecting resource %q", name) - } - return nil - } - - resName, err := fc.WaitForWatchEDS(ctx) - if err != nil { - return fmt.Errorf("timed out when expecting resource %q, %p", name, fc) - } - if resName != name { - return fmt.Errorf("got EDS request for resource %q, expected: %q", resName, name) - } - } - return nil -} - -// TestClientWatchEDS verifies that the xdsClient inside the top-level EDS LB -// policy registers an EDS watch for expected resource upon receiving an update -// from gRPC. -func (s) TestClientWatchEDS(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - // Update with an non-empty edsServiceName should trigger an EDS watch for - // the same. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: "foobar-1"}, - }); err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if err := verifyExpectedRequests(ctx, xdsC, "foobar-1"); err != nil { - t.Fatal(err) - } - - // Also test the case where the edsServerName changes from one non-empty - // name to another, and make sure a new watch is registered. The previously - // registered watch will be cancelled, which will result in an EDS request - // with no resource names being sent to the server. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: "foobar-2"}, - }); err != nil { - t.Fatal(err) - } - if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-2"); err != nil { - t.Fatal(err) - } -} - -// TestCounterUpdate verifies that the counter update is triggered with the -// service name from an update's config. -func (s) TestCounterUpdate(t *testing.T) { - edsLBCh := testutils.NewChannel() - _, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - var testCountMax uint32 = 100 - // Update should trigger counter update with provided service name. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - EDSServiceName: "foobar-1", - MaxConcurrentRequests: &testCountMax, - }, - }); err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsI := edsB.(*edsBalancer).edsImpl.(*fakeEDSBalancer) - if err := edsI.waitForCounterUpdate(ctx, "foobar-1"); err != nil { - t.Fatal(err) - } - if err := edsI.waitForCountMaxUpdate(ctx, &testCountMax); err != nil { - t.Fatal(err) - } -} - -func (s) TestBalancerConfigParsing(t *testing.T) { - const testEDSName = "eds.service" - var testLRSName = "lrs.server" - b := bytes.NewBuffer(nil) - if err := (&jsonpb.Marshaler{}).Marshal(b, &scpb.XdsConfig{ - ChildPolicy: []*scpb.LoadBalancingConfig{ - {Policy: &scpb.LoadBalancingConfig_Xds{}}, - {Policy: &scpb.LoadBalancingConfig_RoundRobin{ - RoundRobin: &scpb.RoundRobinConfig{}, - }}, - }, - FallbackPolicy: []*scpb.LoadBalancingConfig{ - {Policy: &scpb.LoadBalancingConfig_Xds{}}, - {Policy: &scpb.LoadBalancingConfig_PickFirst{ - PickFirst: &scpb.PickFirstConfig{}, - }}, - }, - EdsServiceName: testEDSName, - LrsLoadReportingServerName: &wrapperspb.StringValue{Value: testLRSName}, - }); err != nil { - t.Fatalf("%v", err) - } - - var testMaxConcurrentRequests uint32 = 123 - tests := []struct { - name string - js json.RawMessage - want serviceconfig.LoadBalancingConfig - wantErr bool - }{ - { - name: "bad json", - js: json.RawMessage(`i am not JSON`), - wantErr: true, - }, - { - name: "empty", - js: json.RawMessage(`{}`), - want: &EDSConfig{}, - }, - { - name: "jsonpb-generated", - js: b.Bytes(), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "round_robin", - Config: json.RawMessage("{}"), - }, - FallBackPolicy: &loadBalancingConfig{ - Name: "pick_first", - Config: json.RawMessage("{}"), - }, - EDSServiceName: testEDSName, - LrsLoadReportingServerName: &testLRSName, - }, - }, - { - // json with random balancers, and the first is not registered. - name: "manually-generated", - js: json.RawMessage(` -{ - "childPolicy": [ - {"fake_balancer_C": {}}, - {"fake_balancer_A": {}}, - {"fake_balancer_B": {}} - ], - "fallbackPolicy": [ - {"fake_balancer_C": {}}, - {"fake_balancer_B": {}}, - {"fake_balancer_A": {}} - ], - "edsServiceName": "eds.service", - "maxConcurrentRequests": 123, - "lrsLoadReportingServerName": "lrs.server" -}`), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "fake_balancer_A", - Config: json.RawMessage("{}"), - }, - FallBackPolicy: &loadBalancingConfig{ - Name: "fake_balancer_B", - Config: json.RawMessage("{}"), - }, - EDSServiceName: testEDSName, - MaxConcurrentRequests: &testMaxConcurrentRequests, - LrsLoadReportingServerName: &testLRSName, - }, - }, - { - // json with no lrs server name, LrsLoadReportingServerName should - // be nil (not an empty string). - name: "no-lrs-server-name", - js: json.RawMessage(` -{ - "edsServiceName": "eds.service" -}`), - want: &EDSConfig{ - EDSServiceName: testEDSName, - LrsLoadReportingServerName: nil, - }, - }, - { - name: "good child policy", - js: json.RawMessage(`{"childPolicy":[{"pick_first":{}}]}`), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "pick_first", - Config: json.RawMessage(`{}`), - }, - }, - }, - { - name: "multiple good child policies", - js: json.RawMessage(`{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "round_robin", - Config: json.RawMessage(`{}`), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - b := &edsBalancerBuilder{} - got, err := b.ParseConfig(tt.js) - if (err != nil) != tt.wantErr { - t.Fatalf("edsBalancerBuilder.ParseConfig() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.wantErr { - return - } - if !cmp.Equal(got, tt.want) { - t.Errorf(cmp.Diff(got, tt.want)) - } - }) - } -} - -func (s) TestEqualStringPointers(t *testing.T) { - var ( - ta1 = "test-a" - ta2 = "test-a" - tb = "test-b" - ) - tests := []struct { - name string - a *string - b *string - want bool - }{ - {"both-nil", nil, nil, true}, - {"a-non-nil", &ta1, nil, false}, - {"b-non-nil", nil, &tb, false}, - {"equal", &ta1, &ta2, true}, - {"different", &ta1, &tb, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := equalStringPointers(tt.a, tt.b); got != tt.want { - t.Errorf("equalStringPointers() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/xds/internal/balancer/edsbalancer/load_store_wrapper.go b/xds/internal/balancer/edsbalancer/load_store_wrapper.go deleted file mode 100644 index 18904e47a42..00000000000 --- a/xds/internal/balancer/edsbalancer/load_store_wrapper.go +++ /dev/null @@ -1,88 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package edsbalancer - -import ( - "sync" - - "google.golang.org/grpc/xds/internal/client/load" -) - -type loadStoreWrapper struct { - mu sync.RWMutex - service string - // Both store and perCluster will be nil if load reporting is disabled (EDS - // response doesn't have LRS server name). Note that methods on Store and - // perCluster all handle nil, so there's no need to check nil before calling - // them. - store *load.Store - perCluster load.PerClusterReporter -} - -func (lsw *loadStoreWrapper) updateServiceName(service string) { - lsw.mu.Lock() - defer lsw.mu.Unlock() - if lsw.service == service { - return - } - lsw.service = service - lsw.perCluster = lsw.store.PerCluster(lsw.service, "") -} - -func (lsw *loadStoreWrapper) updateLoadStore(store *load.Store) { - lsw.mu.Lock() - defer lsw.mu.Unlock() - if store == lsw.store { - return - } - lsw.store = store - lsw.perCluster = lsw.store.PerCluster(lsw.service, "") -} - -func (lsw *loadStoreWrapper) CallStarted(locality string) { - lsw.mu.RLock() - defer lsw.mu.RUnlock() - if lsw.perCluster != nil { - lsw.perCluster.CallStarted(locality) - } -} - -func (lsw *loadStoreWrapper) CallFinished(locality string, err error) { - lsw.mu.RLock() - defer lsw.mu.RUnlock() - if lsw.perCluster != nil { - lsw.perCluster.CallFinished(locality, err) - } -} - -func (lsw *loadStoreWrapper) CallServerLoad(locality, name string, val float64) { - lsw.mu.RLock() - defer lsw.mu.RUnlock() - if lsw.perCluster != nil { - lsw.perCluster.CallServerLoad(locality, name, val) - } -} - -func (lsw *loadStoreWrapper) CallDropped(category string) { - lsw.mu.RLock() - defer lsw.mu.RUnlock() - if lsw.perCluster != nil { - lsw.perCluster.CallDropped(category) - } -} diff --git a/xds/internal/balancer/edsbalancer/util.go b/xds/internal/balancer/edsbalancer/util.go deleted file mode 100644 index 13295042646..00000000000 --- a/xds/internal/balancer/edsbalancer/util.go +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "google.golang.org/grpc/internal/wrr" - xdsclient "google.golang.org/grpc/xds/internal/client" -) - -var newRandomWRR = wrr.NewRandom - -type dropper struct { - c xdsclient.OverloadDropConfig - w wrr.WRR -} - -func newDropper(c xdsclient.OverloadDropConfig) *dropper { - w := newRandomWRR() - w.Add(true, int64(c.Numerator)) - w.Add(false, int64(c.Denominator-c.Numerator)) - - return &dropper{ - c: c, - w: w, - } -} - -func (d *dropper) drop() (ret bool) { - return d.w.Next().(bool) -} diff --git a/xds/internal/balancer/edsbalancer/util_test.go b/xds/internal/balancer/edsbalancer/util_test.go deleted file mode 100644 index 748aeffe2bb..00000000000 --- a/xds/internal/balancer/edsbalancer/util_test.go +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "testing" - - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/testutils" -) - -func init() { - newRandomWRR = testutils.NewTestWRR -} - -func (s) TestDropper(t *testing.T) { - const repeat = 2 - - type args struct { - numerator uint32 - denominator uint32 - } - tests := []struct { - name string - args args - }{ - { - name: "2_3", - args: args{ - numerator: 2, - denominator: 3, - }, - }, - { - name: "4_8", - args: args{ - numerator: 4, - denominator: 8, - }, - }, - { - name: "7_20", - args: args{ - numerator: 7, - denominator: 20, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := newDropper(xdsclient.OverloadDropConfig{ - Category: "", - Numerator: tt.args.numerator, - Denominator: tt.args.denominator, - }) - - var ( - dCount int - wantCount = int(tt.args.numerator) * repeat - loopCount = int(tt.args.denominator) * repeat - ) - for i := 0; i < loopCount; i++ { - if d.drop() { - dCount++ - } - } - - if dCount != (wantCount) { - t.Errorf("with numerator %v, denominator %v repeat %v, got drop count: %v, want %v", - tt.args.numerator, tt.args.denominator, repeat, dCount, wantCount) - } - }) - } -} diff --git a/xds/internal/balancer/edsbalancer/xds_lrs_test.go b/xds/internal/balancer/edsbalancer/xds_lrs_test.go deleted file mode 100644 index 9f93e0b42f0..00000000000 --- a/xds/internal/balancer/edsbalancer/xds_lrs_test.go +++ /dev/null @@ -1,71 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package edsbalancer - -import ( - "context" - "testing" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/xds/internal/testutils/fakeclient" -) - -// TestXDSLoadReporting verifies that the edsBalancer starts the loadReport -// stream when the lbConfig passed to it contains a valid value for the LRS -// server (empty string). -func (s) TestXDSLoadReporting(t *testing.T) { - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - EDSServiceName: testEDSClusterName, - LrsLoadReportingServerName: new(string), - }, - }); err != nil { - t.Fatal(err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - gotCluster, err := xdsC.WaitForWatchEDS(ctx) - if err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - if gotCluster != testEDSClusterName { - t.Fatalf("xdsClient.WatchEndpoints() called with cluster: %v, want %v", gotCluster, testEDSClusterName) - } - - got, err := xdsC.WaitForReportLoad(ctx) - if err != nil { - t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) - } - if got.Server != "" { - t.Fatalf("xdsClient.ReportLoad called with {%v}: want {\"\"}", got.Server) - } -} diff --git a/xds/internal/balancer/edsbalancer/xds_old.go b/xds/internal/balancer/edsbalancer/xds_old.go deleted file mode 100644 index 6729e6801f1..00000000000 --- a/xds/internal/balancer/edsbalancer/xds_old.go +++ /dev/null @@ -1,46 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import "google.golang.org/grpc/balancer" - -// The old xds balancer implements logic for both CDS and EDS. With the new -// design, CDS is split and moved to a separate balancer, and the xds balancer -// becomes the EDS balancer. -// -// To keep the existing tests working, this file regisger EDS balancer under the -// old xds balancer name. -// -// TODO: delete this file when migration to new workflow (LDS, RDS, CDS, EDS) is -// done. - -const xdsName = "xds_experimental" - -func init() { - balancer.Register(&xdsBalancerBuilder{}) -} - -// xdsBalancerBuilder register edsBalancerBuilder (now with name -// "eds_experimental") under the old name "xds_experimental". -type xdsBalancerBuilder struct { - edsBalancerBuilder -} - -func (b *xdsBalancerBuilder) Name() string { - return xdsName -} diff --git a/xds/internal/balancer/loadstore/load_store_wrapper.go b/xds/internal/balancer/loadstore/load_store_wrapper.go index 88fa344118c..8ce958d71ca 100644 --- a/xds/internal/balancer/loadstore/load_store_wrapper.go +++ b/xds/internal/balancer/loadstore/load_store_wrapper.go @@ -22,7 +22,7 @@ package loadstore import ( "sync" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // NewWrapper creates a Wrapper. diff --git a/xds/internal/balancer/lrs/balancer.go b/xds/internal/balancer/lrs/balancer.go deleted file mode 100644 index ab9ee7109db..00000000000 --- a/xds/internal/balancer/lrs/balancer.go +++ /dev/null @@ -1,246 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// Package lrs implements load reporting balancer for xds. -package lrs - -import ( - "encoding/json" - "fmt" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal/balancer/loadstore" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" -) - -func init() { - balancer.Register(&lrsBB{}) -} - -var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } - -const lrsBalancerName = "lrs_experimental" - -type lrsBB struct{} - -func (l *lrsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - b := &lrsBalancer{ - cc: cc, - buildOpts: opts, - } - b.logger = prefixLogger(b) - b.logger.Infof("Created") - - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.client = newXDSClientWrapper(client) - - return b -} - -func (l *lrsBB) Name() string { - return lrsBalancerName -} - -func (l *lrsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { - return parseConfig(c) -} - -type lrsBalancer struct { - cc balancer.ClientConn - buildOpts balancer.BuildOptions - - logger *grpclog.PrefixLogger - client *xdsClientWrapper - - config *lbConfig - lb balancer.Balancer // The sub balancer. -} - -func (b *lrsBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - newConfig, ok := s.BalancerConfig.(*lbConfig) - if !ok { - return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) - } - - // Update load reporting config or xds client. This needs to be done before - // updating the child policy because we need the loadStore from the updated - // client to be passed to the ccWrapper. - if err := b.client.update(newConfig); err != nil { - return err - } - - // If child policy is a different type, recreate the sub-balancer. - if b.config == nil || b.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { - bb := balancer.Get(newConfig.ChildPolicy.Name) - if bb == nil { - return fmt.Errorf("balancer %q not registered", newConfig.ChildPolicy.Name) - } - if b.lb != nil { - b.lb.Close() - } - lidJSON, err := newConfig.Locality.ToString() - if err != nil { - return fmt.Errorf("failed to marshal LocalityID: %#v", newConfig.Locality) - } - ccWrapper := newCCWrapper(b.cc, b.client.loadStore(), lidJSON) - b.lb = bb.Build(ccWrapper, b.buildOpts) - } - b.config = newConfig - - // Addresses and sub-balancer config are sent to sub-balancer. - return b.lb.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: s.ResolverState, - BalancerConfig: b.config.ChildPolicy.Config, - }) -} - -func (b *lrsBalancer) ResolverError(err error) { - if b.lb != nil { - b.lb.ResolverError(err) - } -} - -func (b *lrsBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { - if b.lb != nil { - b.lb.UpdateSubConnState(sc, s) - } -} - -func (b *lrsBalancer) Close() { - if b.lb != nil { - b.lb.Close() - b.lb = nil - } - b.client.close() -} - -type ccWrapper struct { - balancer.ClientConn - loadStore load.PerClusterReporter - localityIDJSON string -} - -func newCCWrapper(cc balancer.ClientConn, loadStore load.PerClusterReporter, localityIDJSON string) *ccWrapper { - return &ccWrapper{ - ClientConn: cc, - loadStore: loadStore, - localityIDJSON: localityIDJSON, - } -} - -func (ccw *ccWrapper) UpdateState(s balancer.State) { - s.Picker = newLoadReportPicker(s.Picker, ccw.localityIDJSON, ccw.loadStore) - ccw.ClientConn.UpdateState(s) -} - -// xdsClientInterface contains only the xds_client methods needed by LRS -// balancer. It's defined so we can override xdsclient in tests. -type xdsClientInterface interface { - ReportLoad(server string) (*load.Store, func()) - Close() -} - -type xdsClientWrapper struct { - c xdsClientInterface - cancelLoadReport func() - clusterName string - edsServiceName string - lrsServerName string - // loadWrapper is a wrapper with loadOriginal, with clusterName and - // edsServiceName. It's used children to report loads. - loadWrapper *loadstore.Wrapper -} - -func newXDSClientWrapper(c xdsClientInterface) *xdsClientWrapper { - return &xdsClientWrapper{ - c: c, - loadWrapper: loadstore.NewWrapper(), - } -} - -// update checks the config and xdsclient, and decides whether it needs to -// restart the load reporting stream. -func (w *xdsClientWrapper) update(newConfig *lbConfig) error { - var ( - restartLoadReport bool - updateLoadClusterAndService bool - ) - - // ClusterName is different, restart. ClusterName is from ClusterName and - // EdsServiceName. - if w.clusterName != newConfig.ClusterName { - updateLoadClusterAndService = true - w.clusterName = newConfig.ClusterName - } - if w.edsServiceName != newConfig.EdsServiceName { - updateLoadClusterAndService = true - w.edsServiceName = newConfig.EdsServiceName - } - - if updateLoadClusterAndService { - // This updates the clusterName and serviceName that will reported for the - // loads. The update here is too early, the perfect timing is when the - // picker is updated with the new connection. But from this balancer's point - // of view, it's impossible to tell. - // - // On the other hand, this will almost never happen. Each LRS policy - // shouldn't get updated config. The parent should do a graceful switch when - // the clusterName or serviceName is changed. - w.loadWrapper.UpdateClusterAndService(w.clusterName, w.edsServiceName) - } - - if w.lrsServerName != newConfig.LrsLoadReportingServerName { - // LrsLoadReportingServerName is different, load should be report to a - // different server, restart. - restartLoadReport = true - w.lrsServerName = newConfig.LrsLoadReportingServerName - } - - if restartLoadReport { - if w.cancelLoadReport != nil { - w.cancelLoadReport() - w.cancelLoadReport = nil - } - var loadStore *load.Store - if w.c != nil { - loadStore, w.cancelLoadReport = w.c.ReportLoad(w.lrsServerName) - } - w.loadWrapper.UpdateLoadStore(loadStore) - } - - return nil -} - -func (w *xdsClientWrapper) loadStore() load.PerClusterReporter { - return w.loadWrapper -} - -func (w *xdsClientWrapper) close() { - if w.cancelLoadReport != nil { - w.cancelLoadReport() - w.cancelLoadReport = nil - } - w.c.Close() -} diff --git a/xds/internal/balancer/lrs/balancer_test.go b/xds/internal/balancer/lrs/balancer_test.go deleted file mode 100644 index 0b575b11210..00000000000 --- a/xds/internal/balancer/lrs/balancer_test.go +++ /dev/null @@ -1,144 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/roundrobin" - "google.golang.org/grpc/connectivity" - internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - "google.golang.org/grpc/resolver" - xdsinternal "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/testutils" - "google.golang.org/grpc/xds/internal/testutils/fakeclient" -) - -const defaultTestTimeout = 1 * time.Second - -var ( - testBackendAddrs = []resolver.Address{ - {Addr: "1.1.1.1:1"}, - } - testLocality = &xdsinternal.LocalityID{ - Region: "test-region", - Zone: "test-zone", - SubZone: "test-sub-zone", - } -) - -// TestLoadReporting verifies that the lrs balancer starts the loadReport -// stream when the lbConfig passed to it contains a valid value for the LRS -// server (empty string). -func TestLoadReporting(t *testing.T) { - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() - - builder := balancer.Get(lrsBalancerName) - cc := testutils.NewTestClientConn(t) - lrsB := builder.Build(cc, balancer.BuildOptions{}) - defer lrsB.Close() - - if err := lrsB.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, - BalancerConfig: &lbConfig{ - ClusterName: testClusterName, - EdsServiceName: testServiceName, - LrsLoadReportingServerName: testLRSServerName, - Locality: testLocality, - ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: roundrobin.Name, - }, - }, - }); err != nil { - t.Fatalf("unexpected error from UpdateClientConnState: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - - got, err := xdsC.WaitForReportLoad(ctx) - if err != nil { - t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) - } - if got.Server != testLRSServerName { - t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerName) - } - - sc1 := <-cc.NewSubConnCh - lrsB.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) - lrsB.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) - - // Test pick with one backend. - p1 := <-cc.NewPickerCh - const successCount = 5 - for i := 0; i < successCount; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - gotSCSt.Done(balancer.DoneInfo{}) - } - const errorCount = 5 - for i := 0; i < errorCount; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - gotSCSt.Done(balancer.DoneInfo{Err: fmt.Errorf("error")}) - } - - // Dump load data from the store and compare with expected counts. - loadStore := xdsC.LoadStore() - if loadStore == nil { - t.Fatal("loadStore is nil in xdsClient") - } - sds := loadStore.Stats([]string{testClusterName}) - if len(sds) == 0 { - t.Fatalf("loads for cluster %v not found in store", testClusterName) - } - sd := sds[0] - if sd.Cluster != testClusterName || sd.Service != testServiceName { - t.Fatalf("got unexpected load for %q, %q, want %q, %q", sd.Cluster, sd.Service, testClusterName, testServiceName) - } - testLocalityJSON, _ := testLocality.ToString() - localityData, ok := sd.LocalityStats[testLocalityJSON] - if !ok { - t.Fatalf("loads for %v not found in store", testLocality) - } - reqStats := localityData.RequestStats - if reqStats.Succeeded != successCount { - t.Errorf("got succeeded %v, want %v", reqStats.Succeeded, successCount) - } - if reqStats.Errored != errorCount { - t.Errorf("got errord %v, want %v", reqStats.Errored, errorCount) - } - if reqStats.InProgress != 0 { - t.Errorf("got inProgress %v, want %v", reqStats.InProgress, 0) - } -} diff --git a/xds/internal/balancer/lrs/config.go b/xds/internal/balancer/lrs/config.go deleted file mode 100644 index 3d39961401b..00000000000 --- a/xds/internal/balancer/lrs/config.go +++ /dev/null @@ -1,54 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "encoding/json" - "fmt" - - internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal" -) - -type lbConfig struct { - serviceconfig.LoadBalancingConfig - ClusterName string - EdsServiceName string - LrsLoadReportingServerName string - Locality *internal.LocalityID - ChildPolicy *internalserviceconfig.BalancerConfig -} - -func parseConfig(c json.RawMessage) (*lbConfig, error) { - var cfg lbConfig - if err := json.Unmarshal(c, &cfg); err != nil { - return nil, err - } - if cfg.ClusterName == "" { - return nil, fmt.Errorf("required ClusterName is not set in %+v", cfg) - } - if cfg.LrsLoadReportingServerName == "" { - return nil, fmt.Errorf("required LrsLoadReportingServerName is not set in %+v", cfg) - } - if cfg.Locality == nil { - return nil, fmt.Errorf("required Locality is not set in %+v", cfg) - } - return &cfg, nil -} diff --git a/xds/internal/balancer/lrs/config_test.go b/xds/internal/balancer/lrs/config_test.go deleted file mode 100644 index f49430569fe..00000000000 --- a/xds/internal/balancer/lrs/config_test.go +++ /dev/null @@ -1,127 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer/roundrobin" - internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - xdsinternal "google.golang.org/grpc/xds/internal" -) - -const ( - testClusterName = "test-cluster" - testServiceName = "test-eds-service" - testLRSServerName = "test-lrs-name" -) - -func TestParseConfig(t *testing.T) { - tests := []struct { - name string - js string - want *lbConfig - wantErr bool - }{ - { - name: "no cluster name", - js: `{ - "edsServiceName": "test-eds-service", - "lrsLoadReportingServerName": "test-lrs-name", - "locality": { - "region": "test-region", - "zone": "test-zone", - "subZone": "test-sub-zone" - }, - "childPolicy":[{"round_robin":{}}] -} - `, - wantErr: true, - }, - { - name: "no LRS server name", - js: `{ - "clusterName": "test-cluster", - "edsServiceName": "test-eds-service", - "locality": { - "region": "test-region", - "zone": "test-zone", - "subZone": "test-sub-zone" - }, - "childPolicy":[{"round_robin":{}}] -} - `, - wantErr: true, - }, - { - name: "no locality", - js: `{ - "clusterName": "test-cluster", - "edsServiceName": "test-eds-service", - "lrsLoadReportingServerName": "test-lrs-name", - "childPolicy":[{"round_robin":{}}] -} - `, - wantErr: true, - }, - { - name: "good", - js: `{ - "clusterName": "test-cluster", - "edsServiceName": "test-eds-service", - "lrsLoadReportingServerName": "test-lrs-name", - "locality": { - "region": "test-region", - "zone": "test-zone", - "subZone": "test-sub-zone" - }, - "childPolicy":[{"round_robin":{}}] -} - `, - want: &lbConfig{ - ClusterName: testClusterName, - EdsServiceName: testServiceName, - LrsLoadReportingServerName: testLRSServerName, - Locality: &xdsinternal.LocalityID{ - Region: "test-region", - Zone: "test-zone", - SubZone: "test-sub-zone", - }, - ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: roundrobin.Name, - Config: nil, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := parseConfig([]byte(tt.js)) - if (err != nil) != tt.wantErr { - t.Errorf("parseConfig() error = %v, wantErr %v", err, tt.wantErr) - return - } - if diff := cmp.Diff(got, tt.want); diff != "" { - t.Errorf("parseConfig() got = %v, want %v, diff: %s", got, tt.want, diff) - } - }) - } -} diff --git a/xds/internal/balancer/lrs/picker.go b/xds/internal/balancer/lrs/picker.go deleted file mode 100644 index 1e4ad156e5b..00000000000 --- a/xds/internal/balancer/lrs/picker.go +++ /dev/null @@ -1,85 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - orcapb "github.com/cncf/udpa/go/udpa/data/orca/v1" - "google.golang.org/grpc/balancer" -) - -const ( - serverLoadCPUName = "cpu_utilization" - serverLoadMemoryName = "mem_utilization" -) - -// loadReporter wraps the methods from the loadStore that are used here. -type loadReporter interface { - CallStarted(locality string) - CallFinished(locality string, err error) - CallServerLoad(locality, name string, val float64) -} - -type loadReportPicker struct { - p balancer.Picker - - locality string - loadStore loadReporter -} - -func newLoadReportPicker(p balancer.Picker, id string, loadStore loadReporter) *loadReportPicker { - return &loadReportPicker{ - p: p, - locality: id, - loadStore: loadStore, - } -} - -func (lrp *loadReportPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - res, err := lrp.p.Pick(info) - if err != nil { - return res, err - } - - if lrp.loadStore == nil { - return res, err - } - - lrp.loadStore.CallStarted(lrp.locality) - oldDone := res.Done - res.Done = func(info balancer.DoneInfo) { - if oldDone != nil { - oldDone(info) - } - lrp.loadStore.CallFinished(lrp.locality, info.Err) - - load, ok := info.ServerLoad.(*orcapb.OrcaLoadReport) - if !ok { - return - } - lrp.loadStore.CallServerLoad(lrp.locality, serverLoadCPUName, load.CpuUtilization) - lrp.loadStore.CallServerLoad(lrp.locality, serverLoadMemoryName, load.MemUtilization) - for n, d := range load.RequestCost { - lrp.loadStore.CallServerLoad(lrp.locality, n, d) - } - for n, d := range load.Utilization { - lrp.loadStore.CallServerLoad(lrp.locality, n, d) - } - } - return res, err -} diff --git a/xds/internal/balancer/priority/balancer.go b/xds/internal/balancer/priority/balancer.go index 6c4ff08378e..23e8aa77503 100644 --- a/xds/internal/balancer/priority/balancer.go +++ b/xds/internal/balancer/priority/balancer.go @@ -24,6 +24,7 @@ package priority import ( + "encoding/json" "fmt" "sync" "time" @@ -33,19 +34,22 @@ import ( "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/hierarchy" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/xds/internal/balancer/balancergroup" ) -const priorityBalancerName = "priority_experimental" +// Name is the name of the priority balancer. +const Name = "priority_experimental" func init() { - balancer.Register(priorityBB{}) + balancer.Register(bb{}) } -type priorityBB struct{} +type bb struct{} -func (priorityBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &priorityBalancer{ cc: cc, done: grpcsync.NewEvent(), @@ -60,11 +64,14 @@ func (priorityBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) bal go b.run() b.logger.Infof("Created") return b +} +func (b bb) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + return parseConfig(s) } -func (priorityBB) Name() string { - return priorityBalancerName +func (bb) Name() string { + return Name } // timerWrapper wraps a timer with a boolean. So that when a race happens @@ -102,7 +109,8 @@ type priorityBalancer struct { } func (b *priorityBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - newConfig, ok := s.BalancerConfig.(*lbConfig) + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) + newConfig, ok := s.BalancerConfig.(*LBConfig) if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) } @@ -125,7 +133,7 @@ func (b *priorityBalancer) UpdateClientConnState(s balancer.ClientConnState) err // the balancer isn't built, because this child can be a low // priority. If necessary, it will be built when syncing priorities. cb := newChildBalancer(name, b, bb) - cb.updateConfig(newSubConfig.Config.Config, resolver.State{ + cb.updateConfig(newSubConfig, resolver.State{ Addresses: addressesSplit[name], ServiceConfig: s.ResolverState.ServiceConfig, Attributes: s.ResolverState.Attributes, @@ -140,13 +148,13 @@ func (b *priorityBalancer) UpdateClientConnState(s balancer.ClientConnState) err // rebuild, rebuild will happen when syncing priorities. if currentChild.bb.Name() != bb.Name() { currentChild.stop() - currentChild.bb = bb + currentChild.updateBuilder(bb) } // Update config and address, but note that this doesn't send the // updates to child balancer (the child balancer might not be built, if // it's a low priority). - currentChild.updateConfig(newSubConfig.Config.Config, resolver.State{ + currentChild.updateConfig(newSubConfig, resolver.State{ Addresses: addressesSplit[name], ServiceConfig: s.ResolverState.ServiceConfig, Attributes: s.ResolverState.Attributes, @@ -193,6 +201,10 @@ func (b *priorityBalancer) Close() { b.stopPriorityInitTimer() } +func (b *priorityBalancer) ExitIdle() { + b.bg.ExitIdle() +} + // stopPriorityInitTimer stops the priorityInitTimer if it's not nil, and set it // to nil. // diff --git a/xds/internal/balancer/priority/balancer_child.go b/xds/internal/balancer/priority/balancer_child.go index d012ad4e459..600705da01a 100644 --- a/xds/internal/balancer/priority/balancer_child.go +++ b/xds/internal/balancer/priority/balancer_child.go @@ -29,10 +29,11 @@ import ( type childBalancer struct { name string parent *priorityBalancer - bb balancer.Builder + bb *ignoreResolveNowBalancerBuilder - config serviceconfig.LoadBalancingConfig - rState resolver.State + ignoreReresolutionRequests bool + config serviceconfig.LoadBalancingConfig + rState resolver.State started bool state balancer.State @@ -44,7 +45,7 @@ func newChildBalancer(name string, parent *priorityBalancer, bb balancer.Builder return &childBalancer{ name: name, parent: parent, - bb: bb, + bb: newIgnoreResolveNowBalancerBuilder(bb, false), started: false, // Start with the connecting state and picker with re-pick error, so // that when a priority switch causes this child picked before it's @@ -56,10 +57,16 @@ func newChildBalancer(name string, parent *priorityBalancer, bb balancer.Builder } } +// updateBuilder updates builder for the child, but doesn't build. +func (cb *childBalancer) updateBuilder(bb balancer.Builder) { + cb.bb = newIgnoreResolveNowBalancerBuilder(bb, cb.ignoreReresolutionRequests) +} + // updateConfig sets childBalancer's config and state, but doesn't send update to // the child balancer. -func (cb *childBalancer) updateConfig(config serviceconfig.LoadBalancingConfig, rState resolver.State) { - cb.config = config +func (cb *childBalancer) updateConfig(child *Child, rState resolver.State) { + cb.ignoreReresolutionRequests = child.IgnoreReresolutionRequests + cb.config = child.Config.Config cb.rState = rState } @@ -76,6 +83,7 @@ func (cb *childBalancer) start() { // sendUpdate sends the addresses and config to the child balancer. func (cb *childBalancer) sendUpdate() { + cb.bb.updateIgnoreResolveNow(cb.ignoreReresolutionRequests) // TODO: return and aggregate the returned error in the parent. err := cb.parent.bg.UpdateClientConnState(cb.name, balancer.ClientConnState{ ResolverState: cb.rState, diff --git a/xds/internal/balancer/priority/balancer_priority.go b/xds/internal/balancer/priority/balancer_priority.go index ea2f4f04184..bd2c6724ea5 100644 --- a/xds/internal/balancer/priority/balancer_priority.go +++ b/xds/internal/balancer/priority/balancer_priority.go @@ -28,8 +28,12 @@ import ( ) var ( - errAllPrioritiesRemoved = errors.New("no locality is provided, all priorities are removed") - defaultPriorityInitTimeout = 10 * time.Second + // ErrAllPrioritiesRemoved is returned by the picker when there's no priority available. + ErrAllPrioritiesRemoved = errors.New("no priority is provided, all priorities are removed") + // DefaultPriorityInitTimeout is the timeout after which if a priority is + // not READY, the next will be started. It's exported to be overridden by + // tests. + DefaultPriorityInitTimeout = 10 * time.Second ) // syncPriority handles priority after a config update. It makes sure the @@ -70,7 +74,7 @@ func (b *priorityBalancer) syncPriority() { b.stopPriorityInitTimer() b.cc.UpdateState(balancer.State{ ConnectivityState: connectivity.TransientFailure, - Picker: base.NewErrPicker(errAllPrioritiesRemoved), + Picker: base.NewErrPicker(ErrAllPrioritiesRemoved), }) return } @@ -162,7 +166,7 @@ func (b *priorityBalancer) switchToChild(child *childBalancer, priority int) { // to check the stopped boolean. timerW := &timerWrapper{} b.priorityInitTimer = timerW - timerW.timer = time.AfterFunc(defaultPriorityInitTimeout, func() { + timerW.timer = time.AfterFunc(DefaultPriorityInitTimeout, func() { b.mu.Lock() defer b.mu.Unlock() if timerW.stopped { @@ -221,14 +225,17 @@ func (b *priorityBalancer) handleChildStateUpdate(childName string, s balancer.S child.state = s switch s.ConnectivityState { - case connectivity.Ready: + case connectivity.Ready, connectivity.Idle: + // Note that idle is also handled as if it's Ready. It will close the + // lower priorities (which will be kept in a cache, not deleted), and + // new picks will use the Idle picker. b.handlePriorityWithNewStateReady(child, priority) case connectivity.TransientFailure: b.handlePriorityWithNewStateTransientFailure(child, priority) case connectivity.Connecting: b.handlePriorityWithNewStateConnecting(child, priority, oldState) default: - // New state is Idle, should never happen. Don't forward. + // New state is Shutdown, should never happen. Don't forward. } } diff --git a/xds/internal/balancer/priority/balancer_test.go b/xds/internal/balancer/priority/balancer_test.go index be14231dcb3..b884035442e 100644 --- a/xds/internal/balancer/priority/balancer_test.go +++ b/xds/internal/balancer/priority/balancer_test.go @@ -19,6 +19,7 @@ package priority import ( + "context" "fmt" "testing" "time" @@ -83,7 +84,7 @@ func subConnFromPicker(t *testing.T, p balancer.Picker) func() balancer.SubConn // Init 0 and 1; 0 is up, use 0; add 2, use 0; remove 2, use 0. func (s) TestPriority_HighPriorityReady(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -95,10 +96,10 @@ func (s) TestPriority_HighPriorityReady(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -132,11 +133,11 @@ func (s) TestPriority_HighPriorityReady(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-2"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-2": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-2": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1", "child-2"}, }, @@ -162,10 +163,10 @@ func (s) TestPriority_HighPriorityReady(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -190,7 +191,7 @@ func (s) TestPriority_HighPriorityReady(t *testing.T) { // down, use 2; remove 2, use 1. func (s) TestPriority_SwitchPriority(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -202,10 +203,10 @@ func (s) TestPriority_SwitchPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -269,11 +270,11 @@ func (s) TestPriority_SwitchPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-2"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-2": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-2": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1", "child-2"}, }, @@ -328,10 +329,10 @@ func (s) TestPriority_SwitchPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -373,7 +374,7 @@ func (s) TestPriority_SwitchPriority(t *testing.T) { // use 0. func (s) TestPriority_HighPriorityToConnectingFromReady(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -385,10 +386,10 @@ func (s) TestPriority_HighPriorityToConnectingFromReady(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -468,7 +469,7 @@ func (s) TestPriority_HighPriorityToConnectingFromReady(t *testing.T) { // Init 0 and 1; 0 and 1 both down; add 2, use 2. func (s) TestPriority_HigherDownWhileAddingLower(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -480,10 +481,10 @@ func (s) TestPriority_HigherDownWhileAddingLower(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -534,11 +535,11 @@ func (s) TestPriority_HigherDownWhileAddingLower(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-2"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-2": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-2": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1", "child-2"}, }, @@ -579,7 +580,7 @@ func (s) TestPriority_HigherReadyCloseAllLower(t *testing.T) { // defer time.Sleep(10 * time.Millisecond) cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -592,11 +593,11 @@ func (s) TestPriority_HigherReadyCloseAllLower(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-2"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-2": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-2": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1", "child-2"}, }, @@ -687,15 +688,15 @@ func (s) TestPriority_HigherReadyCloseAllLower(t *testing.T) { func (s) TestPriority_InitTimeout(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := DefaultPriorityInitTimeout + DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + DefaultPriorityInitTimeout = old } }()() cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -707,10 +708,10 @@ func (s) TestPriority_InitTimeout(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -757,15 +758,15 @@ func (s) TestPriority_InitTimeout(t *testing.T) { func (s) TestPriority_RemovesAllPriorities(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := DefaultPriorityInitTimeout + DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + DefaultPriorityInitTimeout = old } }()() cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -777,10 +778,10 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -808,7 +809,7 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { ResolverState: resolver.State{ Addresses: nil, }, - BalancerConfig: &lbConfig{ + BalancerConfig: &LBConfig{ Children: nil, Priorities: nil, }, @@ -825,8 +826,8 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { // Test pick return TransientFailure. pFail := <-cc.NewPickerCh for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != errAllPrioritiesRemoved { - t.Fatalf("want pick error %v, got %v", errAllPrioritiesRemoved, err) + if _, err := pFail.Pick(balancer.PickInfo{}); err != ErrAllPrioritiesRemoved { + t.Fatalf("want pick error %v, got %v", ErrAllPrioritiesRemoved, err) } } @@ -838,10 +839,10 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[3]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -882,9 +883,9 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[2]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -933,7 +934,7 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { // will be used. func (s) TestPriority_HighPriorityNoEndpoints(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -945,10 +946,10 @@ func (s) TestPriority_HighPriorityNoEndpoints(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -980,10 +981,10 @@ func (s) TestPriority_HighPriorityNoEndpoints(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -1027,12 +1028,12 @@ func (s) TestPriority_HighPriorityNoEndpoints(t *testing.T) { func (s) TestPriority_FirstPriorityUnavailable(t *testing.T) { const testPriorityInitTimeout = time.Second defer func(t time.Duration) { - defaultPriorityInitTimeout = t - }(defaultPriorityInitTimeout) - defaultPriorityInitTimeout = testPriorityInitTimeout + DefaultPriorityInitTimeout = t + }(DefaultPriorityInitTimeout) + DefaultPriorityInitTimeout = testPriorityInitTimeout cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1043,9 +1044,9 @@ func (s) TestPriority_FirstPriorityUnavailable(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1058,7 +1059,7 @@ func (s) TestPriority_FirstPriorityUnavailable(t *testing.T) { ResolverState: resolver.State{ Addresses: nil, }, - BalancerConfig: &lbConfig{ + BalancerConfig: &LBConfig{ Children: nil, Priorities: nil, }, @@ -1075,7 +1076,7 @@ func (s) TestPriority_FirstPriorityUnavailable(t *testing.T) { // Init a(p0) and b(p1); a(p0) is up, use a; move b to p0, a to p1, use b. func (s) TestPriority_MoveChildToHigherPriority(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1087,10 +1088,10 @@ func (s) TestPriority_MoveChildToHigherPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -1124,10 +1125,10 @@ func (s) TestPriority_MoveChildToHigherPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-1", "child-0"}, }, @@ -1176,7 +1177,7 @@ func (s) TestPriority_MoveChildToHigherPriority(t *testing.T) { // Init a(p0) and b(p1); a(p0) is down, use b; move b to p0, a to p1, use b. func (s) TestPriority_MoveReadyChildToHigherPriority(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1188,10 +1189,10 @@ func (s) TestPriority_MoveReadyChildToHigherPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -1240,10 +1241,10 @@ func (s) TestPriority_MoveReadyChildToHigherPriority(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-1", "child-0"}, }, @@ -1276,7 +1277,7 @@ func (s) TestPriority_MoveReadyChildToHigherPriority(t *testing.T) { // Init a(p0) and b(p1); a(p0) is down, use b; move b to p0, a to p1, use b. func (s) TestPriority_RemoveReadyLowestChild(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1288,10 +1289,10 @@ func (s) TestPriority_RemoveReadyLowestChild(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, - "child-1": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0", "child-1"}, }, @@ -1338,9 +1339,9 @@ func (s) TestPriority_RemoveReadyLowestChild(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1384,7 +1385,7 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { }()() cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1395,9 +1396,9 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1426,15 +1427,15 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { // be different. if err := pb.UpdateClientConnState(balancer.ClientConnState{ ResolverState: resolver.State{}, - BalancerConfig: &lbConfig{}, + BalancerConfig: &LBConfig{}, }); err != nil { t.Fatalf("failed to update ClientConn state: %v", err) } pFail := <-cc.NewPickerCh for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != errAllPrioritiesRemoved { - t.Fatalf("want pick error %v, got %v", errAllPrioritiesRemoved, err) + if _, err := pFail.Pick(balancer.PickInfo{}); err != ErrAllPrioritiesRemoved { + t.Fatalf("want pick error %v, got %v", ErrAllPrioritiesRemoved, err) } } @@ -1454,9 +1455,9 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1487,7 +1488,7 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { // Init 0; 0 is up, use 0; change 0's policy, 0 is used. func (s) TestPriority_ChildPolicyChange(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1498,9 +1499,9 @@ func (s) TestPriority_ChildPolicyChange(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}}, }, Priorities: []string{"child-0"}, }, @@ -1533,9 +1534,9 @@ func (s) TestPriority_ChildPolicyChange(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: testRRBalancerName}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: testRRBalancerName}}, }, Priorities: []string{"child-0"}, }, @@ -1587,7 +1588,7 @@ func init() { // by acquiring a locked mutex. func (s) TestPriority_ChildPolicyUpdatePickerInline(t *testing.T) { cc := testutils.NewTestClientConn(t) - bb := balancer.Get(priorityBalancerName) + bb := balancer.Get(Name) pb := bb.Build(cc, balancer.BuildOptions{}) defer pb.Close() @@ -1598,9 +1599,9 @@ func (s) TestPriority_ChildPolicyUpdatePickerInline(t *testing.T) { hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), }, }, - BalancerConfig: &lbConfig{ - Children: map[string]*child{ - "child-0": {&internalserviceconfig.BalancerConfig{Name: inlineUpdateBalancerName}}, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: inlineUpdateBalancerName}}, }, Priorities: []string{"child-0"}, }, @@ -1616,3 +1617,260 @@ func (s) TestPriority_ChildPolicyUpdatePickerInline(t *testing.T) { } } } + +// When the child policy's configured to ignore reresolution requests, the +// ResolveNow() calls from this child should be all ignored. +func (s) TestPriority_IgnoreReresolutionRequest(t *testing.T) { + cc := testutils.NewTestClientConn(t) + bb := balancer.Get(Name) + pb := bb.Build(cc, balancer.BuildOptions{}) + defer pb.Close() + + // One children, with priorities [0], with one backend, reresolution is + // ignored. + if err := pb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), + }, + }, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": { + Config: &internalserviceconfig.BalancerConfig{Name: resolveNowBalancerName}, + IgnoreReresolutionRequests: true, + }, + }, + Priorities: []string{"child-0"}, + }, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // This is the balancer.ClientConn that the inner resolverNowBalancer is + // built with. + balancerCCI, err := resolveNowBalancerCCCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout waiting for ClientConn from balancer builder") + } + balancerCC := balancerCCI.(balancer.ClientConn) + + // Since IgnoreReresolutionRequests was set to true, all ResolveNow() calls + // should be ignored. + for i := 0; i < 5; i++ { + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + } + select { + case <-cc.ResolveNowCh: + t.Fatalf("got unexpected ResolveNow() call") + case <-time.After(time.Millisecond * 100): + } + + // Send another update to set IgnoreReresolutionRequests to false. + if err := pb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), + }, + }, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": { + Config: &internalserviceconfig.BalancerConfig{Name: resolveNowBalancerName}, + IgnoreReresolutionRequests: false, + }, + }, + Priorities: []string{"child-0"}, + }, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + // Call ResolveNow() on the CC, it should be forwarded. + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + select { + case <-cc.ResolveNowCh: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for ResolveNow()") + } + +} + +// When the child policy's configured to ignore reresolution requests, the +// ResolveNow() calls from this child should be all ignored, from the other +// children are forwarded. +func (s) TestPriority_IgnoreReresolutionRequestTwoChildren(t *testing.T) { + cc := testutils.NewTestClientConn(t) + bb := balancer.Get(Name) + pb := bb.Build(cc, balancer.BuildOptions{}) + defer pb.Close() + + // One children, with priorities [0, 1], each with one backend. + // Reresolution is ignored for p0. + if err := pb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), + }, + }, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": { + Config: &internalserviceconfig.BalancerConfig{Name: resolveNowBalancerName}, + IgnoreReresolutionRequests: true, + }, + "child-1": { + Config: &internalserviceconfig.BalancerConfig{Name: resolveNowBalancerName}, + }, + }, + Priorities: []string{"child-0", "child-1"}, + }, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // This is the balancer.ClientConn from p0. + balancerCCI0, err := resolveNowBalancerCCCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout waiting for ClientConn from balancer builder 0") + } + balancerCC0 := balancerCCI0.(balancer.ClientConn) + + // Set p0 to transient failure, p1 will be started. + addrs0 := <-cc.NewSubConnAddrsCh + if got, want := addrs0[0].Addr, testBackendAddrStrs[0]; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + sc0 := <-cc.NewSubConnCh + pb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + + // This is the balancer.ClientConn from p1. + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) + defer cancel1() + balancerCCI1, err := resolveNowBalancerCCCh.Receive(ctx1) + if err != nil { + t.Fatalf("timeout waiting for ClientConn from balancer builder 1") + } + balancerCC1 := balancerCCI1.(balancer.ClientConn) + + // Since IgnoreReresolutionRequests was set to true for p0, ResolveNow() + // from p0 should all be ignored. + for i := 0; i < 5; i++ { + balancerCC0.ResolveNow(resolver.ResolveNowOptions{}) + } + select { + case <-cc.ResolveNowCh: + t.Fatalf("got unexpected ResolveNow() call") + case <-time.After(time.Millisecond * 100): + } + + // But IgnoreReresolutionRequests was false for p1, ResolveNow() from p1 + // should be forwarded. + balancerCC1.ResolveNow(resolver.ResolveNowOptions{}) + select { + case <-cc.ResolveNowCh: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for ResolveNow()") + } +} + +const initIdleBalancerName = "test-init-Idle-balancer" + +var errsTestInitIdle = []error{ + fmt.Errorf("init Idle balancer error 0"), + fmt.Errorf("init Idle balancer error 1"), +} + +func init() { + for i := 0; i < 2; i++ { + ii := i + stub.Register(fmt.Sprintf("%s-%d", initIdleBalancerName, ii), stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, opts balancer.ClientConnState) error { + bd.ClientConn.NewSubConn(opts.ResolverState.Addresses, balancer.NewSubConnOptions{}) + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { + err := fmt.Errorf("wrong picker error") + if state.ConnectivityState == connectivity.Idle { + err = errsTestInitIdle[ii] + } + bd.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: &testutils.TestConstPicker{Err: err}, + }) + }, + }) + } +} + +// If the high priorities send initial pickers with Idle state, their pickers +// should get picks, because policies like ringhash starts in Idle, and doesn't +// connect. +// +// Init 0, 1; 0 is Idle, use 0; 0 is down, start 1; 1 is Idle, use 1. +func (s) TestPriority_HighPriorityInitIdle(t *testing.T) { + cc := testutils.NewTestClientConn(t) + bb := balancer.Get(Name) + pb := bb.Build(cc, balancer.BuildOptions{}) + defer pb.Close() + + // Two children, with priorities [0, 1], each with one backend. + if err := pb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[0]}, []string{"child-0"}), + hierarchy.Set(resolver.Address{Addr: testBackendAddrStrs[1]}, []string{"child-1"}), + }, + }, + BalancerConfig: &LBConfig{ + Children: map[string]*Child{ + "child-0": {Config: &internalserviceconfig.BalancerConfig{Name: fmt.Sprintf("%s-%d", initIdleBalancerName, 0)}}, + "child-1": {Config: &internalserviceconfig.BalancerConfig{Name: fmt.Sprintf("%s-%d", initIdleBalancerName, 1)}}, + }, + Priorities: []string{"child-0", "child-1"}, + }, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + addrs0 := <-cc.NewSubConnAddrsCh + if got, want := addrs0[0].Addr, testBackendAddrStrs[0]; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + sc0 := <-cc.NewSubConnCh + + // Send an Idle state update to trigger an Idle picker update. + pb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + p0 := <-cc.NewPickerCh + if pr, err := p0.Pick(balancer.PickInfo{}); err != errsTestInitIdle[0] { + t.Fatalf("pick returned %v, %v, want _, %v", pr, err, errsTestInitIdle[0]) + } + + // Turn p0 down, to start p1. + pb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + // Before 1 gets READY, picker should return NoSubConnAvailable, so RPCs + // will retry. + p1 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + if _, err := p1.Pick(balancer.PickInfo{}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("want pick error %v, got %v", balancer.ErrNoSubConnAvailable, err) + } + } + + addrs1 := <-cc.NewSubConnAddrsCh + if got, want := addrs1[0].Addr, testBackendAddrStrs[1]; got != want { + t.Fatalf("sc is created with addr %v, want %v", got, want) + } + sc1 := <-cc.NewSubConnCh + // Idle picker from p1 should also be forwarded. + pb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + p2 := <-cc.NewPickerCh + if pr, err := p2.Pick(balancer.PickInfo{}); err != errsTestInitIdle[1] { + t.Fatalf("pick returned %v, %v, want _, %v", pr, err, errsTestInitIdle[1]) + } +} diff --git a/xds/internal/balancer/priority/config.go b/xds/internal/balancer/priority/config.go index da085908c71..37f1c9a829a 100644 --- a/xds/internal/balancer/priority/config.go +++ b/xds/internal/balancer/priority/config.go @@ -26,24 +26,27 @@ import ( "google.golang.org/grpc/serviceconfig" ) -type child struct { - Config *internalserviceconfig.BalancerConfig +// Child is a child of priority balancer. +type Child struct { + Config *internalserviceconfig.BalancerConfig `json:"config,omitempty"` + IgnoreReresolutionRequests bool `json:"ignoreReresolutionRequests,omitempty"` } -type lbConfig struct { - serviceconfig.LoadBalancingConfig +// LBConfig represents priority balancer's config. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` // Children is a map from the child balancer names to their configs. Child // names can be found in field Priorities. - Children map[string]*child + Children map[string]*Child `json:"children,omitempty"` // Priorities is a list of child balancer names. They are sorted from // highest priority to low. The type/config for each child can be found in // field Children, with the balancer name as the key. - Priorities []string + Priorities []string `json:"priorities,omitempty"` } -func parseConfig(c json.RawMessage) (*lbConfig, error) { - var cfg lbConfig +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, err } diff --git a/xds/internal/balancer/priority/config_test.go b/xds/internal/balancer/priority/config_test.go index 15c4069dd1e..8316224c91b 100644 --- a/xds/internal/balancer/priority/config_test.go +++ b/xds/internal/balancer/priority/config_test.go @@ -30,7 +30,7 @@ func TestParseConfig(t *testing.T) { tests := []struct { name string js string - want *lbConfig + want *LBConfig wantErr bool }{ { @@ -63,26 +63,27 @@ func TestParseConfig(t *testing.T) { js: `{ "priorities": ["child-1", "child-2", "child-3"], "children": { - "child-1": {"config": [{"round_robin":{}}]}, + "child-1": {"config": [{"round_robin":{}}], "ignoreReresolutionRequests": true}, "child-2": {"config": [{"round_robin":{}}]}, "child-3": {"config": [{"round_robin":{}}]} } } `, - want: &lbConfig{ - Children: map[string]*child{ + want: &LBConfig{ + Children: map[string]*Child{ "child-1": { - &internalserviceconfig.BalancerConfig{ + Config: &internalserviceconfig.BalancerConfig{ Name: roundrobin.Name, }, + IgnoreReresolutionRequests: true, }, "child-2": { - &internalserviceconfig.BalancerConfig{ + Config: &internalserviceconfig.BalancerConfig{ Name: roundrobin.Name, }, }, "child-3": { - &internalserviceconfig.BalancerConfig{ + Config: &internalserviceconfig.BalancerConfig{ Name: roundrobin.Name, }, }, diff --git a/xds/internal/balancer/priority/ignore_resolve_now.go b/xds/internal/balancer/priority/ignore_resolve_now.go new file mode 100644 index 00000000000..9a9f4777269 --- /dev/null +++ b/xds/internal/balancer/priority/ignore_resolve_now.go @@ -0,0 +1,73 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package priority + +import ( + "sync/atomic" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/resolver" +) + +type ignoreResolveNowBalancerBuilder struct { + balancer.Builder + ignoreResolveNow *uint32 +} + +// If `ignore` is true, all `ResolveNow()` from the balancer built from this +// builder will be ignored. +// +// `ignore` can be updated later by `updateIgnoreResolveNow`, and the update +// will be propagated to all the old and new balancers built with this. +func newIgnoreResolveNowBalancerBuilder(bb balancer.Builder, ignore bool) *ignoreResolveNowBalancerBuilder { + ret := &ignoreResolveNowBalancerBuilder{ + Builder: bb, + ignoreResolveNow: new(uint32), + } + ret.updateIgnoreResolveNow(ignore) + return ret +} + +func (irnbb *ignoreResolveNowBalancerBuilder) updateIgnoreResolveNow(b bool) { + if b { + atomic.StoreUint32(irnbb.ignoreResolveNow, 1) + return + } + atomic.StoreUint32(irnbb.ignoreResolveNow, 0) + +} + +func (irnbb *ignoreResolveNowBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + return irnbb.Builder.Build(&ignoreResolveNowClientConn{ + ClientConn: cc, + ignoreResolveNow: irnbb.ignoreResolveNow, + }, opts) +} + +type ignoreResolveNowClientConn struct { + balancer.ClientConn + ignoreResolveNow *uint32 +} + +func (i ignoreResolveNowClientConn) ResolveNow(o resolver.ResolveNowOptions) { + if atomic.LoadUint32(i.ignoreResolveNow) != 0 { + return + } + i.ClientConn.ResolveNow(o) +} diff --git a/xds/internal/balancer/priority/ignore_resolve_now_test.go b/xds/internal/balancer/priority/ignore_resolve_now_test.go new file mode 100644 index 00000000000..b7cecd6c1ff --- /dev/null +++ b/xds/internal/balancer/priority/ignore_resolve_now_test.go @@ -0,0 +1,104 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package priority + +import ( + "context" + "testing" + "time" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/roundrobin" + grpctestutils "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/testutils" +) + +const resolveNowBalancerName = "test-resolve-now-balancer" + +var resolveNowBalancerCCCh = grpctestutils.NewChannel() + +type resolveNowBalancerBuilder struct { + balancer.Builder +} + +func (r *resolveNowBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + resolveNowBalancerCCCh.Send(cc) + return r.Builder.Build(cc, opts) +} + +func (r *resolveNowBalancerBuilder) Name() string { + return resolveNowBalancerName +} + +func init() { + balancer.Register(&resolveNowBalancerBuilder{ + Builder: balancer.Get(roundrobin.Name), + }) +} + +func (s) TestIgnoreResolveNowBalancerBuilder(t *testing.T) { + resolveNowBB := balancer.Get(resolveNowBalancerName) + // Create a build wrapper, but will not ignore ResolveNow(). + ignoreResolveNowBB := newIgnoreResolveNowBalancerBuilder(resolveNowBB, false) + + cc := testutils.NewTestClientConn(t) + tb := ignoreResolveNowBB.Build(cc, balancer.BuildOptions{}) + defer tb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // This is the balancer.ClientConn that the inner resolverNowBalancer is + // built with. + balancerCCI, err := resolveNowBalancerCCCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout waiting for ClientConn from balancer builder") + } + balancerCC := balancerCCI.(balancer.ClientConn) + + // Call ResolveNow() on the CC, it should be forwarded. + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + select { + case <-cc.ResolveNowCh: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for ResolveNow()") + } + + // Update ignoreResolveNow to true, call ResolveNow() on the CC, they should + // all be ignored. + ignoreResolveNowBB.updateIgnoreResolveNow(true) + for i := 0; i < 5; i++ { + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + } + select { + case <-cc.ResolveNowCh: + t.Fatalf("got unexpected ResolveNow() call") + case <-time.After(time.Millisecond * 100): + } + + // Update ignoreResolveNow to false, new ResolveNow() calls should be + // forwarded. + ignoreResolveNowBB.updateIgnoreResolveNow(false) + balancerCC.ResolveNow(resolver.ResolveNowOptions{}) + select { + case <-cc.ResolveNowCh: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for ResolveNow()") + } +} diff --git a/xds/internal/balancer/ringhash/config.go b/xds/internal/balancer/ringhash/config.go new file mode 100644 index 00000000000..5cb4aab3d9c --- /dev/null +++ b/xds/internal/balancer/ringhash/config.go @@ -0,0 +1,56 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "encoding/json" + "fmt" + + "google.golang.org/grpc/serviceconfig" +) + +// LBConfig is the balancer config for ring_hash balancer. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + + MinRingSize uint64 `json:"minRingSize,omitempty"` + MaxRingSize uint64 `json:"maxRingSize,omitempty"` +} + +const ( + defaultMinSize = 1024 + defaultMaxSize = 8 * 1024 * 1024 // 8M +) + +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig + if err := json.Unmarshal(c, &cfg); err != nil { + return nil, err + } + if cfg.MinRingSize == 0 { + cfg.MinRingSize = defaultMinSize + } + if cfg.MaxRingSize == 0 { + cfg.MaxRingSize = defaultMaxSize + } + if cfg.MinRingSize > cfg.MaxRingSize { + return nil, fmt.Errorf("min %v is greater than max %v", cfg.MinRingSize, cfg.MaxRingSize) + } + return &cfg, nil +} diff --git a/xds/internal/balancer/ringhash/config_test.go b/xds/internal/balancer/ringhash/config_test.go new file mode 100644 index 00000000000..a2a966dc318 --- /dev/null +++ b/xds/internal/balancer/ringhash/config_test.go @@ -0,0 +1,68 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestParseConfig(t *testing.T) { + tests := []struct { + name string + js string + want *LBConfig + wantErr bool + }{ + { + name: "OK", + js: `{"minRingSize": 1, "maxRingSize": 2}`, + want: &LBConfig{MinRingSize: 1, MaxRingSize: 2}, + }, + { + name: "OK with default min", + js: `{"maxRingSize": 2000}`, + want: &LBConfig{MinRingSize: defaultMinSize, MaxRingSize: 2000}, + }, + { + name: "OK with default max", + js: `{"minRingSize": 2000}`, + want: &LBConfig{MinRingSize: 2000, MaxRingSize: defaultMaxSize}, + }, + { + name: "min greater than max", + js: `{"minRingSize": 10, "maxRingSize": 2}`, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseConfig([]byte(tt.js)) + if (err != nil) != tt.wantErr { + t.Errorf("parseConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("parseConfig() got unexpected output, diff (-got +want): %v", diff) + } + }) + } +} diff --git a/xds/internal/balancer/edsbalancer/logging.go b/xds/internal/balancer/ringhash/logging.go similarity index 83% rename from xds/internal/balancer/edsbalancer/logging.go rename to xds/internal/balancer/ringhash/logging.go index be4d0a512d1..64a1d467f55 100644 --- a/xds/internal/balancer/edsbalancer/logging.go +++ b/xds/internal/balancer/ringhash/logging.go @@ -1,6 +1,6 @@ /* * - * Copyright 2020 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ * */ -package edsbalancer +package ringhash import ( "fmt" @@ -25,10 +25,10 @@ import ( internalgrpclog "google.golang.org/grpc/internal/grpclog" ) -const prefix = "[eds-lb %p] " +const prefix = "[ring-hash-lb %p] " var logger = grpclog.Component("xds") -func prefixLogger(p *edsBalancer) *internalgrpclog.PrefixLogger { +func prefixLogger(p *ringhashBalancer) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) } diff --git a/xds/internal/balancer/ringhash/picker.go b/xds/internal/balancer/ringhash/picker.go new file mode 100644 index 00000000000..dcea6d46e51 --- /dev/null +++ b/xds/internal/balancer/ringhash/picker.go @@ -0,0 +1,154 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "fmt" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/status" +) + +type picker struct { + ring *ring + logger *grpclog.PrefixLogger +} + +func newPicker(ring *ring, logger *grpclog.PrefixLogger) *picker { + return &picker{ring: ring, logger: logger} +} + +// handleRICSResult is the return type of handleRICS. It's needed to wrap the +// returned error from Pick() in a struct. With this, if the return values are +// `balancer.PickResult, error, bool`, linter complains because error is not the +// last return value. +type handleRICSResult struct { + pr balancer.PickResult + err error +} + +// handleRICS generates pick result if the entry is in Ready, Idle, Connecting +// or Shutdown. TransientFailure will be handled specifically after this +// function returns. +// +// The first return value indicates if the state is in Ready, Idle, Connecting +// or Shutdown. If it's true, the PickResult and error should be returned from +// Pick() as is. +func (p *picker) handleRICS(e *ringEntry) (handleRICSResult, bool) { + switch state := e.sc.effectiveState(); state { + case connectivity.Ready: + return handleRICSResult{pr: balancer.PickResult{SubConn: e.sc.sc}}, true + case connectivity.Idle: + // Trigger Connect() and queue the pick. + e.sc.queueConnect() + return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true + case connectivity.Connecting: + return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true + case connectivity.TransientFailure: + // Return ok==false, so TransientFailure will be handled afterwards. + return handleRICSResult{}, false + case connectivity.Shutdown: + // Shutdown can happen in a race where the old picker is called. A new + // picker should already be sent. + return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true + default: + // Should never reach this. All the connectivity states are already + // handled in the cases. + p.logger.Errorf("SubConn has undefined connectivity state: %v", state) + return handleRICSResult{err: status.Errorf(codes.Unavailable, "SubConn has undefined connectivity state: %v", state)}, true + } +} + +func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + e := p.ring.pick(getRequestHash(info.Ctx)) + if hr, ok := p.handleRICS(e); ok { + return hr.pr, hr.err + } + // ok was false, the entry is in transient failure. + return p.handleTransientFailure(e) +} + +func (p *picker) handleTransientFailure(e *ringEntry) (balancer.PickResult, error) { + // Queue a connect on the first picked SubConn. + e.sc.queueConnect() + + // Find next entry in the ring, skipping duplicate SubConns. + e2 := nextSkippingDuplicates(p.ring, e) + if e2 == nil { + // There's no next entry available, fail the pick. + return balancer.PickResult{}, fmt.Errorf("the only SubConn is in Transient Failure") + } + + // For the second SubConn, also check Ready/Idle/Connecting as if it's the + // first entry. + if hr, ok := p.handleRICS(e2); ok { + return hr.pr, hr.err + } + + // The second SubConn is also in TransientFailure. Queue a connect on it. + e2.sc.queueConnect() + + // If it gets here, this is after the second SubConn, and the second SubConn + // was in TransientFailure. + // + // Loop over all other SubConns: + // - If all SubConns so far are all TransientFailure, trigger Connect() on + // the TransientFailure SubConns, and keep going. + // - If there's one SubConn that's not in TransientFailure, keep checking + // the remaining SubConns (in case there's a Ready, which will be returned), + // but don't not trigger Connect() on the other SubConns. + var firstNonFailedFound bool + for ee := nextSkippingDuplicates(p.ring, e2); ee != e; ee = nextSkippingDuplicates(p.ring, ee) { + scState := ee.sc.effectiveState() + if scState == connectivity.Ready { + return balancer.PickResult{SubConn: ee.sc.sc}, nil + } + if firstNonFailedFound { + continue + } + if scState == connectivity.TransientFailure { + // This will queue a connect. + ee.sc.queueConnect() + continue + } + // This is a SubConn in a non-failure state. We continue to check the + // other SubConns, but remember that there was a non-failed SubConn + // seen. After this, Pick() will never trigger any SubConn to Connect(). + firstNonFailedFound = true + if scState == connectivity.Idle { + // This is the first non-failed SubConn, and it is in a real Idle + // state. Trigger it to Connect(). + ee.sc.queueConnect() + } + } + return balancer.PickResult{}, fmt.Errorf("no connection is Ready") +} + +func nextSkippingDuplicates(ring *ring, entry *ringEntry) *ringEntry { + for next := ring.next(entry); next != entry; next = ring.next(next) { + if next.sc != entry.sc { + return next + } + } + // There's no qualifying next entry. + return nil +} diff --git a/xds/internal/balancer/ringhash/picker_test.go b/xds/internal/balancer/ringhash/picker_test.go new file mode 100644 index 00000000000..c88698ebbdf --- /dev/null +++ b/xds/internal/balancer/ringhash/picker_test.go @@ -0,0 +1,285 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/xds/internal/testutils" +) + +func newTestRing(cStats []connectivity.State) *ring { + var items []*ringEntry + for i, st := range cStats { + testSC := testutils.TestSubConns[i] + items = append(items, &ringEntry{ + idx: i, + hash: uint64((i + 1) * 10), + sc: &subConn{ + addr: testSC.String(), + sc: testSC, + state: st, + }, + }) + } + return &ring{items: items} +} + +func TestPickerPickFirstTwo(t *testing.T) { + tests := []struct { + name string + ring *ring + hash uint64 + wantSC balancer.SubConn + wantErr error + wantSCToConnect balancer.SubConn + }{ + { + name: "picked is Ready", + ring: newTestRing([]connectivity.State{connectivity.Ready, connectivity.Idle}), + hash: 5, + wantSC: testutils.TestSubConns[0], + }, + { + name: "picked is connecting, queue", + ring: newTestRing([]connectivity.State{connectivity.Connecting, connectivity.Idle}), + hash: 5, + wantErr: balancer.ErrNoSubConnAvailable, + }, + { + name: "picked is Idle, connect and queue", + ring: newTestRing([]connectivity.State{connectivity.Idle, connectivity.Idle}), + hash: 5, + wantErr: balancer.ErrNoSubConnAvailable, + wantSCToConnect: testutils.TestSubConns[0], + }, + { + name: "picked is TransientFailure, next is ready, return", + ring: newTestRing([]connectivity.State{connectivity.TransientFailure, connectivity.Ready}), + hash: 5, + wantSC: testutils.TestSubConns[1], + }, + { + name: "picked is TransientFailure, next is connecting, queue", + ring: newTestRing([]connectivity.State{connectivity.TransientFailure, connectivity.Connecting}), + hash: 5, + wantErr: balancer.ErrNoSubConnAvailable, + }, + { + name: "picked is TransientFailure, next is Idle, connect and queue", + ring: newTestRing([]connectivity.State{connectivity.TransientFailure, connectivity.Idle}), + hash: 5, + wantErr: balancer.ErrNoSubConnAvailable, + wantSCToConnect: testutils.TestSubConns[1], + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &picker{ring: tt.ring} + got, err := p.Pick(balancer.PickInfo{ + Ctx: SetRequestHash(context.Background(), tt.hash), + }) + if err != tt.wantErr { + t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !cmp.Equal(got, balancer.PickResult{SubConn: tt.wantSC}, cmpOpts) { + t.Errorf("Pick() got = %v, want picked SubConn: %v", got, tt.wantSC) + } + if sc := tt.wantSCToConnect; sc != nil { + select { + case <-sc.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestShortTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc) + } + } + }) + } +} + +// TestPickerPickTriggerTFConnect covers that if the picked SubConn is +// TransientFailures, all SubConns until a non-TransientFailure are queued for +// Connect(). +func TestPickerPickTriggerTFConnect(t *testing.T) { + ring := newTestRing([]connectivity.State{ + connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, + connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, + }) + p := &picker{ring: ring} + _, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) + if err == nil { + t.Fatalf("Pick() error = %v, want non-nil", err) + } + // The first 4 SubConns, all in TransientFailure, should be queued to + // connect. + for i := 0; i < 4; i++ { + it := ring.items[i] + if !it.sc.connectQueued { + t.Errorf("the %d-th SubConn is not queued for connect", i) + } + } + // The other SubConns, after the first Idle, should not be queued to + // connect. + for i := 5; i < len(ring.items); i++ { + it := ring.items[i] + if it.sc.connectQueued { + t.Errorf("the %d-th SubConn is unexpected queued for connect", i) + } + } +} + +// TestPickerPickTriggerTFReturnReady covers that if the picked SubConn is +// TransientFailure, SubConn 2 and 3 are TransientFailure, 4 is Ready. SubConn 2 +// and 3 will Connect(), and 4 will be returned. +func TestPickerPickTriggerTFReturnReady(t *testing.T) { + ring := newTestRing([]connectivity.State{ + connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Ready, + }) + p := &picker{ring: ring} + pr, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) + if err != nil { + t.Fatalf("Pick() error = %v, want nil", err) + } + if wantSC := testutils.TestSubConns[3]; pr.SubConn != wantSC { + t.Fatalf("Pick() = %v, want %v", pr.SubConn, wantSC) + } + // The first 3 SubConns, all in TransientFailure, should be queued to + // connect. + for i := 0; i < 3; i++ { + it := ring.items[i] + if !it.sc.connectQueued { + t.Errorf("the %d-th SubConn is not queued for connect", i) + } + } +} + +// TestPickerPickTriggerTFWithIdle covers that if the picked SubConn is +// TransientFailure, SubConn 2 is TransientFailure, 3 is Idle (init Idle). Pick +// will be queue, SubConn 3 will Connect(), SubConn 4 and 5 (in TransientFailre) +// will not queue a Connect. +func TestPickerPickTriggerTFWithIdle(t *testing.T) { + ring := newTestRing([]connectivity.State{ + connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, + }) + p := &picker{ring: ring} + _, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) + if err == balancer.ErrNoSubConnAvailable { + t.Fatalf("Pick() error = %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + // The first 2 SubConns, all in TransientFailure, should be queued to + // connect. + for i := 0; i < 2; i++ { + it := ring.items[i] + if !it.sc.connectQueued { + t.Errorf("the %d-th SubConn is not queued for connect", i) + } + } + // SubConn 3 was in Idle, so should Connect() + select { + case <-testutils.TestSubConns[2].ConnectCh: + case <-time.After(defaultTestShortTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", testutils.TestSubConns[2]) + } + // The other SubConns, after the first Idle, should not be queued to + // connect. + for i := 3; i < len(ring.items); i++ { + it := ring.items[i] + if it.sc.connectQueued { + t.Errorf("the %d-th SubConn is unexpected queued for connect", i) + } + } +} + +func TestNextSkippingDuplicatesNoDup(t *testing.T) { + testRing := newTestRing([]connectivity.State{connectivity.Idle, connectivity.Idle}) + tests := []struct { + name string + ring *ring + cur *ringEntry + want *ringEntry + }{ + { + name: "no dup", + ring: testRing, + cur: testRing.items[0], + want: testRing.items[1], + }, + { + name: "only one entry", + ring: &ring{items: []*ringEntry{testRing.items[0]}}, + cur: testRing.items[0], + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := nextSkippingDuplicates(tt.ring, tt.cur); !cmp.Equal(got, tt.want, cmpOpts) { + t.Errorf("nextSkippingDuplicates() = %v, want %v", got, tt.want) + } + }) + } +} + +// addDups adds duplicates of items[0] to the ring. +func addDups(r *ring, count int) *ring { + var ( + items []*ringEntry + idx int + ) + for i, it := range r.items { + itt := *it + itt.idx = idx + items = append(items, &itt) + idx++ + if i == 0 { + // Add duplicate of items[0] to the ring + for j := 0; j < count; j++ { + itt2 := *it + itt2.idx = idx + items = append(items, &itt2) + idx++ + } + } + } + return &ring{items: items} +} + +func TestNextSkippingDuplicatesMoreDup(t *testing.T) { + testRing := newTestRing([]connectivity.State{connectivity.Idle, connectivity.Idle}) + // Make a new ring with duplicate SubConns. + dupTestRing := addDups(testRing, 3) + if got := nextSkippingDuplicates(dupTestRing, dupTestRing.items[0]); !cmp.Equal(got, dupTestRing.items[len(dupTestRing.items)-1], cmpOpts) { + t.Errorf("nextSkippingDuplicates() = %v, want %v", got, dupTestRing.items[len(dupTestRing.items)-1]) + } +} + +func TestNextSkippingDuplicatesOnlyDup(t *testing.T) { + testRing := newTestRing([]connectivity.State{connectivity.Idle}) + // Make a new ring with only duplicate SubConns. + dupTestRing := addDups(testRing, 3) + // This ring only has duplicates of items[0], should return nil. + if got := nextSkippingDuplicates(dupTestRing, dupTestRing.items[0]); got != nil { + t.Errorf("nextSkippingDuplicates() = %v, want nil", got) + } +} diff --git a/xds/internal/balancer/ringhash/ring.go b/xds/internal/balancer/ringhash/ring.go new file mode 100644 index 00000000000..68e844cfb48 --- /dev/null +++ b/xds/internal/balancer/ringhash/ring.go @@ -0,0 +1,163 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "fmt" + "math" + "sort" + "strconv" + + xxhash "github.com/cespare/xxhash/v2" + "google.golang.org/grpc/resolver" +) + +type ring struct { + items []*ringEntry +} + +type subConnWithWeight struct { + sc *subConn + weight float64 +} + +type ringEntry struct { + idx int + hash uint64 + sc *subConn +} + +// newRing creates a ring from the subConns. The ring size is limited by the +// passed in max/min. +// +// ring entries will be created for each subConn, and subConn with high weight +// (specified by the address) may have multiple entries. +// +// For example, for subConns with weights {a:3, b:3, c:4}, a generated ring of +// size 10 could be: +// - {idx:0 hash:3689675255460411075 b} +// - {idx:1 hash:4262906501694543955 c} +// - {idx:2 hash:5712155492001633497 c} +// - {idx:3 hash:8050519350657643659 b} +// - {idx:4 hash:8723022065838381142 b} +// - {idx:5 hash:11532782514799973195 a} +// - {idx:6 hash:13157034721563383607 c} +// - {idx:7 hash:14468677667651225770 c} +// - {idx:8 hash:17336016884672388720 a} +// - {idx:9 hash:18151002094784932496 a} +// +// To pick from a ring, a binary search will be done for the given target hash, +// and first item with hash >= given hash will be returned. +func newRing(subConns map[resolver.Address]*subConn, minRingSize, maxRingSize uint64) (*ring, error) { + // https://github.com/envoyproxy/envoy/blob/765c970f06a4c962961a0e03a467e165b276d50f/source/common/upstream/ring_hash_lb.cc#L114 + normalizedWeights, minWeight, err := normalizeWeights(subConns) + if err != nil { + return nil, err + } + // Normalized weights for {3,3,4} is {0.3,0.3,0.4}. + + // Scale up the size of the ring such that the least-weighted host gets a + // whole number of hashes on the ring. + // + // Note that size is limited by the input max/min. + scale := math.Min(math.Ceil(minWeight*float64(minRingSize))/minWeight, float64(maxRingSize)) + ringSize := math.Ceil(scale) + items := make([]*ringEntry, 0, int(ringSize)) + + // For each entry, scale*weight nodes are generated in the ring. + // + // Not all of these are whole numbers. E.g. for weights {a:3,b:3,c:4}, if + // ring size is 7, scale is 6.66. The numbers of nodes will be + // {a,a,b,b,c,c,c}. + // + // A hash is generated for each item, and later the results will be sorted + // based on the hash. + var ( + idx int + targetIdx float64 + ) + for _, scw := range normalizedWeights { + targetIdx += scale * scw.weight + for float64(idx) < targetIdx { + h := xxhash.Sum64String(scw.sc.addr + strconv.Itoa(len(items))) + items = append(items, &ringEntry{idx: idx, hash: h, sc: scw.sc}) + idx++ + } + } + + // Sort items based on hash, to prepare for binary search. + sort.Slice(items, func(i, j int) bool { return items[i].hash < items[j].hash }) + for i, ii := range items { + ii.idx = i + } + return &ring{items: items}, nil +} + +// normalizeWeights divides all the weights by the sum, so that the total weight +// is 1. +func normalizeWeights(subConns map[resolver.Address]*subConn) (_ []subConnWithWeight, min float64, _ error) { + if len(subConns) == 0 { + return nil, 0, fmt.Errorf("number of subconns is 0") + } + var weightSum uint32 + for a := range subConns { + // The address weight was moved from attributes to the Metadata field. + // This is necessary (all the attributes need to be stripped) for the + // balancer to detect identical {address+weight} combination. + weightSum += a.Metadata.(uint32) + } + if weightSum == 0 { + return nil, 0, fmt.Errorf("total weight of all subconns is 0") + } + weightSumF := float64(weightSum) + ret := make([]subConnWithWeight, 0, len(subConns)) + min = math.MaxFloat64 + for a, sc := range subConns { + nw := float64(a.Metadata.(uint32)) / weightSumF + ret = append(ret, subConnWithWeight{sc: sc, weight: nw}) + if nw < min { + min = nw + } + } + // Sort the addresses to return consistent results. + // + // Note: this might not be necessary, but this makes sure the ring is + // consistent as long as the addresses are the same, for example, in cases + // where an address is added and then removed, the RPCs will still pick the + // same old SubConn. + sort.Slice(ret, func(i, j int) bool { return ret[i].sc.addr < ret[j].sc.addr }) + return ret, min, nil +} + +// pick does a binary search. It returns the item with smallest index i that +// r.items[i].hash >= h. +func (r *ring) pick(h uint64) *ringEntry { + i := sort.Search(len(r.items), func(i int) bool { return r.items[i].hash >= h }) + if i == len(r.items) { + // If not found, and h is greater than the largest hash, return the + // first item. + i = 0 + } + return r.items[i] +} + +// next returns the next entry. +func (r *ring) next(e *ringEntry) *ringEntry { + return r.items[(e.idx+1)%len(r.items)] +} diff --git a/xds/internal/balancer/ringhash/ring_test.go b/xds/internal/balancer/ringhash/ring_test.go new file mode 100644 index 00000000000..2d664e05bb2 --- /dev/null +++ b/xds/internal/balancer/ringhash/ring_test.go @@ -0,0 +1,113 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "fmt" + "math" + "testing" + + xxhash "github.com/cespare/xxhash/v2" + "google.golang.org/grpc/resolver" +) + +func testAddr(addr string, weight uint32) resolver.Address { + return resolver.Address{Addr: addr, Metadata: weight} +} + +func TestRingNew(t *testing.T) { + testAddrs := []resolver.Address{ + testAddr("a", 3), + testAddr("b", 3), + testAddr("c", 4), + } + var totalWeight float64 = 10 + testSubConnMap := map[resolver.Address]*subConn{ + testAddr("a", 3): {addr: "a"}, + testAddr("b", 3): {addr: "b"}, + testAddr("c", 4): {addr: "c"}, + } + for _, min := range []uint64{3, 4, 6, 8} { + for _, max := range []uint64{20, 8} { + t.Run(fmt.Sprintf("size-min-%v-max-%v", min, max), func(t *testing.T) { + r, _ := newRing(testSubConnMap, min, max) + totalCount := len(r.items) + if totalCount < int(min) || totalCount > int(max) { + t.Fatalf("unexpect size %v, want min %v, max %v", totalCount, min, max) + } + for _, a := range testAddrs { + var count int + for _, ii := range r.items { + if ii.sc.addr == a.Addr { + count++ + } + } + got := float64(count) / float64(totalCount) + want := float64(a.Metadata.(uint32)) / totalWeight + if !equalApproximately(got, want) { + t.Fatalf("unexpected item weight in ring: %v != %v", got, want) + } + } + }) + } + } +} + +func equalApproximately(x, y float64) bool { + delta := math.Abs(x - y) + mean := math.Abs(x+y) / 2.0 + return delta/mean < 0.25 +} + +func TestRingPick(t *testing.T) { + r, _ := newRing(map[resolver.Address]*subConn{ + {Addr: "a", Metadata: uint32(3)}: {addr: "a"}, + {Addr: "b", Metadata: uint32(3)}: {addr: "b"}, + {Addr: "c", Metadata: uint32(4)}: {addr: "c"}, + }, 10, 20) + for _, h := range []uint64{xxhash.Sum64String("1"), xxhash.Sum64String("2"), xxhash.Sum64String("3"), xxhash.Sum64String("4")} { + t.Run(fmt.Sprintf("picking-hash-%v", h), func(t *testing.T) { + e := r.pick(h) + var low uint64 + if e.idx > 0 { + low = r.items[e.idx-1].hash + } + high := e.hash + // h should be in [low, high). + if h < low || h >= high { + t.Fatalf("unexpected item picked, hash: %v, low: %v, high: %v", h, low, high) + } + }) + } +} + +func TestRingNext(t *testing.T) { + r, _ := newRing(map[resolver.Address]*subConn{ + {Addr: "a", Metadata: uint32(3)}: {addr: "a"}, + {Addr: "b", Metadata: uint32(3)}: {addr: "b"}, + {Addr: "c", Metadata: uint32(4)}: {addr: "c"}, + }, 10, 20) + + for _, e := range r.items { + ne := r.next(e) + if ne.idx != (e.idx+1)%len(r.items) { + t.Fatalf("next(%+v) returned unexpected %+v", e, ne) + } + } +} diff --git a/xds/internal/balancer/ringhash/ringhash.go b/xds/internal/balancer/ringhash/ringhash.go new file mode 100644 index 00000000000..f8a47f165bd --- /dev/null +++ b/xds/internal/balancer/ringhash/ringhash.go @@ -0,0 +1,434 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package ringhash implements the ringhash balancer. +package ringhash + +import ( + "encoding/json" + "errors" + "fmt" + "sync" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" +) + +// Name is the name of the ring_hash balancer. +const Name = "ring_hash_experimental" + +func init() { + balancer.Register(bb{}) +} + +type bb struct{} + +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { + b := &ringhashBalancer{ + cc: cc, + subConns: make(map[resolver.Address]*subConn), + scStates: make(map[balancer.SubConn]*subConn), + csEvltr: &connectivityStateEvaluator{}, + } + b.logger = prefixLogger(b) + b.logger.Infof("Created") + return b +} + +func (bb) Name() string { + return Name +} + +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + return parseConfig(c) +} + +type subConn struct { + addr string + sc balancer.SubConn + + mu sync.RWMutex + // This is the actual state of this SubConn (as updated by the ClientConn). + // The effective state can be different, see comment of attemptedToConnect. + state connectivity.State + // failing is whether this SubConn is in a failing state. A subConn is + // considered to be in a failing state if it was previously in + // TransientFailure. + // + // This affects the effective connectivity state of this SubConn, e.g. + // - if the actual state is Idle or Connecting, but this SubConn is failing, + // the effective state is TransientFailure. + // + // This is used in pick(). E.g. if a subConn is Idle, but has failing as + // true, pick() will + // - consider this SubConn as TransientFailure, and check the state of the + // next SubConn. + // - trigger Connect() (note that normally a SubConn in real + // TransientFailure cannot Connect()) + // + // A subConn starts in non-failing (failing is false). A transition to + // TransientFailure sets failing to true (and it stays true). A transition + // to Ready sets failing to false. + failing bool + // connectQueued is true if a Connect() was queued for this SubConn while + // it's not in Idle (most likely was in TransientFailure). A Connect() will + // be triggered on this SubConn when it turns Idle. + // + // When connectivity state is updated to Idle for this SubConn, if + // connectQueued is true, Connect() will be called on the SubConn. + connectQueued bool +} + +// setState updates the state of this SubConn. +// +// It also handles the queued Connect(). If the new state is Idle, and a +// Connect() was queued, this SubConn will be triggered to Connect(). +func (sc *subConn) setState(s connectivity.State) { + sc.mu.Lock() + defer sc.mu.Unlock() + switch s { + case connectivity.Idle: + // Trigger Connect() if new state is Idle, and there is a queued connect. + if sc.connectQueued { + sc.connectQueued = false + sc.sc.Connect() + } + case connectivity.Connecting: + // Clear connectQueued if the SubConn isn't failing. This state + // transition is unlikely to happen, but handle this just in case. + sc.connectQueued = false + case connectivity.Ready: + // Clear connectQueued if the SubConn isn't failing. This state + // transition is unlikely to happen, but handle this just in case. + sc.connectQueued = false + // Set to a non-failing state. + sc.failing = false + case connectivity.TransientFailure: + // Set to a failing state. + sc.failing = true + } + sc.state = s +} + +// effectiveState returns the effective state of this SubConn. It can be +// different from the actual state, e.g. Idle while the subConn is failing is +// considered TransientFailure. Read comment of field failing for other cases. +func (sc *subConn) effectiveState() connectivity.State { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.failing && (sc.state == connectivity.Idle || sc.state == connectivity.Connecting) { + return connectivity.TransientFailure + } + return sc.state +} + +// queueConnect sets a boolean so that when the SubConn state changes to Idle, +// it's Connect() will be triggered. If the SubConn state is already Idle, it +// will just call Connect(). +func (sc *subConn) queueConnect() { + sc.mu.Lock() + defer sc.mu.Unlock() + if sc.state == connectivity.Idle { + sc.sc.Connect() + return + } + // Queue this connect, and when this SubConn switches back to Idle (happens + // after backoff in TransientFailure), it will Connect(). + sc.connectQueued = true +} + +type ringhashBalancer struct { + cc balancer.ClientConn + logger *grpclog.PrefixLogger + + config *LBConfig + + subConns map[resolver.Address]*subConn // `attributes` is stripped from the keys of this map (the addresses) + scStates map[balancer.SubConn]*subConn + + // ring is always in sync with subConns. When subConns change, a new ring is + // generated. Note that address weights updates (they are keys in the + // subConns map) also regenerates the ring. + ring *ring + picker balancer.Picker + csEvltr *connectivityStateEvaluator + state connectivity.State + + resolverErr error // the last error reported by the resolver; cleared on successful resolution + connErr error // the last connection error; cleared upon leaving TransientFailure +} + +// updateAddresses creates new SubConns and removes SubConns, based on the +// address update. +// +// The return value is whether the new address list is different from the +// previous. True if +// - an address was added +// - an address was removed +// - an address's weight was updated +// +// Note that this function doesn't trigger SubConn connecting, so all the new +// SubConn states are Idle. +func (b *ringhashBalancer) updateAddresses(addrs []resolver.Address) bool { + var addrsUpdated bool + // addrsSet is the set converted from addrs, it's used for quick lookup of + // an address. + // + // Addresses in this map all have attributes stripped, but metadata set to + // the weight. So that weight change can be detected. + // + // TODO: this won't be necessary if there are ways to compare address + // attributes. + addrsSet := make(map[resolver.Address]struct{}) + for _, a := range addrs { + aNoAttrs := a + // Strip attributes but set Metadata to the weight. + aNoAttrs.Attributes = nil + w := weightedroundrobin.GetAddrInfo(a).Weight + if w == 0 { + // If weight is not set, use 1. + w = 1 + } + aNoAttrs.Metadata = w + addrsSet[aNoAttrs] = struct{}{} + if scInfo, ok := b.subConns[aNoAttrs]; !ok { + // When creating SubConn, the original address with attributes is + // passed through. So that connection configurations in attributes + // (like creds) will be used. + sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{HealthCheckEnabled: true}) + if err != nil { + logger.Warningf("base.baseBalancer: failed to create new SubConn: %v", err) + continue + } + scs := &subConn{addr: a.Addr, sc: sc} + scs.setState(connectivity.Idle) + b.state = b.csEvltr.recordTransition(connectivity.Shutdown, connectivity.Idle) + b.subConns[aNoAttrs] = scs + b.scStates[sc] = scs + addrsUpdated = true + } else { + // Always update the subconn's address in case the attributes + // changed. The SubConn does a reflect.DeepEqual of the new and old + // addresses. So this is a noop if the current address is the same + // as the old one (including attributes). + b.subConns[aNoAttrs] = scInfo + b.cc.UpdateAddresses(scInfo.sc, []resolver.Address{a}) + } + } + for a, scInfo := range b.subConns { + // a was removed by resolver. + if _, ok := addrsSet[a]; !ok { + b.cc.RemoveSubConn(scInfo.sc) + delete(b.subConns, a) + addrsUpdated = true + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in UpdateSubConnState. + } + } + return addrsUpdated +} + +func (b *ringhashBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) + if b.config == nil { + newConfig, ok := s.BalancerConfig.(*LBConfig) + if !ok { + return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) + } + b.config = newConfig + } + + // Successful resolution; clear resolver error and ensure we return nil. + b.resolverErr = nil + if b.updateAddresses(s.ResolverState.Addresses) { + // If addresses were updated, no matter whether it resulted in SubConn + // creation/deletion, or just weight update, we will need to regenerate + // the ring. + var err error + b.ring, err = newRing(b.subConns, b.config.MinRingSize, b.config.MaxRingSize) + if err != nil { + panic(err) + } + b.regeneratePicker() + b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) + } + + // If resolver state contains no addresses, return an error so ClientConn + // will trigger re-resolve. Also records this as an resolver error, so when + // the overall state turns transient failure, the error message will have + // the zero address information. + if len(s.ResolverState.Addresses) == 0 { + b.ResolverError(errors.New("produced zero addresses")) + return balancer.ErrBadResolverState + } + return nil +} + +func (b *ringhashBalancer) ResolverError(err error) { + b.resolverErr = err + if len(b.subConns) == 0 { + b.state = connectivity.TransientFailure + } + + if b.state != connectivity.TransientFailure { + // The picker will not change since the balancer does not currently + // report an error. + return + } + b.regeneratePicker() + b.cc.UpdateState(balancer.State{ + ConnectivityState: b.state, + Picker: b.picker, + }) +} + +// UpdateSubConnState updates the per-SubConn state stored in the ring, and also +// the aggregated state. +// +// It triggers an update to cc when: +// - the new state is TransientFailure, to update the error message +// - it's possible that this is a noop, but sending an extra update is easier +// than comparing errors +// - the aggregated state is changed +// - the same picker will be sent again, but this update may trigger a re-pick +// for some RPCs. +func (b *ringhashBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + s := state.ConnectivityState + b.logger.Infof("handle SubConn state change: %p, %v", sc, s) + scs, ok := b.scStates[sc] + if !ok { + b.logger.Infof("got state changes for an unknown SubConn: %p, %v", sc, s) + return + } + oldSCState := scs.effectiveState() + scs.setState(s) + newSCState := scs.effectiveState() + + var sendUpdate bool + oldBalancerState := b.state + b.state = b.csEvltr.recordTransition(oldSCState, newSCState) + if oldBalancerState != b.state { + sendUpdate = true + } + + switch s { + case connectivity.Idle: + // When the overall state is TransientFailure, this will never get picks + // if there's a lower priority. Need to keep the SubConns connecting so + // there's a chance it will recover. + if b.state == connectivity.TransientFailure { + scs.queueConnect() + } + // No need to send an update. No queued RPC can be unblocked. If the + // overall state changed because of this, sendUpdate is already true. + case connectivity.Connecting: + // No need to send an update. No queued RPC can be unblocked. If the + // overall state changed because of this, sendUpdate is already true. + case connectivity.Ready: + // Resend the picker, there's no need to regenerate the picker because + // the ring didn't change. + sendUpdate = true + case connectivity.TransientFailure: + // Save error to be reported via picker. + b.connErr = state.ConnectionError + // Regenerate picker to update error message. + b.regeneratePicker() + sendUpdate = true + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(b.scStates, sc) + } + + if sendUpdate { + b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) + } +} + +// mergeErrors builds an error from the last connection error and the last +// resolver error. Must only be called if b.state is TransientFailure. +func (b *ringhashBalancer) mergeErrors() error { + // connErr must always be non-nil unless there are no SubConns, in which + // case resolverErr must be non-nil. + if b.connErr == nil { + return fmt.Errorf("last resolver error: %v", b.resolverErr) + } + if b.resolverErr == nil { + return fmt.Errorf("last connection error: %v", b.connErr) + } + return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr) +} + +func (b *ringhashBalancer) regeneratePicker() { + if b.state == connectivity.TransientFailure { + b.picker = base.NewErrPicker(b.mergeErrors()) + return + } + b.picker = newPicker(b.ring, b.logger) +} + +func (b *ringhashBalancer) Close() {} + +// connectivityStateEvaluator takes the connectivity states of multiple SubConns +// and returns one aggregated connectivity state. +// +// It's not thread safe. +type connectivityStateEvaluator struct { + nums [5]uint64 +} + +// recordTransition records state change happening in subConn and based on that +// it evaluates what aggregated state should be. +// +// - If there is at least one subchannel in READY state, report READY. +// - If there are 2 or more subchannels in TRANSIENT_FAILURE state, report TRANSIENT_FAILURE. +// - If there is at least one subchannel in CONNECTING state, report CONNECTING. +// - If there is at least one subchannel in Idle state, report Idle. +// - Otherwise, report TRANSIENT_FAILURE. +// +// Note that if there are 1 connecting, 2 transient failure, the overall state +// is transient failure. This is because the second transient failure is a +// fallback of the first failing SubConn, and we want to report transient +// failure to failover to the lower priority. +func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State { + // Update counters. + for idx, state := range []connectivity.State{oldState, newState} { + updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. + cse.nums[state] += updateVal + } + + if cse.nums[connectivity.Ready] > 0 { + return connectivity.Ready + } + if cse.nums[connectivity.TransientFailure] > 1 { + return connectivity.TransientFailure + } + if cse.nums[connectivity.Connecting] > 0 { + return connectivity.Connecting + } + if cse.nums[connectivity.Idle] > 0 { + return connectivity.Idle + } + return connectivity.TransientFailure +} diff --git a/xds/internal/balancer/ringhash/ringhash_test.go b/xds/internal/balancer/ringhash/ringhash_test.go new file mode 100644 index 00000000000..fb85367e4a4 --- /dev/null +++ b/xds/internal/balancer/ringhash/ringhash_test.go @@ -0,0 +1,458 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/testutils" +) + +var ( + cmpOpts = cmp.Options{ + cmp.AllowUnexported(testutils.TestSubConn{}, ringEntry{}, subConn{}), + cmpopts.IgnoreFields(subConn{}, "mu"), + } +) + +const ( + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + + testBackendAddrsCount = 12 +) + +var ( + testBackendAddrStrs []string + testConfig = &LBConfig{MinRingSize: 1, MaxRingSize: 10} +) + +func init() { + for i := 0; i < testBackendAddrsCount; i++ { + testBackendAddrStrs = append(testBackendAddrStrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) + } +} + +func ctxWithHash(h uint64) context.Context { + return SetRequestHash(context.Background(), h) +} + +// setupTest creates the balancer, and does an initial sanity check. +func setupTest(t *testing.T, addrs []resolver.Address) (*testutils.TestClientConn, balancer.Balancer, balancer.Picker) { + t.Helper() + cc := testutils.NewTestClientConn(t) + builder := balancer.Get(Name) + b := builder.Build(cc, balancer.BuildOptions{}) + if b == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: addrs}, + BalancerConfig: testConfig, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + + for _, addr := range addrs { + addr1 := <-cc.NewSubConnAddrsCh + if want := []resolver.Address{addr}; !cmp.Equal(addr1, want, cmp.AllowUnexported(attributes.Attributes{})) { + t.Fatalf("got unexpected new subconn addrs: %v", cmp.Diff(addr1, want, cmp.AllowUnexported(attributes.Attributes{}))) + } + sc1 := <-cc.NewSubConnCh + // All the SubConns start in Idle, and should not Connect(). + select { + case <-sc1.(*testutils.TestSubConn).ConnectCh: + t.Errorf("unexpected Connect() from SubConn %v", sc1) + case <-time.After(defaultTestShortTimeout): + } + } + + // Should also have a picker, with all SubConns in Idle. + p1 := <-cc.NewPickerCh + return cc, b, p1 +} + +func TestOneSubConn(t *testing.T) { + wantAddr1 := resolver.Address{Addr: testBackendAddrStrs[0]} + cc, b, p0 := setupTest(t, []resolver.Address{wantAddr1}) + ring0 := p0.(*picker).ring + + firstHash := ring0.items[0].hash + // firstHash-1 will pick the first (and only) SubConn from the ring. + testHash := firstHash - 1 + // The first pick should be queued, and should trigger Connect() on the only + // SubConn. + if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + sc0 := ring0.items[0].sc.sc + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // Send state updates to Ready. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test pick with one backend. + p1 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } +} + +// TestThreeBackendsAffinity covers that there are 3 SubConns, RPCs with the +// same hash always pick the same SubConn. When the one picked is down, another +// one will be picked. +func TestThreeSubConnsAffinity(t *testing.T) { + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + {Addr: testBackendAddrStrs[2]}, + } + cc, b, p0 := setupTest(t, wantAddrs) + // This test doesn't update addresses, so this ring will be used by all the + // pickers. + ring0 := p0.(*picker).ring + + firstHash := ring0.items[0].hash + // firstHash+1 will pick the second SubConn from the ring. + testHash := firstHash + 1 + // The first pick should be queued, and should trigger Connect() on the only + // SubConn. + if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + // The picked SubConn should be the second in the ring. + sc0 := ring0.items[1].sc.sc + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // Send state updates to Ready. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + p1 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } + + // Turn down the subConn in use. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + p2 := <-cc.NewPickerCh + // Pick with the same hash should be queued, because the SubConn after the + // first picked is Idle. + if _, err := p2.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + + // The third SubConn in the ring should connect. + sc1 := ring0.items[2].sc.sc + select { + case <-sc1.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc1) + } + + // Send state updates to Ready. + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // New picks should all return this SubConn. + p3 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p3.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) + } + } + + // Now, after backoff, the first picked SubConn will turn Idle. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + // The picks above should have queued Connect() for the first picked + // SubConn, so this Idle state change will trigger a Connect(). + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // After the first picked SubConn turn Ready, new picks should return it + // again (even though the second picked SubConn is also Ready). + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + p4 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p4.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } +} + +// TestThreeBackendsAffinity covers that there are 3 SubConns, RPCs with the +// same hash always pick the same SubConn. Then try different hash to pick +// another backend, and verify the first hash still picks the first backend. +func TestThreeSubConnsAffinityMultiple(t *testing.T) { + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + {Addr: testBackendAddrStrs[2]}, + } + cc, b, p0 := setupTest(t, wantAddrs) + // This test doesn't update addresses, so this ring will be used by all the + // pickers. + ring0 := p0.(*picker).ring + + firstHash := ring0.items[0].hash + // firstHash+1 will pick the second SubConn from the ring. + testHash := firstHash + 1 + // The first pick should be queued, and should trigger Connect() on the only + // SubConn. + if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + sc0 := ring0.items[1].sc.sc + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // Send state updates to Ready. + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // First hash should always pick sc0. + p1 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } + + secondHash := ring0.items[1].hash + // secondHash+1 will pick the third SubConn from the ring. + testHash2 := secondHash + 1 + if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash2)}); err != balancer.ErrNoSubConnAvailable { + t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) + } + sc1 := ring0.items[2].sc.sc + select { + case <-sc1.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc1) + } + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // With the new generated picker, hash2 always picks sc1. + p2 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash2)}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) + } + } + // But the first hash still picks sc0. + for i := 0; i < 5; i++ { + gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) + if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) + } + } +} + +func TestAddrWeightChange(t *testing.T) { + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + {Addr: testBackendAddrStrs[2]}, + } + cc, b, p0 := setupTest(t, wantAddrs) + ring0 := p0.(*picker).ring + + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: wantAddrs}, + BalancerConfig: nil, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + select { + case <-cc.NewPickerCh: + t.Fatalf("unexpected picker after UpdateClientConn with the same addresses") + case <-time.After(defaultTestShortTimeout): + } + + // Delete an address, should send a new Picker. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + }}, + BalancerConfig: nil, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + var p1 balancer.Picker + select { + case p1 = <-cc.NewPickerCh: + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for picker after UpdateClientConn with different addresses") + } + ring1 := p1.(*picker).ring + if ring1 == ring0 { + t.Fatalf("new picker after removing address has the same ring as before, want different") + } + + // Another update with the same addresses, but different weight. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + weightedroundrobin.SetAddrInfo( + resolver.Address{Addr: testBackendAddrStrs[1]}, + weightedroundrobin.AddrInfo{Weight: 2}), + }}, + BalancerConfig: nil, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + var p2 balancer.Picker + select { + case p2 = <-cc.NewPickerCh: + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for picker after UpdateClientConn with different addresses") + } + if p2.(*picker).ring == ring1 { + t.Fatalf("new picker after changing address weight has the same ring as before, want different") + } +} + +// TestSubConnToConnectWhenOverallTransientFailure covers the situation when the +// overall state is TransientFailure, the SubConns turning Idle will be +// triggered to Connect(). But not when the overall state is not +// TransientFailure. +func TestSubConnToConnectWhenOverallTransientFailure(t *testing.T) { + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + {Addr: testBackendAddrStrs[2]}, + } + _, b, p0 := setupTest(t, wantAddrs) + ring0 := p0.(*picker).ring + + // Turn all SubConns to TransientFailure. + for _, it := range ring0.items { + b.UpdateSubConnState(it.sc.sc, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + } + + // The next one turning Idle should Connect(). + sc0 := ring0.items[0].sc.sc + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + select { + case <-sc0.(*testutils.TestSubConn).ConnectCh: + case <-time.After(defaultTestTimeout): + t.Errorf("timeout waiting for Connect() from SubConn %v", sc0) + } + + // If this SubConn is ready. Other SubConns turning Idle will not Connect(). + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + b.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // The third SubConn in the ring should connect. + sc1 := ring0.items[1].sc.sc + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + select { + case <-sc1.(*testutils.TestSubConn).ConnectCh: + t.Errorf("unexpected Connect() from SubConn %v", sc1) + case <-time.After(defaultTestShortTimeout): + } +} + +func TestConnectivityStateEvaluatorRecordTransition(t *testing.T) { + tests := []struct { + name string + from, to []connectivity.State + want connectivity.State + }{ + { + name: "one ready", + from: []connectivity.State{connectivity.Idle}, + to: []connectivity.State{connectivity.Ready}, + want: connectivity.Ready, + }, + { + name: "one connecting", + from: []connectivity.State{connectivity.Idle}, + to: []connectivity.State{connectivity.Connecting}, + want: connectivity.Connecting, + }, + { + name: "one ready one transient failure", + from: []connectivity.State{connectivity.Idle, connectivity.Idle}, + to: []connectivity.State{connectivity.Ready, connectivity.TransientFailure}, + want: connectivity.Ready, + }, + { + name: "one connecting one transient failure", + from: []connectivity.State{connectivity.Idle, connectivity.Idle}, + to: []connectivity.State{connectivity.Connecting, connectivity.TransientFailure}, + want: connectivity.Connecting, + }, + { + name: "one connecting two transient failure", + from: []connectivity.State{connectivity.Idle, connectivity.Idle, connectivity.Idle}, + to: []connectivity.State{connectivity.Connecting, connectivity.TransientFailure, connectivity.TransientFailure}, + want: connectivity.TransientFailure, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cse := &connectivityStateEvaluator{} + var got connectivity.State + for i, fff := range tt.from { + ttt := tt.to[i] + got = cse.recordTransition(fff, ttt) + } + if got != tt.want { + t.Errorf("recordTransition() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/xds/internal/balancer/ringhash/util.go b/xds/internal/balancer/ringhash/util.go new file mode 100644 index 00000000000..92bb3ae5b79 --- /dev/null +++ b/xds/internal/balancer/ringhash/util.go @@ -0,0 +1,40 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ringhash + +import "context" + +type clusterKey struct{} + +func getRequestHash(ctx context.Context) uint64 { + requestHash, _ := ctx.Value(clusterKey{}).(uint64) + return requestHash +} + +// GetRequestHashForTesting returns the request hash in the context; to be used +// for testing only. +func GetRequestHashForTesting(ctx context.Context) uint64 { + return getRequestHash(ctx) +} + +// SetRequestHash adds the request hash to the context for use in Ring Hash Load +// Balancing. +func SetRequestHash(ctx context.Context, requestHash uint64) context.Context { + return context.WithValue(ctx, clusterKey{}, requestHash) +} diff --git a/xds/internal/balancer/weightedtarget/weightedaggregator/aggregator.go b/xds/internal/balancer/weightedtarget/weightedaggregator/aggregator.go index 6c36e2a69cd..7e1d106e9ff 100644 --- a/xds/internal/balancer/weightedtarget/weightedaggregator/aggregator.go +++ b/xds/internal/balancer/weightedtarget/weightedaggregator/aggregator.go @@ -200,7 +200,9 @@ func (wbsa *Aggregator) BuildAndUpdate() { func (wbsa *Aggregator) build() balancer.State { wbsa.logger.Infof("Child pickers with config: %+v", wbsa.idToPickerState) m := wbsa.idToPickerState - var readyN, connectingN int + // TODO: use balancer.ConnectivityStateEvaluator to calculate the aggregated + // state. + var readyN, connectingN, idleN int readyPickerWithWeights := make([]weightedPickerState, 0, len(m)) for _, ps := range m { switch ps.stateToAggregate { @@ -209,6 +211,8 @@ func (wbsa *Aggregator) build() balancer.State { readyPickerWithWeights = append(readyPickerWithWeights, *ps) case connectivity.Connecting: connectingN++ + case connectivity.Idle: + idleN++ } } var aggregatedState connectivity.State @@ -217,6 +221,8 @@ func (wbsa *Aggregator) build() balancer.State { aggregatedState = connectivity.Ready case connectingN > 0: aggregatedState = connectivity.Connecting + case idleN > 0: + aggregatedState = connectivity.Idle default: aggregatedState = connectivity.TransientFailure } diff --git a/xds/internal/balancer/weightedtarget/weightedtarget.go b/xds/internal/balancer/weightedtarget/weightedtarget.go index 02b199258cd..f05e0aca19f 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget.go @@ -26,6 +26,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/hierarchy" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/internal/wrr" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" @@ -33,22 +34,23 @@ import ( "google.golang.org/grpc/xds/internal/balancer/weightedtarget/weightedaggregator" ) -const weightedTargetName = "weighted_target_experimental" +// Name is the name of the weighted_target balancer. +const Name = "weighted_target_experimental" -// newRandomWRR is the WRR constructor used to pick sub-pickers from +// NewRandomWRR is the WRR constructor used to pick sub-pickers from // sub-balancers. It's to be modified in tests. -var newRandomWRR = wrr.NewRandom +var NewRandomWRR = wrr.NewRandom func init() { - balancer.Register(&weightedTargetBB{}) + balancer.Register(bb{}) } -type weightedTargetBB struct{} +type bb struct{} -func (wt *weightedTargetBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &weightedTargetBalancer{} b.logger = prefixLogger(b) - b.stateAggregator = weightedaggregator.New(cc, b.logger, newRandomWRR) + b.stateAggregator = weightedaggregator.New(cc, b.logger, NewRandomWRR) b.stateAggregator.Start() b.bg = balancergroup.New(cc, bOpts, b.stateAggregator, nil, b.logger) b.bg.Start() @@ -56,11 +58,11 @@ func (wt *weightedTargetBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOp return b } -func (wt *weightedTargetBB) Name() string { - return weightedTargetName +func (bb) Name() string { + return Name } -func (wt *weightedTargetBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(c) } @@ -75,14 +77,15 @@ type weightedTargetBalancer struct { bg *balancergroup.BalancerGroup stateAggregator *weightedaggregator.Aggregator - targets map[string]target + targets map[string]Target } // UpdateClientConnState takes the new targets in balancer group, -// creates/deletes sub-balancers and sends them update. Addresses are split into +// creates/deletes sub-balancers and sends them update. addresses are split into // groups based on hierarchy path. -func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - newConfig, ok := s.BalancerConfig.(*lbConfig) +func (b *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) + newConfig, ok := s.BalancerConfig.(*LBConfig) if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) } @@ -91,10 +94,10 @@ func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat var rebuildStateAndPicker bool // Remove sub-pickers and sub-balancers that are not in the new config. - for name := range w.targets { + for name := range b.targets { if _, ok := newConfig.Targets[name]; !ok { - w.stateAggregator.Remove(name) - w.bg.Remove(name) + b.stateAggregator.Remove(name) + b.bg.Remove(name) // Trigger a state/picker update, because we don't want `ClientConn` // to pick this sub-balancer anymore. rebuildStateAndPicker = true @@ -107,29 +110,39 @@ func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat // // For all sub-balancers, forward the address/balancer config update. for name, newT := range newConfig.Targets { - oldT, ok := w.targets[name] + oldT, ok := b.targets[name] if !ok { // If this is a new sub-balancer, add weights to the picker map. - w.stateAggregator.Add(name, newT.Weight) + b.stateAggregator.Add(name, newT.Weight) // Then add to the balancer group. - w.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) + b.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) // Not trigger a state/picker update. Wait for the new sub-balancer // to send its updates. + } else if newT.ChildPolicy.Name != oldT.ChildPolicy.Name { + // If the child policy name is differet, remove from balancer group + // and re-add. + b.stateAggregator.Remove(name) + b.bg.Remove(name) + b.stateAggregator.Add(name, newT.Weight) + b.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) + // Trigger a state/picker update, because we don't want `ClientConn` + // to pick this sub-balancer anymore. + rebuildStateAndPicker = true } else if newT.Weight != oldT.Weight { // If this is an existing sub-balancer, update weight if necessary. - w.stateAggregator.UpdateWeight(name, newT.Weight) + b.stateAggregator.UpdateWeight(name, newT.Weight) // Trigger a state/picker update, because we don't want `ClientConn` // should do picks with the new weights now. rebuildStateAndPicker = true } // Forwards all the update: - // - Addresses are from the map after splitting with hierarchy path, + // - addresses are from the map after splitting with hierarchy path, // - Top level service config and attributes are the same, // - Balancer config comes from the targets map. // // TODO: handle error? How to aggregate errors and return? - _ = w.bg.UpdateClientConnState(name, balancer.ClientConnState{ + _ = b.bg.UpdateClientConnState(name, balancer.ClientConnState{ ResolverState: resolver.State{ Addresses: addressesSplit[name], ServiceConfig: s.ResolverState.ServiceConfig, @@ -139,23 +152,27 @@ func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat }) } - w.targets = newConfig.Targets + b.targets = newConfig.Targets if rebuildStateAndPicker { - w.stateAggregator.BuildAndUpdate() + b.stateAggregator.BuildAndUpdate() } return nil } -func (w *weightedTargetBalancer) ResolverError(err error) { - w.bg.ResolverError(err) +func (b *weightedTargetBalancer) ResolverError(err error) { + b.bg.ResolverError(err) +} + +func (b *weightedTargetBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + b.bg.UpdateSubConnState(sc, state) } -func (w *weightedTargetBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { - w.bg.UpdateSubConnState(sc, state) +func (b *weightedTargetBalancer) Close() { + b.stateAggregator.Stop() + b.bg.Close() } -func (w *weightedTargetBalancer) Close() { - w.stateAggregator.Stop() - w.bg.Close() +func (b *weightedTargetBalancer) ExitIdle() { + b.bg.ExitIdle() } diff --git a/xds/internal/balancer/weightedtarget/weightedtarget_config.go b/xds/internal/balancer/weightedtarget/weightedtarget_config.go index 747ce918bc6..52090cd67b0 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget_config.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget_config.go @@ -25,30 +25,23 @@ import ( "google.golang.org/grpc/serviceconfig" ) -type target struct { +// Target represents one target with the weight and the child policy. +type Target struct { // Weight is the weight of the child policy. - Weight uint32 + Weight uint32 `json:"weight,omitempty"` // ChildPolicy is the child policy and it's config. - ChildPolicy *internalserviceconfig.BalancerConfig + ChildPolicy *internalserviceconfig.BalancerConfig `json:"childPolicy,omitempty"` } -// lbConfig is the balancer config for weighted_target. The proto representation -// is: -// -// message WeightedTargetConfig { -// message Target { -// uint32 weight = 1; -// repeated LoadBalancingConfig child_policy = 2; -// } -// map targets = 1; -// } -type lbConfig struct { - serviceconfig.LoadBalancingConfig - Targets map[string]target +// LBConfig is the balancer config for weighted_target. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + + Targets map[string]Target `json:"targets,omitempty"` } -func parseConfig(c json.RawMessage) (*lbConfig, error) { - var cfg lbConfig +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, err } diff --git a/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go b/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go index 2208117f60e..c239a3ae5a4 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go @@ -24,7 +24,7 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/balancer" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" + "google.golang.org/grpc/xds/internal/balancer/priority" ) const ( @@ -32,31 +32,29 @@ const ( "targets": { "cluster_1" : { "weight":75, - "childPolicy":[{"cds_experimental":{"cluster":"cluster_1"}}] + "childPolicy":[{"priority_experimental":{"priorities": ["child-1"], "children": {"child-1": {"config": [{"round_robin":{}}]}}}}] }, "cluster_2" : { "weight":25, - "childPolicy":[{"cds_experimental":{"cluster":"cluster_2"}}] + "childPolicy":[{"priority_experimental":{"priorities": ["child-2"], "children": {"child-2": {"config": [{"round_robin":{}}]}}}}] } } }` - - cdsName = "cds_experimental" ) var ( - cdsConfigParser = balancer.Get(cdsName).(balancer.ConfigParser) - cdsConfigJSON1 = `{"cluster":"cluster_1"}` - cdsConfig1, _ = cdsConfigParser.ParseConfig([]byte(cdsConfigJSON1)) - cdsConfigJSON2 = `{"cluster":"cluster_2"}` - cdsConfig2, _ = cdsConfigParser.ParseConfig([]byte(cdsConfigJSON2)) + testConfigParser = balancer.Get(priority.Name).(balancer.ConfigParser) + testConfigJSON1 = `{"priorities": ["child-1"], "children": {"child-1": {"config": [{"round_robin":{}}]}}}` + testConfig1, _ = testConfigParser.ParseConfig([]byte(testConfigJSON1)) + testConfigJSON2 = `{"priorities": ["child-2"], "children": {"child-2": {"config": [{"round_robin":{}}]}}}` + testConfig2, _ = testConfigParser.ParseConfig([]byte(testConfigJSON2)) ) func Test_parseConfig(t *testing.T) { tests := []struct { name string js string - want *lbConfig + want *LBConfig wantErr bool }{ { @@ -68,20 +66,20 @@ func Test_parseConfig(t *testing.T) { { name: "OK", js: testJSONConfig, - want: &lbConfig{ - Targets: map[string]target{ + want: &LBConfig{ + Targets: map[string]Target{ "cluster_1": { Weight: 75, ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: cdsName, - Config: cdsConfig1, + Name: priority.Name, + Config: testConfig1, }, }, "cluster_2": { Weight: 25, ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: cdsName, - Config: cdsConfig2, + Name: priority.Name, + Config: testConfig2, }, }, }, diff --git a/xds/internal/balancer/weightedtarget/weightedtarget_test.go b/xds/internal/balancer/weightedtarget/weightedtarget_test.go index 7f9e566ca5b..b0e4df89588 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget_test.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget_test.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/hierarchy" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" @@ -103,7 +104,7 @@ func init() { for i := 0; i < testBackendAddrsCount; i++ { testBackendAddrStrs = append(testBackendAddrStrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) } - wtbBuilder = balancer.Get(weightedTargetName) + wtbBuilder = balancer.Get(Name) wtbParser = wtbBuilder.(balancer.ConfigParser) balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond @@ -215,6 +216,46 @@ func TestWeightedTarget(t *testing.T) { if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { t.Fatalf("want %v, got %v", want, err) } + + // Replace child policy of "cluster_1" to "round_robin". + config3, err := wtbParser.ParseConfig([]byte(`{"targets":{"cluster_2":{"weight":1,"childPolicy":[{"round_robin":""}]}}}`)) + if err != nil { + t.Fatalf("failed to parse balancer config: %v", err) + } + + // Send the config, and an address with hierarchy path ["cluster_1"]. + wantAddr4 := resolver.Address{Addr: testBackendAddrStrs[0], Attributes: nil} + if err := wtb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + hierarchy.Set(wantAddr4, []string{"cluster_2"}), + }}, + BalancerConfig: config3, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + // Verify that a subconn is created with the address, and the hierarchy path + // in the address is cleared. + addr4 := <-cc.NewSubConnAddrsCh + if want := []resolver.Address{ + hierarchy.Set(wantAddr4, []string{}), + }; !cmp.Equal(addr4, want, cmp.AllowUnexported(attributes.Attributes{})) { + t.Fatalf("got unexpected new subconn addrs: %v", cmp.Diff(addr4, want, cmp.AllowUnexported(attributes.Attributes{}))) + } + + // Send subconn state change. + sc4 := <-cc.NewSubConnCh + wtb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + wtb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test pick with one backend. + p3 := <-cc.NewPickerCh + for i := 0; i < 5; i++ { + gotSCSt, _ := p3.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc4, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc4) + } + } } func subConnFromPicker(p balancer.Picker) func() balancer.SubConn { @@ -223,3 +264,63 @@ func subConnFromPicker(p balancer.Picker) func() balancer.SubConn { return scst.SubConn } } + +const initIdleBalancerName = "test-init-Idle-balancer" + +var errTestInitIdle = fmt.Errorf("init Idle balancer error 0") + +func init() { + stub.Register(initIdleBalancerName, stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, opts balancer.ClientConnState) error { + bd.ClientConn.NewSubConn(opts.ResolverState.Addresses, balancer.NewSubConnOptions{}) + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { + err := fmt.Errorf("wrong picker error") + if state.ConnectivityState == connectivity.Idle { + err = errTestInitIdle + } + bd.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: &testutils.TestConstPicker{Err: err}, + }) + }, + }) +} + +// TestInitialIdle covers the case that if the child reports Idle, the overall +// state will be Idle. +func TestInitialIdle(t *testing.T) { + cc := testutils.NewTestClientConn(t) + wtb := wtbBuilder.Build(cc, balancer.BuildOptions{}) + + // Start with "cluster_1: round_robin". + config1, err := wtbParser.ParseConfig([]byte(`{"targets":{"cluster_1":{"weight":1,"childPolicy":[{"test-init-Idle-balancer":""}]}}}`)) + if err != nil { + t.Fatalf("failed to parse balancer config: %v", err) + } + + // Send the config, and an address with hierarchy path ["cluster_1"]. + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0], Attributes: nil}, + } + if err := wtb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{ + hierarchy.Set(wantAddrs[0], []string{"cds:cluster_1"}), + }}, + BalancerConfig: config1, + }); err != nil { + t.Fatalf("failed to update ClientConn state: %v", err) + } + + // Verify that a subconn is created with the address, and the hierarchy path + // in the address is cleared. + for range wantAddrs { + sc := <-cc.NewSubConnCh + wtb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Idle}) + } + + if state1 := <-cc.NewStateCh; state1 != connectivity.Idle { + t.Fatalf("Received aggregated state: %v, want Idle", state1) + } +} diff --git a/xds/internal/client/cds_test.go b/xds/internal/client/cds_test.go deleted file mode 100644 index c5f1d76d32c..00000000000 --- a/xds/internal/client/cds_test.go +++ /dev/null @@ -1,833 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "testing" - - v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" - v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" - v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" - "github.com/golang/protobuf/proto" - anypb "github.com/golang/protobuf/ptypes/any" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/version" - "google.golang.org/protobuf/types/known/wrapperspb" -) - -const ( - clusterName = "clusterName" - serviceName = "service" -) - -var emptyUpdate = ClusterUpdate{ServiceName: "", EnableLRS: false} - -func (s) TestValidateCluster_Failure(t *testing.T) { - tests := []struct { - name string - cluster *v3clusterpb.Cluster - wantUpdate ClusterUpdate - wantErr bool - }{ - { - name: "non-eds-cluster-type", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - }, - LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, - }, - wantUpdate: emptyUpdate, - wantErr: true, - }, - { - name: "no-eds-config", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - }, - wantUpdate: emptyUpdate, - wantErr: true, - }, - { - name: "no-ads-config-source", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{}, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - }, - wantUpdate: emptyUpdate, - wantErr: true, - }, - { - name: "non-round-robin-lb-policy", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - }, - LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, - }, - wantUpdate: emptyUpdate, - wantErr: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if update, err := validateCluster(test.cluster); err == nil { - t.Errorf("validateCluster(%+v) = %v, wanted error", test.cluster, update) - } - }) - } -} - -func (s) TestValidateCluster_Success(t *testing.T) { - tests := []struct { - name string - cluster *v3clusterpb.Cluster - wantUpdate ClusterUpdate - }{ - { - name: "happy-case-no-service-name-no-lrs", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - }, - wantUpdate: emptyUpdate, - }, - { - name: "happy-case-no-lrs", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - }, - wantUpdate: ClusterUpdate{ServiceName: serviceName, EnableLRS: false}, - }, - { - name: "happiest-case", - cluster: &v3clusterpb.Cluster{ - Name: clusterName, - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - LrsServer: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ - Self: &v3corepb.SelfConfigSource{}, - }, - }, - }, - wantUpdate: ClusterUpdate{ServiceName: serviceName, EnableLRS: true}, - }, - { - name: "happiest-case-with-circuitbreakers", - cluster: &v3clusterpb.Cluster{ - Name: clusterName, - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - CircuitBreakers: &v3clusterpb.CircuitBreakers{ - Thresholds: []*v3clusterpb.CircuitBreakers_Thresholds{ - { - Priority: v3corepb.RoutingPriority_DEFAULT, - MaxRequests: wrapperspb.UInt32(512), - }, - { - Priority: v3corepb.RoutingPriority_HIGH, - MaxRequests: nil, - }, - }, - }, - LrsServer: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ - Self: &v3corepb.SelfConfigSource{}, - }, - }, - }, - wantUpdate: ClusterUpdate{ServiceName: serviceName, EnableLRS: true, MaxRequests: func() *uint32 { i := uint32(512); return &i }()}, - }, - } - - origCircuitBreakingSupport := env.CircuitBreakingSupport - env.CircuitBreakingSupport = true - defer func() { env.CircuitBreakingSupport = origCircuitBreakingSupport }() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - update, err := validateCluster(test.cluster) - if err != nil { - t.Errorf("validateCluster(%+v) failed: %v", test.cluster, err) - } - if !cmp.Equal(update, test.wantUpdate, cmpopts.EquateEmpty()) { - t.Errorf("validateCluster(%+v) = %v, want: %v", test.cluster, update, test.wantUpdate) - } - }) - } -} - -func (s) TestValidateClusterWithSecurityConfig_EnvVarOff(t *testing.T) { - // Turn off the env var protection for client-side security. - origClientSideSecurityEnvVar := env.ClientSideSecuritySupport - env.ClientSideSecuritySupport = false - defer func() { env.ClientSideSecuritySupport = origClientSideSecurityEnvVar }() - - cluster := &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "rootInstance", - CertificateName: "rootCert", - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - } - wantUpdate := ClusterUpdate{ - ServiceName: serviceName, - EnableLRS: false, - } - gotUpdate, err := validateCluster(cluster) - if err != nil { - t.Errorf("validateCluster() failed: %v", err) - } - if diff := cmp.Diff(wantUpdate, gotUpdate); diff != "" { - t.Errorf("validateCluster() returned unexpected diff (-want, got):\n%s", diff) - } -} - -func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { - // Turn on the env var protection for client-side security. - origClientSideSecurityEnvVar := env.ClientSideSecuritySupport - env.ClientSideSecuritySupport = true - defer func() { env.ClientSideSecuritySupport = origClientSideSecurityEnvVar }() - - const ( - identityPluginInstance = "identityPluginInstance" - identityCertName = "identityCert" - rootPluginInstance = "rootPluginInstance" - rootCertName = "rootCert" - serviceName = "service" - san1 = "san1" - san2 = "san2" - ) - - tests := []struct { - name string - cluster *v3clusterpb.Cluster - wantUpdate ClusterUpdate - wantErr bool - }{ - { - name: "transport-socket-unsupported-name", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "unsupported-foo", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - }, - }, - }, - }, - wantErr: true, - }, - { - name: "transport-socket-unsupported-typeURL", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3HTTPConnManagerURL, - }, - }, - }, - }, - wantErr: true, - }, - { - name: "transport-socket-unsupported-type", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: []byte{1, 2, 3, 4}, - }, - }, - }, - }, - wantErr: true, - }, - { - name: "transport-socket-unsupported-validation-context", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ - ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ - Name: "foo-sds-secret", - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantErr: true, - }, - { - name: "transport-socket-without-validation-context", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{}, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantErr: true, - }, - { - name: "happy-case-with-no-identity-certs", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: rootPluginInstance, - CertificateName: rootCertName, - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantUpdate: ClusterUpdate{ - ServiceName: serviceName, - EnableLRS: false, - SecurityCfg: &SecurityConfig{ - RootInstanceName: rootPluginInstance, - RootCertName: rootCertName, - }, - }, - }, - { - name: "happy-case-with-validation-context-provider-instance", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: identityPluginInstance, - CertificateName: identityCertName, - }, - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: rootPluginInstance, - CertificateName: rootCertName, - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantUpdate: ClusterUpdate{ - ServiceName: serviceName, - EnableLRS: false, - SecurityCfg: &SecurityConfig{ - RootInstanceName: rootPluginInstance, - RootCertName: rootCertName, - IdentityInstanceName: identityPluginInstance, - IdentityCertName: identityCertName, - }, - }, - }, - { - name: "happy-case-with-combined-validation-context", - cluster: &v3clusterpb.Cluster{ - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: serviceName, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.UpstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: identityPluginInstance, - CertificateName: identityCertName, - }, - ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ - CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ - DefaultValidationContext: &v3tlspb.CertificateValidationContext{ - MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ - {MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: san1}}, - {MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: san2}}, - }, - }, - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: rootPluginInstance, - CertificateName: rootCertName, - }, - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - wantUpdate: ClusterUpdate{ - ServiceName: serviceName, - EnableLRS: false, - SecurityCfg: &SecurityConfig{ - RootInstanceName: rootPluginInstance, - RootCertName: rootCertName, - IdentityInstanceName: identityPluginInstance, - IdentityCertName: identityCertName, - AcceptedSANs: []string{san1, san2}, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - update, err := validateCluster(test.cluster) - if ((err != nil) != test.wantErr) || !cmp.Equal(update, test.wantUpdate, cmpopts.EquateEmpty()) { - t.Errorf("validateCluster(%+v) = (%+v, %v), want: (%+v, %v)", test.cluster, update, err, test.wantUpdate, test.wantErr) - } - }) - } -} - -func (s) TestUnmarshalCluster(t *testing.T) { - const ( - v2ClusterName = "v2clusterName" - v3ClusterName = "v3clusterName" - v2Service = "v2Service" - v3Service = "v2Service" - ) - var ( - v2Cluster = &v2xdspb.Cluster{ - Name: v2ClusterName, - ClusterDiscoveryType: &v2xdspb.Cluster_Type{Type: v2xdspb.Cluster_EDS}, - EdsClusterConfig: &v2xdspb.Cluster_EdsClusterConfig{ - EdsConfig: &v2corepb.ConfigSource{ - ConfigSourceSpecifier: &v2corepb.ConfigSource_Ads{ - Ads: &v2corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: v2Service, - }, - LbPolicy: v2xdspb.Cluster_ROUND_ROBIN, - LrsServer: &v2corepb.ConfigSource{ - ConfigSourceSpecifier: &v2corepb.ConfigSource_Self{ - Self: &v2corepb.SelfConfigSource{}, - }, - }, - } - v2ClusterAny = &anypb.Any{ - TypeUrl: version.V2ClusterURL, - Value: func() []byte { - mcl, _ := proto.Marshal(v2Cluster) - return mcl - }(), - } - - v3Cluster = &v3clusterpb.Cluster{ - Name: v3ClusterName, - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, - EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ - EdsConfig: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ - Ads: &v3corepb.AggregatedConfigSource{}, - }, - }, - ServiceName: v3Service, - }, - LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, - LrsServer: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ - Self: &v3corepb.SelfConfigSource{}, - }, - }, - } - v3ClusterAny = &anypb.Any{ - TypeUrl: version.V3ClusterURL, - Value: func() []byte { - mcl, _ := proto.Marshal(v3Cluster) - return mcl - }(), - } - ) - const testVersion = "test-version-cds" - - tests := []struct { - name string - resources []*anypb.Any - wantUpdate map[string]ClusterUpdate - wantMD UpdateMetadata - wantErr bool - }{ - { - name: "non-cluster resource type", - resources: []*anypb.Any{{TypeUrl: version.V3HTTPConnManagerURL}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "badly marshaled cluster resource", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ClusterURL, - Value: []byte{1, 2, 3, 4}, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "bad cluster resource", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ClusterURL, - Value: func() []byte { - cl := &v3clusterpb.Cluster{ - Name: "test", - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, - } - mcl, _ := proto.Marshal(cl) - return mcl - }(), - }, - }, - wantUpdate: map[string]ClusterUpdate{"test": {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v2 cluster", - resources: []*anypb.Any{v2ClusterAny}, - wantUpdate: map[string]ClusterUpdate{ - v2ClusterName: { - ServiceName: v2Service, EnableLRS: true, - Raw: v2ClusterAny, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 cluster", - resources: []*anypb.Any{v3ClusterAny}, - wantUpdate: map[string]ClusterUpdate{ - v3ClusterName: { - ServiceName: v3Service, EnableLRS: true, - Raw: v3ClusterAny, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "multiple clusters", - resources: []*anypb.Any{v2ClusterAny, v3ClusterAny}, - wantUpdate: map[string]ClusterUpdate{ - v2ClusterName: { - ServiceName: v2Service, EnableLRS: true, - Raw: v2ClusterAny, - }, - v3ClusterName: { - ServiceName: v3Service, EnableLRS: true, - Raw: v3ClusterAny, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - // To test that unmarshal keeps processing on errors. - name: "good and bad clusters", - resources: []*anypb.Any{ - v2ClusterAny, - { - // bad cluster resource - TypeUrl: version.V3ClusterURL, - Value: func() []byte { - cl := &v3clusterpb.Cluster{ - Name: "bad", - ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, - } - mcl, _ := proto.Marshal(cl) - return mcl - }(), - }, - v3ClusterAny, - }, - wantUpdate: map[string]ClusterUpdate{ - v2ClusterName: { - ServiceName: v2Service, EnableLRS: true, - Raw: v2ClusterAny, - }, - v3ClusterName: { - ServiceName: v3Service, EnableLRS: true, - Raw: v3ClusterAny, - }, - "bad": {}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - update, md, err := UnmarshalCluster(testVersion, test.resources, nil) - if (err != nil) != test.wantErr { - t.Fatalf("UnmarshalCluster(), got err: %v, wantErr: %v", err, test.wantErr) - } - if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { - t.Errorf("got unexpected update, diff (-got +want): %v", diff) - } - if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { - t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) - } - }) - } -} diff --git a/xds/internal/client/lds_test.go b/xds/internal/client/lds_test.go deleted file mode 100644 index 26e79b78d13..00000000000 --- a/xds/internal/client/lds_test.go +++ /dev/null @@ -1,1628 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "fmt" - "strings" - "testing" - "time" - - v1typepb "github.com/cncf/udpa/go/udpa/type/v1" - v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" - spb "github.com/golang/protobuf/ptypes/struct" - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/httpfilter" - "google.golang.org/grpc/xds/internal/version" - "google.golang.org/protobuf/types/known/durationpb" - - v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" - v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - v2httppb "github.com/envoyproxy/go-control-plane/envoy/config/filter/network/http_connection_manager/v2" - v2listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v2" - v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" - v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" - v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" - anypb "github.com/golang/protobuf/ptypes/any" - wrapperspb "github.com/golang/protobuf/ptypes/wrappers" -) - -func (s) TestUnmarshalListener_ClientSide(t *testing.T) { - const ( - v2LDSTarget = "lds.target.good:2222" - v3LDSTarget = "lds.target.good:3333" - v2RouteConfigName = "v2RouteConfig" - v3RouteConfigName = "v3RouteConfig" - ) - - var ( - v2Lis = &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v2httppb.HttpConnectionManager{ - RouteSpecifier: &v2httppb.HttpConnectionManager_Rds{ - Rds: &v2httppb.Rds{ - ConfigSource: &v2corepb.ConfigSource{ - ConfigSourceSpecifier: &v2corepb.ConfigSource_Ads{Ads: &v2corepb.AggregatedConfigSource{}}, - }, - RouteConfigName: v2RouteConfigName, - }, - }, - } - mcm, _ := proto.Marshal(cm) - lis := &v2xdspb.Listener{ - Name: v2LDSTarget, - ApiListener: &v2listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2HTTPConnManagerURL, - Value: mcm, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - customFilter = &v3httppb.HttpFilter{ - Name: "customFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, - } - typedStructFilter = &v3httppb.HttpFilter{ - Name: "customFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: wrappedCustomFilterTypedStructConfig}, - } - customOptionalFilter = &v3httppb.HttpFilter{ - Name: "customFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, - IsOptional: true, - } - customFilter2 = &v3httppb.HttpFilter{ - Name: "customFilter2", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, - } - errFilter = &v3httppb.HttpFilter{ - Name: "errFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: errFilterConfig}, - } - errOptionalFilter = &v3httppb.HttpFilter{ - Name: "errFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: errFilterConfig}, - IsOptional: true, - } - clientOnlyCustomFilter = &v3httppb.HttpFilter{ - Name: "clientOnlyCustomFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, - } - serverOnlyCustomFilter = &v3httppb.HttpFilter{ - Name: "serverOnlyCustomFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, - } - serverOnlyOptionalCustomFilter = &v3httppb.HttpFilter{ - Name: "serverOnlyOptionalCustomFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, - IsOptional: true, - } - unknownFilter = &v3httppb.HttpFilter{ - Name: "unknownFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: unknownFilterConfig}, - } - unknownOptionalFilter = &v3httppb.HttpFilter{ - Name: "unknownFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: unknownFilterConfig}, - IsOptional: true, - } - v3LisWithFilters = func(fs ...*v3httppb.HttpFilter) *anypb.Any { - hcm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{ - ConfigSource: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, - }, - RouteConfigName: v3RouteConfigName, - }, - }, - CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ - MaxStreamDuration: durationpb.New(time.Second), - }, - HttpFilters: fs, - } - return &anypb.Any{ - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - mcm, _ := ptypes.MarshalAny(hcm) - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: mcm, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - } - ) - const testVersion = "test-version-lds-client" - - tests := []struct { - name string - resources []*anypb.Any - wantUpdate map[string]ListenerUpdate - wantMD UpdateMetadata - wantErr bool - disableFI bool // disable fault injection - }{ - { - name: "non-listener resource", - resources: []*anypb.Any{{TypeUrl: version.V3HTTPConnManagerURL}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "badly marshaled listener resource", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V3HTTPConnManagerURL, - Value: []byte{1, 2, 3, 4}, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "wrong type in apiListener", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{ - ConfigSource: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, - }, - RouteConfigName: v3RouteConfigName, - }, - }, - } - mcm, _ := proto.Marshal(cm) - return mcm - }(), - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "empty httpConnMgr in apiListener", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{}, - }, - } - mcm, _ := proto.Marshal(cm) - return mcm - }(), - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "scopedRoutes routeConfig in apiListener", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, - } - mcm, _ := proto.Marshal(cm) - return mcm - }(), - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "rds.ConfigSource in apiListener is not ADS", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{ - ConfigSource: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Path{ - Path: "/some/path", - }, - }, - RouteConfigName: v3RouteConfigName, - }, - }, - } - mcm, _ := proto.Marshal(cm) - return mcm - }(), - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "empty resource list", - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with no filters", - resources: []*anypb.Any{v3LisWithFilters()}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters()}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with custom filter", - resources: []*anypb.Any{v3LisWithFilters(customFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "customFilter", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterConfig}, - }}, - Raw: v3LisWithFilters(customFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with custom filter in typed struct", - resources: []*anypb.Any{v3LisWithFilters(typedStructFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "customFilter", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterTypedStructConfig}, - }}, - Raw: v3LisWithFilters(typedStructFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with optional custom filter", - resources: []*anypb.Any{v3LisWithFilters(customOptionalFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "customFilter", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterConfig}, - }}, - Raw: v3LisWithFilters(customOptionalFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with custom filter, fault injection disabled", - resources: []*anypb.Any{v3LisWithFilters(customFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters(customFilter)}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - disableFI: true, - }, - { - name: "v3 with two filters with same name", - resources: []*anypb.Any{v3LisWithFilters(customFilter, customFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with two filters - same type different name", - resources: []*anypb.Any{v3LisWithFilters(customFilter, customFilter2)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "customFilter", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterConfig}, - }, { - Name: "customFilter2", - Filter: httpFilter{}, - Config: filterConfig{Cfg: customFilterConfig}, - }}, - Raw: v3LisWithFilters(customFilter, customFilter2), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with server-only filter", - resources: []*anypb.Any{v3LisWithFilters(serverOnlyCustomFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with optional server-only filter", - resources: []*anypb.Any{v3LisWithFilters(serverOnlyOptionalCustomFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, - MaxStreamDuration: time.Second, - Raw: v3LisWithFilters(serverOnlyOptionalCustomFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with client-only filter", - resources: []*anypb.Any{v3LisWithFilters(clientOnlyCustomFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, - HTTPFilters: []HTTPFilter{{ - Name: "clientOnlyCustomFilter", - Filter: clientOnlyHTTPFilter{}, - Config: filterConfig{Cfg: clientOnlyCustomFilterConfig}, - }}, - Raw: v3LisWithFilters(clientOnlyCustomFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with err filter", - resources: []*anypb.Any{v3LisWithFilters(errFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with optional err filter", - resources: []*anypb.Any{v3LisWithFilters(errOptionalFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with unknown filter", - resources: []*anypb.Any{v3LisWithFilters(unknownFilter)}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - { - name: "v3 with unknown filter (optional)", - resources: []*anypb.Any{v3LisWithFilters(unknownOptionalFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, - MaxStreamDuration: time.Second, - Raw: v3LisWithFilters(unknownOptionalFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 with error filter, fault injection disabled", - resources: []*anypb.Any{v3LisWithFilters(errFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, - MaxStreamDuration: time.Second, - Raw: v3LisWithFilters(errFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - disableFI: true, - }, - { - name: "v2 listener resource", - resources: []*anypb.Any{v2Lis}, - wantUpdate: map[string]ListenerUpdate{ - v2LDSTarget: {RouteConfigName: v2RouteConfigName, Raw: v2Lis}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "v3 listener resource", - resources: []*anypb.Any{v3LisWithFilters()}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters()}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "multiple listener resources", - resources: []*anypb.Any{v2Lis, v3LisWithFilters()}, - wantUpdate: map[string]ListenerUpdate{ - v2LDSTarget: {RouteConfigName: v2RouteConfigName, Raw: v2Lis}, - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters()}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - // To test that unmarshal keeps processing on errors. - name: "good and bad listener resources", - resources: []*anypb.Any{ - v2Lis, - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: "bad", - ApiListener: &v3listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: version.V2ListenerURL, - Value: func() []byte { - cm := &v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, - } - mcm, _ := proto.Marshal(cm) - return mcm - }()}}} - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - v3LisWithFilters(), - }, - wantUpdate: map[string]ListenerUpdate{ - v2LDSTarget: {RouteConfigName: v2RouteConfigName, Raw: v2Lis}, - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters()}, - "bad": {}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = !test.disableFI - - update, md, err := UnmarshalListener(testVersion, test.resources, nil) - if (err != nil) != test.wantErr { - t.Fatalf("UnmarshalListener(), got err: %v, wantErr: %v", err, test.wantErr) - } - if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { - t.Errorf("got unexpected update, diff (-got +want): %v", diff) - } - if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { - t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) - } - env.FaultInjectionSupport = oldFI - }) - } -} - -func (s) TestUnmarshalListener_ServerSide(t *testing.T) { - const v3LDSTarget = "grpc/server?udpa.resource.listening_address=0.0.0.0:9999" - - var ( - listenerEmptyTransportSocket = &anypb.Any{ - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - listenerNoValidationContext = &anypb.Any{ - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "identityPluginInstance", - CertificateName: "identityCertName", - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - listenerWithValidationContext = &anypb.Any{ - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "identityPluginInstance", - CertificateName: "identityCertName", - }, - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "rootPluginInstance", - CertificateName: "rootCertName", - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - } - ) - - const testVersion = "test-version-lds-server" - - tests := []struct { - name string - resources []*anypb.Any - wantUpdate map[string]ListenerUpdate - wantMD UpdateMetadata - wantErr string - }{ - { - name: "no address field", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "no address field in LDS response", - }, - { - name: "no socket address field", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{}, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "no socket_address field in LDS response", - }, - { - name: "listener name does not match expected format", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: "foo", - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{"foo": {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "no host:port in name field of LDS response", - }, - { - name: "host mismatch", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "1.2.3.4", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "socket_address host does not match the one in name", - }, - { - name: "port mismatch", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 1234, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "socket_address port does not match the one in name", - }, - { - name: "unexpected number of filter chains", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - {Name: "filter-chain-1"}, - {Name: "filter-chain-2"}, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "filter chains count in LDS response does not match expected", - }, - { - name: "unexpected transport socket name", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "unsupported-transport-socket-name", - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "transport_socket field has unexpected name", - }, - { - name: "unexpected transport socket typedConfig URL", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3UpstreamTLSContextURL, - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "transport_socket field has unexpected typeURL", - }, - { - name: "badly marshaled transport socket", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: []byte{1, 2, 3, 4}, - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "failed to unmarshal DownstreamTlsContext in LDS response", - }, - { - name: "missing CommonTlsContext", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{} - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "DownstreamTlsContext in LDS response does not contain a CommonTlsContext", - }, - { - name: "unsupported validation context in transport socket", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{ - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ - ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ - Name: "foo-sds-secret", - }, - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "validation context contains unexpected type", - }, - { - name: "empty transport socket", - resources: []*anypb.Any{listenerEmptyTransportSocket}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: {Raw: listenerEmptyTransportSocket}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "no identity and root certificate providers", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "identityPluginInstance", - CertificateName: "identityCertName", - }, - }, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set", - }, - { - name: "no identity certificate provider with require_client_cert", - resources: []*anypb.Any{ - { - TypeUrl: version.V3ListenerURL, - Value: func() []byte { - lis := &v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: &v3corepb.Address{ - Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Address: "0.0.0.0", - PortSpecifier: &v3corepb.SocketAddress_PortValue{ - PortValue: 9999, - }, - }, - }, - }, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - CommonTlsContext: &v3tlspb.CommonTlsContext{}, - } - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }, - }, - }, - }, - }, - } - mLis, _ := proto.Marshal(lis) - return mLis - }(), - }, - }, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: UpdateMetadata{ - Status: ServiceStatusNACKed, - Version: testVersion, - ErrState: &UpdateErrorMetadata{ - Version: testVersion, - Err: errPlaceHolder, - }, - }, - wantErr: "security configuration on the server-side does not contain identity certificate provider instance name", - }, - { - name: "happy case with no validation context", - resources: []*anypb.Any{listenerNoValidationContext}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - SecurityCfg: &SecurityConfig{ - IdentityInstanceName: "identityPluginInstance", - IdentityCertName: "identityCertName", - }, - Raw: listenerNoValidationContext, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - { - name: "happy case with validation context provider instance", - resources: []*anypb.Any{listenerWithValidationContext}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - SecurityCfg: &SecurityConfig{ - RootInstanceName: "rootPluginInstance", - RootCertName: "rootCertName", - IdentityInstanceName: "identityPluginInstance", - IdentityCertName: "identityCertName", - RequireClientCert: true, - }, - Raw: listenerWithValidationContext, - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - gotUpdate, md, err := UnmarshalListener(testVersion, test.resources, nil) - if (err != nil) != (test.wantErr != "") { - t.Fatalf("UnmarshalListener(), got err: %v, wantErr: %v", err, test.wantErr) - } - if err != nil && !strings.Contains(err.Error(), test.wantErr) { - t.Fatalf("UnmarshalListener() = %v wantErr: %q", err, test.wantErr) - } - if diff := cmp.Diff(gotUpdate, test.wantUpdate, cmpOpts); diff != "" { - t.Errorf("got unexpected update, diff (-got +want): %v", diff) - } - if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { - t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) - } - }) - } -} - -type filterConfig struct { - httpfilter.FilterConfig - Cfg proto.Message - Override proto.Message -} - -// httpFilter allows testing the http filter registry and parsing functionality. -type httpFilter struct { - httpfilter.ClientInterceptorBuilder - httpfilter.ServerInterceptorBuilder -} - -func (httpFilter) TypeURLs() []string { return []string{"custom.filter"} } - -func (httpFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Cfg: cfg}, nil -} - -func (httpFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Override: override}, nil -} - -// errHTTPFilter returns errors no matter what is passed to ParseFilterConfig. -type errHTTPFilter struct { - httpfilter.ClientInterceptorBuilder -} - -func (errHTTPFilter) TypeURLs() []string { return []string{"err.custom.filter"} } - -func (errHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { - return nil, fmt.Errorf("error from ParseFilterConfig") -} - -func (errHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { - return nil, fmt.Errorf("error from ParseFilterConfigOverride") -} - -func init() { - httpfilter.Register(httpFilter{}) - httpfilter.Register(errHTTPFilter{}) - httpfilter.Register(serverOnlyHTTPFilter{}) - httpfilter.Register(clientOnlyHTTPFilter{}) -} - -// serverOnlyHTTPFilter does not implement ClientInterceptorBuilder -type serverOnlyHTTPFilter struct { - httpfilter.ServerInterceptorBuilder -} - -func (serverOnlyHTTPFilter) TypeURLs() []string { return []string{"serverOnly.custom.filter"} } - -func (serverOnlyHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Cfg: cfg}, nil -} - -func (serverOnlyHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Override: override}, nil -} - -// clientOnlyHTTPFilter does not implement ServerInterceptorBuilder -type clientOnlyHTTPFilter struct { - httpfilter.ClientInterceptorBuilder -} - -func (clientOnlyHTTPFilter) TypeURLs() []string { return []string{"clientOnly.custom.filter"} } - -func (clientOnlyHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Cfg: cfg}, nil -} - -func (clientOnlyHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { - return filterConfig{Override: override}, nil -} - -var customFilterConfig = &anypb.Any{ - TypeUrl: "custom.filter", - Value: []byte{1, 2, 3}, -} - -var errFilterConfig = &anypb.Any{ - TypeUrl: "err.custom.filter", - Value: []byte{1, 2, 3}, -} - -var serverOnlyCustomFilterConfig = &anypb.Any{ - TypeUrl: "serverOnly.custom.filter", - Value: []byte{1, 2, 3}, -} - -var clientOnlyCustomFilterConfig = &anypb.Any{ - TypeUrl: "clientOnly.custom.filter", - Value: []byte{1, 2, 3}, -} - -var customFilterTypedStructConfig = &v1typepb.TypedStruct{ - TypeUrl: "custom.filter", - Value: &spb.Struct{ - Fields: map[string]*spb.Value{ - "foo": {Kind: &spb.Value_StringValue{StringValue: "bar"}}, - }, - }, -} -var wrappedCustomFilterTypedStructConfig *anypb.Any - -func init() { - var err error - wrappedCustomFilterTypedStructConfig, err = ptypes.MarshalAny(customFilterTypedStructConfig) - if err != nil { - panic(err.Error()) - } -} - -var unknownFilterConfig = &anypb.Any{ - TypeUrl: "unknown.custom.filter", - Value: []byte{1, 2, 3}, -} - -func wrappedOptionalFilter(name string) *anypb.Any { - filter := &v3routepb.FilterConfig{ - IsOptional: true, - Config: &anypb.Any{ - TypeUrl: name, - Value: []byte{1, 2, 3}, - }, - } - w, err := ptypes.MarshalAny(filter) - if err != nil { - panic("error marshalling any: " + err.Error()) - } - return w -} diff --git a/xds/internal/client/requests_counter.go b/xds/internal/client/requests_counter.go deleted file mode 100644 index 7ef18345ed6..00000000000 --- a/xds/internal/client/requests_counter.go +++ /dev/null @@ -1,82 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "fmt" - "sync" - "sync/atomic" -) - -type servicesRequestsCounter struct { - mu sync.Mutex - services map[string]*ServiceRequestsCounter -} - -var src = &servicesRequestsCounter{ - services: make(map[string]*ServiceRequestsCounter), -} - -// ServiceRequestsCounter is used to track the total inflight requests for a -// service with the provided name. -type ServiceRequestsCounter struct { - ServiceName string - numRequests uint32 -} - -// GetServiceRequestsCounter returns the ServiceRequestsCounter with the -// provided serviceName. If one does not exist, it creates it. -func GetServiceRequestsCounter(serviceName string) *ServiceRequestsCounter { - src.mu.Lock() - defer src.mu.Unlock() - c, ok := src.services[serviceName] - if !ok { - c = &ServiceRequestsCounter{ServiceName: serviceName} - src.services[serviceName] = c - } - return c -} - -// StartRequest starts a request for a service, incrementing its number of -// requests by 1. Returns an error if the max number of requests is exceeded. -func (c *ServiceRequestsCounter) StartRequest(max uint32) error { - if atomic.LoadUint32(&c.numRequests) >= max { - return fmt.Errorf("max requests %v exceeded on service %v", max, c.ServiceName) - } - atomic.AddUint32(&c.numRequests, 1) - return nil -} - -// EndRequest ends a request for a service, decrementing its number of requests -// by 1. -func (c *ServiceRequestsCounter) EndRequest() { - atomic.AddUint32(&c.numRequests, ^uint32(0)) -} - -// ClearCounterForTesting clears the counter for the service. Should be only -// used in tests. -func ClearCounterForTesting(serviceName string) { - src.mu.Lock() - defer src.mu.Unlock() - c, ok := src.services[serviceName] - if !ok { - return - } - c.numRequests = 0 -} diff --git a/xds/internal/client/singleton.go b/xds/internal/client/singleton.go deleted file mode 100644 index 5d92b4146bb..00000000000 --- a/xds/internal/client/singleton.go +++ /dev/null @@ -1,101 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "fmt" - "sync" - "time" - - "google.golang.org/grpc/xds/internal/client/bootstrap" -) - -const defaultWatchExpiryTimeout = 15 * time.Second - -// This is the Client returned by New(). It contains one client implementation, -// and maintains the refcount. -var singletonClient = &Client{} - -// To override in tests. -var bootstrapNewConfig = bootstrap.NewConfig - -// Client is a full fledged gRPC client which queries a set of discovery APIs -// (collectively termed as xDS) on a remote management server, to discover -// various dynamic resources. -// -// The xds client is a singleton. It will be shared by the xds resolver and -// balancer implementations, across multiple ClientConns and Servers. -type Client struct { - *clientImpl - - // This mu protects all the fields, including the embedded clientImpl above. - mu sync.Mutex - refCount int -} - -// New returns a new xdsClient configured by the bootstrap file specified in env -// variable GRPC_XDS_BOOTSTRAP. -func New() (*Client, error) { - singletonClient.mu.Lock() - defer singletonClient.mu.Unlock() - // If the client implementation was created, increment ref count and return - // the client. - if singletonClient.clientImpl != nil { - singletonClient.refCount++ - return singletonClient, nil - } - - // Create the new client implementation. - config, err := bootstrapNewConfig() - if err != nil { - return nil, fmt.Errorf("xds: failed to read bootstrap file: %v", err) - } - c, err := newWithConfig(config, defaultWatchExpiryTimeout) - if err != nil { - return nil, err - } - - singletonClient.clientImpl = c - singletonClient.refCount++ - return singletonClient, nil -} - -// Close closes the client. It does ref count of the xds client implementation, -// and closes the gRPC connection to the management server when ref count -// reaches 0. -func (c *Client) Close() { - c.mu.Lock() - defer c.mu.Unlock() - c.refCount-- - if c.refCount == 0 { - c.clientImpl.Close() - // Set clientImpl back to nil. So if New() is called after this, a new - // implementation will be created. - c.clientImpl = nil - } -} - -// NewWithConfigForTesting is exported for testing only. -func NewWithConfigForTesting(config *bootstrap.Config, watchExpiryTimeout time.Duration) (*Client, error) { - cl, err := newWithConfig(config, watchExpiryTimeout) - if err != nil { - return nil, err - } - return &Client{clientImpl: cl, refCount: 1}, nil -} diff --git a/xds/internal/client/tests/README.md b/xds/internal/client/tests/README.md deleted file mode 100644 index 6dc940c103f..00000000000 --- a/xds/internal/client/tests/README.md +++ /dev/null @@ -1 +0,0 @@ -This package contains tests which cannot live in the `client` package because they need to import one of the API client packages (which itself has a dependency on the `client` package). diff --git a/xds/internal/client/watchers_listener_test.go b/xds/internal/client/watchers_listener_test.go deleted file mode 100644 index bf3a122da07..00000000000 --- a/xds/internal/client/watchers_listener_test.go +++ /dev/null @@ -1,358 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "context" - "testing" - - "google.golang.org/grpc/internal/testutils" -) - -type ldsUpdateErr struct { - u ListenerUpdate - err error -} - -// TestLDSWatch covers the cases: -// - an update is received after a watch() -// - an update for another resource name -// - an update is received after cancel() -func (s) TestLDSWatch(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - ldsUpdateCh := testutils.NewChannel() - cancelWatch := client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { - ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - - wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate); err != nil { - t.Fatal(err) - } - - // Another update, with an extra resource for a different resource name. - client.NewListeners(map[string]ListenerUpdate{ - testLDSName: wantUpdate, - "randomName": {}, - }, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate); err != nil { - t.Fatal(err) - } - - // Cancel watch, and send update again. - cancelWatch() - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) - } -} - -// TestLDSTwoWatchSameResourceName covers the case where an update is received -// after two watch() for the same resource name. -func (s) TestLDSTwoWatchSameResourceName(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - const count = 2 - var ( - ldsUpdateChs []*testutils.Channel - cancelLastWatch func() - ) - - for i := 0; i < count; i++ { - ldsUpdateCh := testutils.NewChannel() - ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) - cancelLastWatch = client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { - ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) - }) - - if i == 0 { - // A new watch is registered on the underlying API client only for - // the first iteration because we are using the same resource name. - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - } - } - - wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count; i++ { - if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } - } - - // Cancel the last watch, and send update again. - cancelLastWatch() - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count-1; i++ { - if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } - } - - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := ldsUpdateChs[count-1].Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) - } -} - -// TestLDSThreeWatchDifferentResourceName covers the case where an update is -// received after three watch() for different resource names. -func (s) TestLDSThreeWatchDifferentResourceName(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - var ldsUpdateChs []*testutils.Channel - const count = 2 - - // Two watches for the same name. - for i := 0; i < count; i++ { - ldsUpdateCh := testutils.NewChannel() - ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) - client.WatchListener(testLDSName+"1", func(update ListenerUpdate, err error) { - ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) - }) - - if i == 0 { - // A new watch is registered on the underlying API client only for - // the first iteration because we are using the same resource name. - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - } - } - - // Third watch for a different name. - ldsUpdateCh2 := testutils.NewChannel() - client.WatchListener(testLDSName+"2", func(update ListenerUpdate, err error) { - ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - - wantUpdate1 := ListenerUpdate{RouteConfigName: testRDSName + "1"} - wantUpdate2 := ListenerUpdate{RouteConfigName: testRDSName + "2"} - client.NewListeners(map[string]ListenerUpdate{ - testLDSName + "1": wantUpdate1, - testLDSName + "2": wantUpdate2, - }, UpdateMetadata{}) - - for i := 0; i < count; i++ { - if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate1); err != nil { - t.Fatal(err) - } - } - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) - } -} - -// TestLDSWatchAfterCache covers the case where watch is called after the update -// is in cache. -func (s) TestLDSWatchAfterCache(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - ldsUpdateCh := testutils.NewChannel() - client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { - ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - - wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} - client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate); err != nil { - t.Fatal(err) - } - - // Another watch for the resource in cache. - ldsUpdateCh2 := testutils.NewChannel() - client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { - ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err}) - }) - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if n, err := apiClient.addWatches[ListenerResource].Receive(sCtx); err != context.DeadlineExceeded { - t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err) - } - - // New watch should receive the update. - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate); err != nil { - t.Fatal(err) - } - - // Old watch should see nothing. - sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) - } -} - -// TestLDSResourceRemoved covers the cases: -// - an update is received after a watch() -// - another update is received, with one resource removed -// - this should trigger callback with resource removed error -// - one more update without the removed resource -// - the callback (above) shouldn't receive any update -func (s) TestLDSResourceRemoved(t *testing.T) { - apiClientCh, cleanup := overrideNewAPIClient() - defer cleanup() - - client, err := newWithConfig(clientOpts(testXDSServer, false)) - if err != nil { - t.Fatalf("failed to create client: %v", err) - } - defer client.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := apiClientCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for API client to be created: %v", err) - } - apiClient := c.(*testAPIClient) - - ldsUpdateCh1 := testutils.NewChannel() - client.WatchListener(testLDSName+"1", func(update ListenerUpdate, err error) { - ldsUpdateCh1.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - // Another watch for a different name. - ldsUpdateCh2 := testutils.NewChannel() - client.WatchListener(testLDSName+"2", func(update ListenerUpdate, err error) { - ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err}) - }) - if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { - t.Fatalf("want new watch to start, got error %v", err) - } - - wantUpdate1 := ListenerUpdate{RouteConfigName: testEDSName + "1"} - wantUpdate2 := ListenerUpdate{RouteConfigName: testEDSName + "2"} - client.NewListeners(map[string]ListenerUpdate{ - testLDSName + "1": wantUpdate1, - testLDSName + "2": wantUpdate2, - }, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh1, wantUpdate1); err != nil { - t.Fatal(err) - } - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) - } - - // Send another update to remove resource 1. - client.NewListeners(map[string]ListenerUpdate{testLDSName + "2": wantUpdate2}, UpdateMetadata{}) - - // Watcher 1 should get an error. - if u, err := ldsUpdateCh1.Receive(ctx); err != nil || ErrType(u.(ldsUpdateErr).err) != ErrorTypeResourceNotFound { - t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err) - } - - // Watcher 2 should get the same update again. - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) - } - - // Send one more update without resource 1. - client.NewListeners(map[string]ListenerUpdate{testLDSName + "2": wantUpdate2}, UpdateMetadata{}) - - // Watcher 1 should not see an update. - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := ldsUpdateCh1.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u) - } - - // Watcher 2 should get the same update again. - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) - } -} diff --git a/xds/internal/client/xds.go b/xds/internal/client/xds.go deleted file mode 100644 index 78dc139e8fe..00000000000 --- a/xds/internal/client/xds.go +++ /dev/null @@ -1,915 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package client - -import ( - "errors" - "fmt" - "net" - "strconv" - "strings" - "time" - - v1typepb "github.com/cncf/udpa/go/udpa/type/v1" - v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" - v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" - v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" - v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" - v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" - v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" - "google.golang.org/protobuf/types/known/anypb" - - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/httpfilter" - "google.golang.org/grpc/xds/internal/version" -) - -// TransportSocket proto message has a `name` field which is expected to be set -// to this value by the management server. -const transportSocketName = "envoy.transport_sockets.tls" - -// UnmarshalListener processes resources received in an LDS response, validates -// them, and transforms them into a native struct which contains only fields we -// are interested in. -func UnmarshalListener(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger) (map[string]ListenerUpdate, UpdateMetadata, error) { - update := make(map[string]ListenerUpdate) - md, err := processAllResources(version, resources, logger, update) - return update, md, err -} - -func unmarshalListenerResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, ListenerUpdate, error) { - if !IsListenerResource(r.GetTypeUrl()) { - return "", ListenerUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) - } - // TODO: Pass version.TransportAPI instead of relying upon the type URL - v2 := r.GetTypeUrl() == version.V2ListenerURL - lis := &v3listenerpb.Listener{} - if err := proto.Unmarshal(r.GetValue(), lis); err != nil { - return "", ListenerUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) - } - logger.Infof("Resource with name: %v, type: %T, contains: %v", lis.GetName(), lis, lis) - - lu, err := processListener(lis, v2) - if err != nil { - return lis.GetName(), ListenerUpdate{}, err - } - lu.Raw = r - return lis.GetName(), *lu, nil -} - -func processListener(lis *v3listenerpb.Listener, v2 bool) (*ListenerUpdate, error) { - if lis.GetApiListener() != nil { - return processClientSideListener(lis, v2) - } - return processServerSideListener(lis) -} - -// processClientSideListener checks if the provided Listener proto meets -// the expected criteria. If so, it returns a non-empty routeConfigName. -func processClientSideListener(lis *v3listenerpb.Listener, v2 bool) (*ListenerUpdate, error) { - update := &ListenerUpdate{} - - apiLisAny := lis.GetApiListener().GetApiListener() - if !IsHTTPConnManagerResource(apiLisAny.GetTypeUrl()) { - return nil, fmt.Errorf("unexpected resource type: %q", apiLisAny.GetTypeUrl()) - } - apiLis := &v3httppb.HttpConnectionManager{} - if err := proto.Unmarshal(apiLisAny.GetValue(), apiLis); err != nil { - return nil, fmt.Errorf("failed to unmarshal api_listner: %v", err) - } - - switch apiLis.RouteSpecifier.(type) { - case *v3httppb.HttpConnectionManager_Rds: - if apiLis.GetRds().GetConfigSource().GetAds() == nil { - return nil, fmt.Errorf("ConfigSource is not ADS: %+v", lis) - } - name := apiLis.GetRds().GetRouteConfigName() - if name == "" { - return nil, fmt.Errorf("empty route_config_name: %+v", lis) - } - update.RouteConfigName = name - case *v3httppb.HttpConnectionManager_RouteConfig: - // TODO: Add support for specifying the RouteConfiguration inline - // in the LDS response. - return nil, fmt.Errorf("LDS response contains RDS config inline. Not supported for now: %+v", apiLis) - case nil: - return nil, fmt.Errorf("no RouteSpecifier: %+v", apiLis) - default: - return nil, fmt.Errorf("unsupported type %T for RouteSpecifier", apiLis.RouteSpecifier) - } - - if v2 { - return update, nil - } - - // The following checks and fields only apply to xDS protocol versions v3+. - - update.MaxStreamDuration = apiLis.GetCommonHttpProtocolOptions().GetMaxStreamDuration().AsDuration() - - var err error - if update.HTTPFilters, err = processHTTPFilters(apiLis.GetHttpFilters(), false); err != nil { - return nil, err - } - - return update, nil -} - -func unwrapHTTPFilterConfig(config *anypb.Any) (proto.Message, string, error) { - // The real type name is inside the TypedStruct. - s := new(v1typepb.TypedStruct) - if !ptypes.Is(config, s) { - return config, config.GetTypeUrl(), nil - } - if err := ptypes.UnmarshalAny(config, s); err != nil { - return nil, "", fmt.Errorf("error unmarshalling TypedStruct filter config: %v", err) - } - return s, s.GetTypeUrl(), nil -} - -func validateHTTPFilterConfig(cfg *anypb.Any, lds, optional bool) (httpfilter.Filter, httpfilter.FilterConfig, error) { - config, typeURL, err := unwrapHTTPFilterConfig(cfg) - if err != nil { - return nil, nil, err - } - filterBuilder := httpfilter.Get(typeURL) - if filterBuilder == nil { - if optional { - return nil, nil, nil - } - return nil, nil, fmt.Errorf("no filter implementation found for %q", typeURL) - } - parseFunc := filterBuilder.ParseFilterConfig - if !lds { - parseFunc = filterBuilder.ParseFilterConfigOverride - } - filterConfig, err := parseFunc(config) - if err != nil { - return nil, nil, fmt.Errorf("error parsing config for filter %q: %v", typeURL, err) - } - return filterBuilder, filterConfig, nil -} - -func processHTTPFilterOverrides(cfgs map[string]*anypb.Any) (map[string]httpfilter.FilterConfig, error) { - if !env.FaultInjectionSupport || len(cfgs) == 0 { - return nil, nil - } - m := make(map[string]httpfilter.FilterConfig) - for name, cfg := range cfgs { - optional := false - s := new(v3routepb.FilterConfig) - if ptypes.Is(cfg, s) { - if err := ptypes.UnmarshalAny(cfg, s); err != nil { - return nil, fmt.Errorf("filter override %q: error unmarshalling FilterConfig: %v", name, err) - } - cfg = s.GetConfig() - optional = s.GetIsOptional() - } - - httpFilter, config, err := validateHTTPFilterConfig(cfg, false, optional) - if err != nil { - return nil, fmt.Errorf("filter override %q: %v", name, err) - } - if httpFilter == nil { - // Optional configs are ignored. - continue - } - m[name] = config - } - return m, nil -} - -func processHTTPFilters(filters []*v3httppb.HttpFilter, server bool) ([]HTTPFilter, error) { - if !env.FaultInjectionSupport { - return nil, nil - } - - ret := make([]HTTPFilter, 0, len(filters)) - seenNames := make(map[string]bool, len(filters)) - for _, filter := range filters { - name := filter.GetName() - if name == "" { - return nil, errors.New("filter missing name field") - } - if seenNames[name] { - return nil, fmt.Errorf("duplicate filter name %q", name) - } - seenNames[name] = true - - httpFilter, config, err := validateHTTPFilterConfig(filter.GetTypedConfig(), true, filter.GetIsOptional()) - if err != nil { - return nil, err - } - if httpFilter == nil { - // Optional configs are ignored. - continue - } - if server { - if _, ok := httpFilter.(httpfilter.ServerInterceptorBuilder); !ok { - if filter.GetIsOptional() { - continue - } - return nil, fmt.Errorf("HTTP filter %q not supported server-side", name) - } - } else if _, ok := httpFilter.(httpfilter.ClientInterceptorBuilder); !ok { - if filter.GetIsOptional() { - continue - } - return nil, fmt.Errorf("HTTP filter %q not supported client-side", name) - } - - // Save name/config - ret = append(ret, HTTPFilter{Name: name, Filter: httpFilter, Config: config}) - } - return ret, nil -} - -func processServerSideListener(lis *v3listenerpb.Listener) (*ListenerUpdate, error) { - // Make sure that an address encoded in the received listener resource, and - // that it matches the one specified in the name. Listener names on the - // server-side as in the following format: - // grpc/server?udpa.resource.listening_address=IP:Port. - addr := lis.GetAddress() - if addr == nil { - return nil, fmt.Errorf("no address field in LDS response: %+v", lis) - } - sockAddr := addr.GetSocketAddress() - if sockAddr == nil { - return nil, fmt.Errorf("no socket_address field in LDS response: %+v", lis) - } - host, port, err := getAddressFromName(lis.GetName()) - if err != nil { - return nil, fmt.Errorf("no host:port in name field of LDS response: %+v, error: %v", lis, err) - } - if h := sockAddr.GetAddress(); host != h { - return nil, fmt.Errorf("socket_address host does not match the one in name. Got %q, want %q", h, host) - } - if p := strconv.Itoa(int(sockAddr.GetPortValue())); port != p { - return nil, fmt.Errorf("socket_address port does not match the one in name. Got %q, want %q", p, port) - } - - // Make sure the listener resource contains a single filter chain. We do not - // support multiple filter chains and picking the best match from the list. - fcs := lis.GetFilterChains() - if n := len(fcs); n != 1 { - return nil, fmt.Errorf("filter chains count in LDS response does not match expected. Got %d, want 1", n) - } - fc := fcs[0] - - // If the transport_socket field is not specified, it means that the control - // plane has not sent us any security config. This is fine and the server - // will use the fallback credentials configured as part of the - // xdsCredentials. - ts := fc.GetTransportSocket() - if ts == nil { - return &ListenerUpdate{}, nil - } - if name := ts.GetName(); name != transportSocketName { - return nil, fmt.Errorf("transport_socket field has unexpected name: %s", name) - } - any := ts.GetTypedConfig() - if any == nil || any.TypeUrl != version.V3DownstreamTLSContextURL { - return nil, fmt.Errorf("transport_socket field has unexpected typeURL: %s", any.TypeUrl) - } - downstreamCtx := &v3tlspb.DownstreamTlsContext{} - if err := proto.Unmarshal(any.GetValue(), downstreamCtx); err != nil { - return nil, fmt.Errorf("failed to unmarshal DownstreamTlsContext in LDS response: %v", err) - } - if downstreamCtx.GetCommonTlsContext() == nil { - return nil, errors.New("DownstreamTlsContext in LDS response does not contain a CommonTlsContext") - } - sc, err := securityConfigFromCommonTLSContext(downstreamCtx.GetCommonTlsContext()) - if err != nil { - return nil, err - } - if sc.IdentityInstanceName == "" { - return nil, errors.New("security configuration on the server-side does not contain identity certificate provider instance name") - } - sc.RequireClientCert = downstreamCtx.GetRequireClientCertificate().GetValue() - if sc.RequireClientCert && sc.RootInstanceName == "" { - return nil, errors.New("security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set") - } - return &ListenerUpdate{SecurityCfg: sc}, nil -} - -func getAddressFromName(name string) (host string, port string, err error) { - parts := strings.SplitN(name, "udpa.resource.listening_address=", 2) - if len(parts) != 2 { - return "", "", fmt.Errorf("udpa.resource_listening_address not found in name: %v", name) - } - return net.SplitHostPort(parts[1]) -} - -// UnmarshalRouteConfig processes resources received in an RDS response, -// validates them, and transforms them into a native struct which contains only -// fields we are interested in. The provided hostname determines the route -// configuration resources of interest. -func UnmarshalRouteConfig(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger) (map[string]RouteConfigUpdate, UpdateMetadata, error) { - update := make(map[string]RouteConfigUpdate) - md, err := processAllResources(version, resources, logger, update) - return update, md, err -} - -func unmarshalRouteConfigResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, RouteConfigUpdate, error) { - if !IsRouteConfigResource(r.GetTypeUrl()) { - return "", RouteConfigUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) - } - rc := &v3routepb.RouteConfiguration{} - if err := proto.Unmarshal(r.GetValue(), rc); err != nil { - return "", RouteConfigUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) - } - logger.Infof("Resource with name: %v, type: %T, contains: %v.", rc.GetName(), rc, rc) - - // TODO: Pass version.TransportAPI instead of relying upon the type URL - v2 := r.GetTypeUrl() == version.V2RouteConfigURL - u, err := generateRDSUpdateFromRouteConfiguration(rc, logger, v2) - if err != nil { - return rc.GetName(), RouteConfigUpdate{}, err - } - u.Raw = r - return rc.GetName(), u, nil -} - -// generateRDSUpdateFromRouteConfiguration checks if the provided -// RouteConfiguration meets the expected criteria. If so, it returns a -// RouteConfigUpdate with nil error. -// -// A RouteConfiguration resource is considered valid when only if it contains a -// VirtualHost whose domain field matches the server name from the URI passed -// to the gRPC channel, and it contains a clusterName or a weighted cluster. -// -// The RouteConfiguration includes a list of VirtualHosts, which may have zero -// or more elements. We are interested in the element whose domains field -// matches the server name specified in the "xds:" URI. The only field in the -// VirtualHost proto that the we are interested in is the list of routes. We -// only look at the last route in the list (the default route), whose match -// field must be empty and whose route field must be set. Inside that route -// message, the cluster field will contain the clusterName or weighted clusters -// we are looking for. -func generateRDSUpdateFromRouteConfiguration(rc *v3routepb.RouteConfiguration, logger *grpclog.PrefixLogger, v2 bool) (RouteConfigUpdate, error) { - var vhs []*VirtualHost - for _, vh := range rc.GetVirtualHosts() { - routes, err := routesProtoToSlice(vh.Routes, logger, v2) - if err != nil { - return RouteConfigUpdate{}, fmt.Errorf("received route is invalid: %v", err) - } - vhOut := &VirtualHost{ - Domains: vh.GetDomains(), - Routes: routes, - } - if !v2 { - cfgs, err := processHTTPFilterOverrides(vh.GetTypedPerFilterConfig()) - if err != nil { - return RouteConfigUpdate{}, fmt.Errorf("virtual host %+v: %v", vh, err) - } - vhOut.HTTPFilterConfigOverride = cfgs - } - vhs = append(vhs, vhOut) - } - return RouteConfigUpdate{VirtualHosts: vhs}, nil -} - -func routesProtoToSlice(routes []*v3routepb.Route, logger *grpclog.PrefixLogger, v2 bool) ([]*Route, error) { - var routesRet []*Route - - for _, r := range routes { - match := r.GetMatch() - if match == nil { - return nil, fmt.Errorf("route %+v doesn't have a match", r) - } - - if len(match.GetQueryParameters()) != 0 { - // Ignore route with query parameters. - logger.Warningf("route %+v has query parameter matchers, the route will be ignored", r) - continue - } - - pathSp := match.GetPathSpecifier() - if pathSp == nil { - return nil, fmt.Errorf("route %+v doesn't have a path specifier", r) - } - - var route Route - switch pt := pathSp.(type) { - case *v3routepb.RouteMatch_Prefix: - route.Prefix = &pt.Prefix - case *v3routepb.RouteMatch_Path: - route.Path = &pt.Path - case *v3routepb.RouteMatch_SafeRegex: - route.Regex = &pt.SafeRegex.Regex - default: - return nil, fmt.Errorf("route %+v has an unrecognized path specifier: %+v", r, pt) - } - - if caseSensitive := match.GetCaseSensitive(); caseSensitive != nil { - route.CaseInsensitive = !caseSensitive.Value - } - - for _, h := range match.GetHeaders() { - var header HeaderMatcher - switch ht := h.GetHeaderMatchSpecifier().(type) { - case *v3routepb.HeaderMatcher_ExactMatch: - header.ExactMatch = &ht.ExactMatch - case *v3routepb.HeaderMatcher_SafeRegexMatch: - header.RegexMatch = &ht.SafeRegexMatch.Regex - case *v3routepb.HeaderMatcher_RangeMatch: - header.RangeMatch = &Int64Range{ - Start: ht.RangeMatch.Start, - End: ht.RangeMatch.End, - } - case *v3routepb.HeaderMatcher_PresentMatch: - header.PresentMatch = &ht.PresentMatch - case *v3routepb.HeaderMatcher_PrefixMatch: - header.PrefixMatch = &ht.PrefixMatch - case *v3routepb.HeaderMatcher_SuffixMatch: - header.SuffixMatch = &ht.SuffixMatch - default: - return nil, fmt.Errorf("route %+v has an unrecognized header matcher: %+v", r, ht) - } - header.Name = h.GetName() - invert := h.GetInvertMatch() - header.InvertMatch = &invert - route.Headers = append(route.Headers, &header) - } - - if fr := match.GetRuntimeFraction(); fr != nil { - d := fr.GetDefaultValue() - n := d.GetNumerator() - switch d.GetDenominator() { - case v3typepb.FractionalPercent_HUNDRED: - n *= 10000 - case v3typepb.FractionalPercent_TEN_THOUSAND: - n *= 100 - case v3typepb.FractionalPercent_MILLION: - } - route.Fraction = &n - } - - route.WeightedClusters = make(map[string]WeightedCluster) - action := r.GetRoute() - switch a := action.GetClusterSpecifier().(type) { - case *v3routepb.RouteAction_Cluster: - route.WeightedClusters[a.Cluster] = WeightedCluster{Weight: 1} - case *v3routepb.RouteAction_WeightedClusters: - wcs := a.WeightedClusters - var totalWeight uint32 - for _, c := range wcs.Clusters { - w := c.GetWeight().GetValue() - if w == 0 { - continue - } - wc := WeightedCluster{Weight: w} - if !v2 { - cfgs, err := processHTTPFilterOverrides(c.GetTypedPerFilterConfig()) - if err != nil { - return nil, fmt.Errorf("route %+v, action %+v: %v", r, a, err) - } - wc.HTTPFilterConfigOverride = cfgs - } - route.WeightedClusters[c.GetName()] = wc - totalWeight += w - } - if totalWeight != wcs.GetTotalWeight().GetValue() { - return nil, fmt.Errorf("route %+v, action %+v, weights of clusters do not add up to total total weight, got: %v, want %v", r, a, wcs.GetTotalWeight().GetValue(), totalWeight) - } - if totalWeight == 0 { - return nil, fmt.Errorf("route %+v, action %+v, has no valid cluster in WeightedCluster action", r, a) - } - case *v3routepb.RouteAction_ClusterHeader: - continue - } - - msd := action.GetMaxStreamDuration() - // Prefer grpc_timeout_header_max, if set. - dur := msd.GetGrpcTimeoutHeaderMax() - if dur == nil { - dur = msd.GetMaxStreamDuration() - } - if dur != nil { - d := dur.AsDuration() - route.MaxStreamDuration = &d - } - - if !v2 { - cfgs, err := processHTTPFilterOverrides(r.GetTypedPerFilterConfig()) - if err != nil { - return nil, fmt.Errorf("route %+v: %v", r, err) - } - route.HTTPFilterConfigOverride = cfgs - } - routesRet = append(routesRet, &route) - } - return routesRet, nil -} - -// UnmarshalCluster processes resources received in an CDS response, validates -// them, and transforms them into a native struct which contains only fields we -// are interested in. -func UnmarshalCluster(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger) (map[string]ClusterUpdate, UpdateMetadata, error) { - update := make(map[string]ClusterUpdate) - md, err := processAllResources(version, resources, logger, update) - return update, md, err -} - -func unmarshalClusterResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, ClusterUpdate, error) { - if !IsClusterResource(r.GetTypeUrl()) { - return "", ClusterUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) - } - - cluster := &v3clusterpb.Cluster{} - if err := proto.Unmarshal(r.GetValue(), cluster); err != nil { - return "", ClusterUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) - } - logger.Infof("Resource with name: %v, type: %T, contains: %v", cluster.GetName(), cluster, cluster) - - cu, err := validateCluster(cluster) - if err != nil { - return cluster.GetName(), ClusterUpdate{}, err - } - cu.Raw = r - // If the Cluster message in the CDS response did not contain a - // serviceName, we will just use the clusterName for EDS. - if cu.ServiceName == "" { - cu.ServiceName = cluster.GetName() - } - return cluster.GetName(), cu, nil -} - -func validateCluster(cluster *v3clusterpb.Cluster) (ClusterUpdate, error) { - emptyUpdate := ClusterUpdate{ServiceName: "", EnableLRS: false} - switch { - case cluster.GetType() != v3clusterpb.Cluster_EDS: - return emptyUpdate, fmt.Errorf("unexpected cluster type %v in response: %+v", cluster.GetType(), cluster) - case cluster.GetEdsClusterConfig().GetEdsConfig().GetAds() == nil: - return emptyUpdate, fmt.Errorf("unexpected edsConfig in response: %+v", cluster) - case cluster.GetLbPolicy() != v3clusterpb.Cluster_ROUND_ROBIN: - return emptyUpdate, fmt.Errorf("unexpected lbPolicy %v in response: %+v", cluster.GetLbPolicy(), cluster) - } - - // Process security configuration received from the control plane iff the - // corresponding environment variable is set. - var sc *SecurityConfig - if env.ClientSideSecuritySupport { - var err error - if sc, err = securityConfigFromCluster(cluster); err != nil { - return emptyUpdate, err - } - } - - return ClusterUpdate{ - ServiceName: cluster.GetEdsClusterConfig().GetServiceName(), - EnableLRS: cluster.GetLrsServer().GetSelf() != nil, - SecurityCfg: sc, - MaxRequests: circuitBreakersFromCluster(cluster), - }, nil -} - -// securityConfigFromCluster extracts the relevant security configuration from -// the received Cluster resource. -func securityConfigFromCluster(cluster *v3clusterpb.Cluster) (*SecurityConfig, error) { - // The Cluster resource contains a `transport_socket` field, which contains - // a oneof `typed_config` field of type `protobuf.Any`. The any proto - // contains a marshaled representation of an `UpstreamTlsContext` message. - ts := cluster.GetTransportSocket() - if ts == nil { - return nil, nil - } - if name := ts.GetName(); name != transportSocketName { - return nil, fmt.Errorf("transport_socket field has unexpected name: %s", name) - } - any := ts.GetTypedConfig() - if any == nil || any.TypeUrl != version.V3UpstreamTLSContextURL { - return nil, fmt.Errorf("transport_socket field has unexpected typeURL: %s", any.TypeUrl) - } - upstreamCtx := &v3tlspb.UpstreamTlsContext{} - if err := proto.Unmarshal(any.GetValue(), upstreamCtx); err != nil { - return nil, fmt.Errorf("failed to unmarshal UpstreamTlsContext in CDS response: %v", err) - } - if upstreamCtx.GetCommonTlsContext() == nil { - return nil, errors.New("UpstreamTlsContext in CDS response does not contain a CommonTlsContext") - } - - sc, err := securityConfigFromCommonTLSContext(upstreamCtx.GetCommonTlsContext()) - if err != nil { - return nil, err - } - if sc.RootInstanceName == "" { - return nil, errors.New("security configuration on the client-side does not contain root certificate provider instance name") - } - return sc, nil -} - -// common is expected to be not nil. -func securityConfigFromCommonTLSContext(common *v3tlspb.CommonTlsContext) (*SecurityConfig, error) { - // The `CommonTlsContext` contains a - // `tls_certificate_certificate_provider_instance` field of type - // `CertificateProviderInstance`, which contains the provider instance name - // and the certificate name to fetch identity certs. - sc := &SecurityConfig{} - if identity := common.GetTlsCertificateCertificateProviderInstance(); identity != nil { - sc.IdentityInstanceName = identity.GetInstanceName() - sc.IdentityCertName = identity.GetCertificateName() - } - - // The `CommonTlsContext` contains a `validation_context_type` field which - // is a oneof. We can get the values that we are interested in from two of - // those possible values: - // - combined validation context: - // - contains a default validation context which holds the list of - // accepted SANs. - // - contains certificate provider instance configuration - // - certificate provider instance configuration - // - in this case, we do not get a list of accepted SANs. - switch t := common.GetValidationContextType().(type) { - case *v3tlspb.CommonTlsContext_CombinedValidationContext: - combined := common.GetCombinedValidationContext() - if def := combined.GetDefaultValidationContext(); def != nil { - for _, matcher := range def.GetMatchSubjectAltNames() { - // We only support exact matches for now. - if exact := matcher.GetExact(); exact != "" { - sc.AcceptedSANs = append(sc.AcceptedSANs, exact) - } - } - } - if pi := combined.GetValidationContextCertificateProviderInstance(); pi != nil { - sc.RootInstanceName = pi.GetInstanceName() - sc.RootCertName = pi.GetCertificateName() - } - case *v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance: - pi := common.GetValidationContextCertificateProviderInstance() - sc.RootInstanceName = pi.GetInstanceName() - sc.RootCertName = pi.GetCertificateName() - case nil: - // It is valid for the validation context to be nil on the server side. - default: - return nil, fmt.Errorf("validation context contains unexpected type: %T", t) - } - return sc, nil -} - -// circuitBreakersFromCluster extracts the circuit breakers configuration from -// the received cluster resource. Returns nil if no CircuitBreakers or no -// Thresholds in CircuitBreakers. -func circuitBreakersFromCluster(cluster *v3clusterpb.Cluster) *uint32 { - if !env.CircuitBreakingSupport { - return nil - } - for _, threshold := range cluster.GetCircuitBreakers().GetThresholds() { - if threshold.GetPriority() != v3corepb.RoutingPriority_DEFAULT { - continue - } - maxRequestsPb := threshold.GetMaxRequests() - if maxRequestsPb == nil { - return nil - } - maxRequests := maxRequestsPb.GetValue() - return &maxRequests - } - return nil -} - -// UnmarshalEndpoints processes resources received in an EDS response, -// validates them, and transforms them into a native struct which contains only -// fields we are interested in. -func UnmarshalEndpoints(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger) (map[string]EndpointsUpdate, UpdateMetadata, error) { - update := make(map[string]EndpointsUpdate) - md, err := processAllResources(version, resources, logger, update) - return update, md, err -} - -func unmarshalEndpointsResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, EndpointsUpdate, error) { - if !IsEndpointsResource(r.GetTypeUrl()) { - return "", EndpointsUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) - } - - cla := &v3endpointpb.ClusterLoadAssignment{} - if err := proto.Unmarshal(r.GetValue(), cla); err != nil { - return "", EndpointsUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) - } - logger.Infof("Resource with name: %v, type: %T, contains: %v", cla.GetClusterName(), cla, cla) - - u, err := parseEDSRespProto(cla) - if err != nil { - return cla.GetClusterName(), EndpointsUpdate{}, err - } - u.Raw = r - return cla.GetClusterName(), u, nil -} - -func parseAddress(socketAddress *v3corepb.SocketAddress) string { - return net.JoinHostPort(socketAddress.GetAddress(), strconv.Itoa(int(socketAddress.GetPortValue()))) -} - -func parseDropPolicy(dropPolicy *v3endpointpb.ClusterLoadAssignment_Policy_DropOverload) OverloadDropConfig { - percentage := dropPolicy.GetDropPercentage() - var ( - numerator = percentage.GetNumerator() - denominator uint32 - ) - switch percentage.GetDenominator() { - case v3typepb.FractionalPercent_HUNDRED: - denominator = 100 - case v3typepb.FractionalPercent_TEN_THOUSAND: - denominator = 10000 - case v3typepb.FractionalPercent_MILLION: - denominator = 1000000 - } - return OverloadDropConfig{ - Category: dropPolicy.GetCategory(), - Numerator: numerator, - Denominator: denominator, - } -} - -func parseEndpoints(lbEndpoints []*v3endpointpb.LbEndpoint) []Endpoint { - endpoints := make([]Endpoint, 0, len(lbEndpoints)) - for _, lbEndpoint := range lbEndpoints { - endpoints = append(endpoints, Endpoint{ - HealthStatus: EndpointHealthStatus(lbEndpoint.GetHealthStatus()), - Address: parseAddress(lbEndpoint.GetEndpoint().GetAddress().GetSocketAddress()), - Weight: lbEndpoint.GetLoadBalancingWeight().GetValue(), - }) - } - return endpoints -} - -func parseEDSRespProto(m *v3endpointpb.ClusterLoadAssignment) (EndpointsUpdate, error) { - ret := EndpointsUpdate{} - for _, dropPolicy := range m.GetPolicy().GetDropOverloads() { - ret.Drops = append(ret.Drops, parseDropPolicy(dropPolicy)) - } - priorities := make(map[uint32]struct{}) - for _, locality := range m.Endpoints { - l := locality.GetLocality() - if l == nil { - return EndpointsUpdate{}, fmt.Errorf("EDS response contains a locality without ID, locality: %+v", locality) - } - lid := internal.LocalityID{ - Region: l.Region, - Zone: l.Zone, - SubZone: l.SubZone, - } - priority := locality.GetPriority() - priorities[priority] = struct{}{} - ret.Localities = append(ret.Localities, Locality{ - ID: lid, - Endpoints: parseEndpoints(locality.GetLbEndpoints()), - Weight: locality.GetLoadBalancingWeight().GetValue(), - Priority: priority, - }) - } - for i := 0; i < len(priorities); i++ { - if _, ok := priorities[uint32(i)]; !ok { - return EndpointsUpdate{}, fmt.Errorf("priority %v missing (with different priorities %v received)", i, priorities) - } - } - return ret, nil -} - -// processAllResources unmarshals and validates the resources, populates the -// provided ret (a map), and returns metadata and error. -// -// The type of the resource is determined by the type of ret. E.g. -// map[string]ListenerUpdate means this is for LDS. -func processAllResources(version string, resources []*anypb.Any, logger *grpclog.PrefixLogger, ret interface{}) (UpdateMetadata, error) { - timestamp := time.Now() - md := UpdateMetadata{ - Version: version, - Timestamp: timestamp, - } - var topLevelErrors []error - perResourceErrors := make(map[string]error) - - for _, r := range resources { - switch ret2 := ret.(type) { - case map[string]ListenerUpdate: - name, update, err := unmarshalListenerResource(r, logger) - if err == nil { - ret2[name] = update - continue - } - if name == "" { - topLevelErrors = append(topLevelErrors, err) - continue - } - perResourceErrors[name] = err - // Add place holder in the map so we know this resource name was in - // the response. - ret2[name] = ListenerUpdate{} - case map[string]RouteConfigUpdate: - name, update, err := unmarshalRouteConfigResource(r, logger) - if err == nil { - ret2[name] = update - continue - } - if name == "" { - topLevelErrors = append(topLevelErrors, err) - continue - } - perResourceErrors[name] = err - // Add place holder in the map so we know this resource name was in - // the response. - ret2[name] = RouteConfigUpdate{} - case map[string]ClusterUpdate: - name, update, err := unmarshalClusterResource(r, logger) - if err == nil { - ret2[name] = update - continue - } - if name == "" { - topLevelErrors = append(topLevelErrors, err) - continue - } - perResourceErrors[name] = err - // Add place holder in the map so we know this resource name was in - // the response. - ret2[name] = ClusterUpdate{} - case map[string]EndpointsUpdate: - name, update, err := unmarshalEndpointsResource(r, logger) - if err == nil { - ret2[name] = update - continue - } - if name == "" { - topLevelErrors = append(topLevelErrors, err) - continue - } - perResourceErrors[name] = err - // Add place holder in the map so we know this resource name was in - // the response. - ret2[name] = EndpointsUpdate{} - } - } - - if len(topLevelErrors) == 0 && len(perResourceErrors) == 0 { - md.Status = ServiceStatusACKed - return md, nil - } - - var typeStr string - switch ret.(type) { - case map[string]ListenerUpdate: - typeStr = "LDS" - case map[string]RouteConfigUpdate: - typeStr = "RDS" - case map[string]ClusterUpdate: - typeStr = "CDS" - case map[string]EndpointsUpdate: - typeStr = "EDS" - } - - md.Status = ServiceStatusNACKed - errRet := combineErrors(typeStr, topLevelErrors, perResourceErrors) - md.ErrState = &UpdateErrorMetadata{ - Version: version, - Err: errRet, - Timestamp: timestamp, - } - return md, errRet -} - -func combineErrors(rType string, topLevelErrors []error, perResourceErrors map[string]error) error { - var errStrB strings.Builder - errStrB.WriteString(fmt.Sprintf("error parsing %q response: ", rType)) - if len(topLevelErrors) > 0 { - errStrB.WriteString("top level errors: ") - for i, err := range topLevelErrors { - if i != 0 { - errStrB.WriteString(";\n") - } - errStrB.WriteString(err.Error()) - } - } - if len(perResourceErrors) > 0 { - var i int - for name, err := range perResourceErrors { - if i != 0 { - errStrB.WriteString(";\n") - } - i++ - errStrB.WriteString(fmt.Sprintf("resource %q: %v", name, err.Error())) - } - } - return errors.New(errStrB.String()) -} diff --git a/xds/internal/httpfilter/fault/fault.go b/xds/internal/httpfilter/fault/fault.go new file mode 100644 index 00000000000..725b50a76a8 --- /dev/null +++ b/xds/internal/httpfilter/fault/fault.go @@ -0,0 +1,301 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package fault implements the Envoy Fault Injection HTTP filter. +package fault + +import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "sync/atomic" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpcrand" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/protobuf/types/known/anypb" + + cpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/common/fault/v3" + fpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/fault/v3" + tpb "github.com/envoyproxy/go-control-plane/envoy/type/v3" +) + +const headerAbortHTTPStatus = "x-envoy-fault-abort-request" +const headerAbortGRPCStatus = "x-envoy-fault-abort-grpc-request" +const headerAbortPercentage = "x-envoy-fault-abort-request-percentage" + +const headerDelayPercentage = "x-envoy-fault-delay-request-percentage" +const headerDelayDuration = "x-envoy-fault-delay-request" + +var statusMap = map[int]codes.Code{ + 400: codes.Internal, + 401: codes.Unauthenticated, + 403: codes.PermissionDenied, + 404: codes.Unimplemented, + 429: codes.Unavailable, + 502: codes.Unavailable, + 503: codes.Unavailable, + 504: codes.Unavailable, +} + +func init() { + httpfilter.Register(builder{}) +} + +type builder struct { +} + +type config struct { + httpfilter.FilterConfig + config *fpb.HTTPFault +} + +func (builder) TypeURLs() []string { + return []string{"type.googleapis.com/envoy.extensions.filters.http.fault.v3.HTTPFault"} +} + +// Parsing is the same for the base config and the override config. +func parseConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + if cfg == nil { + return nil, fmt.Errorf("fault: nil configuration message provided") + } + any, ok := cfg.(*anypb.Any) + if !ok { + return nil, fmt.Errorf("fault: error parsing config %v: unknown type %T", cfg, cfg) + } + msg := new(fpb.HTTPFault) + if err := ptypes.UnmarshalAny(any, msg); err != nil { + return nil, fmt.Errorf("fault: error parsing config %v: %v", cfg, err) + } + return config{config: msg}, nil +} + +func (builder) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return parseConfig(cfg) +} + +func (builder) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return parseConfig(override) +} + +func (builder) IsTerminal() bool { + return false +} + +var _ httpfilter.ClientInterceptorBuilder = builder{} + +func (builder) BuildClientInterceptor(cfg, override httpfilter.FilterConfig) (iresolver.ClientInterceptor, error) { + if cfg == nil { + return nil, fmt.Errorf("fault: nil config provided") + } + + c, ok := cfg.(config) + if !ok { + return nil, fmt.Errorf("fault: incorrect config type provided (%T): %v", cfg, cfg) + } + + if override != nil { + // override completely replaces the listener configuration; but we + // still validate the listener config type. + c, ok = override.(config) + if !ok { + return nil, fmt.Errorf("fault: incorrect override config type provided (%T): %v", override, override) + } + } + + icfg := c.config + if (icfg.GetMaxActiveFaults() != nil && icfg.GetMaxActiveFaults().GetValue() == 0) || + (icfg.GetDelay() == nil && icfg.GetAbort() == nil) { + return nil, nil + } + return &interceptor{config: icfg}, nil +} + +type interceptor struct { + config *fpb.HTTPFault +} + +var activeFaults uint32 // global active faults; accessed atomically + +func (i *interceptor) NewStream(ctx context.Context, ri iresolver.RPCInfo, done func(), newStream func(ctx context.Context, done func()) (iresolver.ClientStream, error)) (iresolver.ClientStream, error) { + if maxAF := i.config.GetMaxActiveFaults(); maxAF != nil { + defer atomic.AddUint32(&activeFaults, ^uint32(0)) // decrement counter + if af := atomic.AddUint32(&activeFaults, 1); af > maxAF.GetValue() { + // Would exceed maximum active fault limit. + return newStream(ctx, done) + } + } + + if err := injectDelay(ctx, i.config.GetDelay()); err != nil { + return nil, err + } + + if err := injectAbort(ctx, i.config.GetAbort()); err != nil { + if err == errOKStream { + return &okStream{ctx: ctx}, nil + } + return nil, err + } + return newStream(ctx, done) +} + +// For overriding in tests +var randIntn = grpcrand.Intn +var newTimer = time.NewTimer + +func injectDelay(ctx context.Context, delayCfg *cpb.FaultDelay) error { + numerator, denominator := splitPct(delayCfg.GetPercentage()) + var delay time.Duration + switch delayType := delayCfg.GetFaultDelaySecifier().(type) { + case *cpb.FaultDelay_FixedDelay: + delay = delayType.FixedDelay.AsDuration() + case *cpb.FaultDelay_HeaderDelay_: + md, _ := metadata.FromOutgoingContext(ctx) + v := md[headerDelayDuration] + if v == nil { + // No delay configured for this RPC. + return nil + } + ms, ok := parseIntFromMD(v) + if !ok { + // Malformed header; no delay. + return nil + } + delay = time.Duration(ms) * time.Millisecond + if v := md[headerDelayPercentage]; v != nil { + if num, ok := parseIntFromMD(v); ok && num < numerator { + numerator = num + } + } + } + if delay == 0 || randIntn(denominator) >= numerator { + return nil + } + t := newTimer(delay) + select { + case <-t.C: + case <-ctx.Done(): + t.Stop() + return ctx.Err() + } + return nil +} + +func injectAbort(ctx context.Context, abortCfg *fpb.FaultAbort) error { + numerator, denominator := splitPct(abortCfg.GetPercentage()) + code := codes.OK + okCode := false + switch errType := abortCfg.GetErrorType().(type) { + case *fpb.FaultAbort_HttpStatus: + code, okCode = grpcFromHTTP(int(errType.HttpStatus)) + case *fpb.FaultAbort_GrpcStatus: + code, okCode = sanitizeGRPCCode(codes.Code(errType.GrpcStatus)), true + case *fpb.FaultAbort_HeaderAbort_: + md, _ := metadata.FromOutgoingContext(ctx) + if v := md[headerAbortHTTPStatus]; v != nil { + // HTTP status has priority over gRPC status. + if httpStatus, ok := parseIntFromMD(v); ok { + code, okCode = grpcFromHTTP(httpStatus) + } + } else if v := md[headerAbortGRPCStatus]; v != nil { + if grpcStatus, ok := parseIntFromMD(v); ok { + code, okCode = sanitizeGRPCCode(codes.Code(grpcStatus)), true + } + } + if v := md[headerAbortPercentage]; v != nil { + if num, ok := parseIntFromMD(v); ok && num < numerator { + numerator = num + } + } + } + if !okCode || randIntn(denominator) >= numerator { + return nil + } + if code == codes.OK { + return errOKStream + } + return status.Errorf(code, "RPC terminated due to fault injection") +} + +var errOKStream = errors.New("stream terminated early with OK status") + +// parseIntFromMD returns the integer in the last header or nil if parsing +// failed. +func parseIntFromMD(header []string) (int, bool) { + if len(header) == 0 { + return 0, false + } + v, err := strconv.Atoi(header[len(header)-1]) + return v, err == nil +} + +func splitPct(fp *tpb.FractionalPercent) (num int, den int) { + if fp == nil { + return 0, 100 + } + num = int(fp.GetNumerator()) + switch fp.GetDenominator() { + case tpb.FractionalPercent_HUNDRED: + return num, 100 + case tpb.FractionalPercent_TEN_THOUSAND: + return num, 10 * 1000 + case tpb.FractionalPercent_MILLION: + return num, 1000 * 1000 + } + return num, 100 +} + +func grpcFromHTTP(httpStatus int) (codes.Code, bool) { + if httpStatus < 200 || httpStatus >= 600 { + // Malformed; ignore this fault type. + return codes.OK, false + } + if c := statusMap[httpStatus]; c != codes.OK { + // OK = 0/the default for the map. + return c, true + } + // All undefined HTTP status codes convert to Unknown. HTTP status of 200 + // is "success", but gRPC converts to Unknown due to missing grpc status. + return codes.Unknown, true +} + +func sanitizeGRPCCode(c codes.Code) codes.Code { + if c > codes.Code(16) { + return codes.Unknown + } + return c +} + +type okStream struct { + ctx context.Context +} + +func (*okStream) Header() (metadata.MD, error) { return nil, nil } +func (*okStream) Trailer() metadata.MD { return nil } +func (*okStream) CloseSend() error { return nil } +func (o *okStream) Context() context.Context { return o.ctx } +func (*okStream) SendMsg(m interface{}) error { return io.EOF } +func (*okStream) RecvMsg(m interface{}) error { return io.EOF } diff --git a/xds/internal/httpfilter/fault/fault_test.go b/xds/internal/httpfilter/fault/fault_test.go new file mode 100644 index 00000000000..c2959054da9 --- /dev/null +++ b/xds/internal/httpfilter/fault/fault_test.go @@ -0,0 +1,672 @@ +//go:build !386 +// +build !386 + +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package xds_test contains e2e tests for xDS use. +package fault + +import ( + "context" + "fmt" + "io" + "net" + "reflect" + "testing" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + xtestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/protobuf/types/known/wrapperspb" + + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + cpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/common/fault/v3" + fpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/fault/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + tpb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + testpb "google.golang.org/grpc/test/grpc_testing" + + _ "google.golang.org/grpc/xds/internal/balancer" // Register the balancers. + _ "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register the v3 xDS API client. +) + +const defaultTestTimeout = 10 * time.Second + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type testService struct { + testpb.TestServiceServer +} + +func (*testService) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil +} + +func (*testService) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { + // End RPC after client does a CloseSend. + for { + if _, err := stream.Recv(); err == io.EOF { + return nil + } else if err != nil { + return err + } + } +} + +// clientSetup performs a bunch of steps common to all xDS server tests here: +// - spin up an xDS management server on a local port +// - spin up a gRPC server and register the test service on it +// - create a local TCP listener and start serving on it +// +// Returns the following: +// - the management server: tests use this to configure resources +// - nodeID expected by the management server: this is set in the Node proto +// sent by the xdsClient for queries. +// - the port the server is listening on +// - cleanup function to be invoked by the tests when done +func clientSetup(t *testing.T) (*e2e.ManagementServer, string, uint32, func()) { + // Spin up a xDS management server on a local port. + nodeID := uuid.New().String() + fs, err := e2e.StartManagementServer() + if err != nil { + t.Fatal(err) + } + + // Create a bootstrap file in a temporary directory. + bootstrapCleanup, err := xds.SetupBootstrapFile(xds.BootstrapOptions{ + Version: xds.TransportV3, + NodeID: nodeID, + ServerURI: fs.Address, + ServerListenerResourceNameTemplate: "grpc/server", + }) + if err != nil { + t.Fatal(err) + } + + // Initialize a gRPC server and register the stubServer on it. + server := grpc.NewServer() + testpb.RegisterTestServiceServer(server, &testService{}) + + // Create a local listener and pass it to Serve(). + lis, err := xtestutils.LocalTCPListener() + if err != nil { + t.Fatalf("xtestutils.LocalTCPListener() failed: %v", err) + } + + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + return fs, nodeID, uint32(lis.Addr().(*net.TCPAddr).Port), func() { + fs.Stop() + bootstrapCleanup() + server.Stop() + } +} + +func (s) TestFaultInjection_Unary(t *testing.T) { + type subcase struct { + name string + code codes.Code + repeat int + randIn []int // Intn calls per-repeat (not per-subcase) + delays []time.Duration // NewTimer calls per-repeat (not per-subcase) + md metadata.MD + } + testCases := []struct { + name string + cfgs []*fpb.HTTPFault + randOutInc int + want []subcase + }{{ + name: "max faults zero", + cfgs: []*fpb.HTTPFault{{ + MaxActiveFaults: wrapperspb.UInt32(0), + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 100, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Aborted)}, + }, + }}, + randOutInc: 5, + want: []subcase{{ + code: codes.OK, + repeat: 25, + }}, + }, { + name: "no abort or delay", + cfgs: []*fpb.HTTPFault{{}}, + randOutInc: 5, + want: []subcase{{ + code: codes.OK, + repeat: 25, + }}, + }, { + name: "abort always", + cfgs: []*fpb.HTTPFault{{ + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 100, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Aborted)}, + }, + }}, + randOutInc: 5, + want: []subcase{{ + code: codes.Aborted, + randIn: []int{100}, + repeat: 25, + }}, + }, { + name: "abort 10%", + cfgs: []*fpb.HTTPFault{{ + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 100000, Denominator: tpb.FractionalPercent_MILLION}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Aborted)}, + }, + }}, + randOutInc: 50000, + want: []subcase{{ + name: "[0,10]%", + code: codes.Aborted, + randIn: []int{1000000}, + repeat: 2, + }, { + name: "(10,100]%", + code: codes.OK, + randIn: []int{1000000}, + repeat: 18, + }, { + name: "[0,10]% again", + code: codes.Aborted, + randIn: []int{1000000}, + repeat: 2, + }}, + }, { + name: "delay always", + cfgs: []*fpb.HTTPFault{{ + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 100, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(time.Second)}, + }, + }}, + randOutInc: 5, + want: []subcase{{ + randIn: []int{100}, + repeat: 25, + delays: []time.Duration{time.Second}, + }}, + }, { + name: "delay 10%", + cfgs: []*fpb.HTTPFault{{ + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 1000, Denominator: tpb.FractionalPercent_TEN_THOUSAND}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(time.Second)}, + }, + }}, + randOutInc: 500, + want: []subcase{{ + name: "[0,10]%", + randIn: []int{10000}, + repeat: 2, + delays: []time.Duration{time.Second}, + }, { + name: "(10,100]%", + randIn: []int{10000}, + repeat: 18, + }, { + name: "[0,10]% again", + randIn: []int{10000}, + repeat: 2, + delays: []time.Duration{time.Second}, + }}, + }, { + name: "delay 80%, abort 50%", + cfgs: []*fpb.HTTPFault{{ + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 80, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(3 * time.Second)}, + }, + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 50, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Unimplemented)}, + }, + }}, + randOutInc: 5, + want: []subcase{{ + name: "50% delay and abort", + code: codes.Unimplemented, + randIn: []int{100, 100}, + repeat: 10, + delays: []time.Duration{3 * time.Second}, + }, { + name: "30% delay, no abort", + randIn: []int{100, 100}, + repeat: 6, + delays: []time.Duration{3 * time.Second}, + }, { + name: "20% success", + randIn: []int{100, 100}, + repeat: 4, + }, { + name: "50% delay and abort again", + code: codes.Unimplemented, + randIn: []int{100, 100}, + repeat: 10, + delays: []time.Duration{3 * time.Second}, + }}, + }, { + name: "header abort", + cfgs: []*fpb.HTTPFault{{ + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 80, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_HeaderAbort_{}, + }, + }}, + randOutInc: 10, + want: []subcase{{ + name: "30% abort; [0,30]%", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"30"}, + }, + code: codes.DataLoss, + randIn: []int{100}, + repeat: 3, + }, { + name: "30% abort; (30,60]%", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"30"}, + }, + randIn: []int{100}, + repeat: 3, + }, { + name: "80% abort; (60,80]%", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"80"}, + }, + code: codes.DataLoss, + randIn: []int{100}, + repeat: 2, + }, { + name: "cannot exceed percentage in filter", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"100"}, + }, + randIn: []int{100}, + repeat: 2, + }, { + name: "HTTP Status 404", + md: metadata.MD{ + headerAbortHTTPStatus: []string{"404"}, + headerAbortPercentage: []string{"100"}, + }, + code: codes.Unimplemented, + randIn: []int{100}, + repeat: 1, + }, { + name: "HTTP Status 429", + md: metadata.MD{ + headerAbortHTTPStatus: []string{"429"}, + headerAbortPercentage: []string{"100"}, + }, + code: codes.Unavailable, + randIn: []int{100}, + repeat: 1, + }, { + name: "HTTP Status 200", + md: metadata.MD{ + headerAbortHTTPStatus: []string{"200"}, + headerAbortPercentage: []string{"100"}, + }, + // No GRPC status, but HTTP Status of 200 translates to Unknown, + // per spec in statuscodes.md. + code: codes.Unknown, + randIn: []int{100}, + repeat: 1, + }, { + name: "gRPC Status OK", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.OK)}, + headerAbortPercentage: []string{"100"}, + }, + // This should be Unimplemented (mismatched request/response + // count), per spec in statuscodes.md, but grpc-go currently + // returns io.EOF which status.Code() converts to Unknown + code: codes.Unknown, + randIn: []int{100}, + repeat: 1, + }, { + name: "invalid header results in no abort", + md: metadata.MD{ + headerAbortGRPCStatus: []string{"error"}, + headerAbortPercentage: []string{"100"}, + }, + repeat: 1, + }, { + name: "invalid header results in default percentage", + md: metadata.MD{ + headerAbortGRPCStatus: []string{fmt.Sprintf("%d", codes.DataLoss)}, + headerAbortPercentage: []string{"error"}, + }, + code: codes.DataLoss, + randIn: []int{100}, + repeat: 1, + }}, + }, { + name: "header delay", + cfgs: []*fpb.HTTPFault{{ + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 80, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_HeaderDelay_{}, + }, + }}, + randOutInc: 10, + want: []subcase{{ + name: "30% delay; [0,30]%", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"30"}, + }, + randIn: []int{100}, + delays: []time.Duration{2 * time.Millisecond}, + repeat: 3, + }, { + name: "30% delay; (30, 60]%", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"30"}, + }, + randIn: []int{100}, + repeat: 3, + }, { + name: "invalid header results in no delay", + md: metadata.MD{ + headerDelayDuration: []string{"error"}, + headerDelayPercentage: []string{"80"}, + }, + repeat: 1, + }, { + name: "invalid header results in default percentage", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"error"}, + }, + randIn: []int{100}, + delays: []time.Duration{2 * time.Millisecond}, + repeat: 1, + }, { + name: "invalid header results in default percentage", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"error"}, + }, + randIn: []int{100}, + repeat: 1, + }, { + name: "cannot exceed percentage in filter", + md: metadata.MD{ + headerDelayDuration: []string{"2"}, + headerDelayPercentage: []string{"100"}, + }, + randIn: []int{100}, + repeat: 1, + }}, + }, { + name: "abort then delay filters", + cfgs: []*fpb.HTTPFault{{ + Abort: &fpb.FaultAbort{ + Percentage: &tpb.FractionalPercent{Numerator: 50, Denominator: tpb.FractionalPercent_HUNDRED}, + ErrorType: &fpb.FaultAbort_GrpcStatus{GrpcStatus: uint32(codes.Unimplemented)}, + }, + }, { + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 80, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(time.Second)}, + }, + }}, + randOutInc: 10, + want: []subcase{{ + name: "50% delay and abort (abort skips delay)", + code: codes.Unimplemented, + randIn: []int{100}, + repeat: 5, + }, { + name: "30% delay, no abort", + randIn: []int{100, 100}, + repeat: 3, + delays: []time.Duration{time.Second}, + }, { + name: "20% success", + randIn: []int{100, 100}, + repeat: 2, + }}, + }} + + fs, nodeID, port, cleanup := clientSetup(t) + defer cleanup() + + for tcNum, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer func() { randIntn = grpcrand.Intn; newTimer = time.NewTimer }() + var intnCalls []int + var newTimerCalls []time.Duration + randOut := 0 + randIntn = func(n int) int { + intnCalls = append(intnCalls, n) + return randOut % n + } + + newTimer = func(d time.Duration) *time.Timer { + newTimerCalls = append(newTimerCalls, d) + return time.NewTimer(0) + } + + serviceName := fmt.Sprintf("myservice%d", tcNum) + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + hcm := new(v3httppb.HttpConnectionManager) + err := ptypes.UnmarshalAny(resources.Listeners[0].GetApiListener().GetApiListener(), hcm) + if err != nil { + t.Fatal(err) + } + routerFilter := hcm.HttpFilters[len(hcm.HttpFilters)-1] + + hcm.HttpFilters = nil + for i, cfg := range tc.cfgs { + hcm.HttpFilters = append(hcm.HttpFilters, e2e.HTTPFilter(fmt.Sprintf("fault%d", i), cfg)) + } + hcm.HttpFilters = append(hcm.HttpFilters, routerFilter) + hcmAny := testutils.MarshalAny(hcm) + resources.Listeners[0].ApiListener.ApiListener = hcmAny + resources.Listeners[0].FilterChains[0].Filters[0].ConfigType = &v3listenerpb.Filter_TypedConfig{TypedConfig: hcmAny} + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := fs.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create a ClientConn and run the test case. + cc, err := grpc.Dial("xds:///"+serviceName, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + count := 0 + for _, want := range tc.want { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if want.repeat == 0 { + t.Fatalf("invalid repeat count") + } + for n := 0; n < want.repeat; n++ { + intnCalls = nil + newTimerCalls = nil + ctx = metadata.NewOutgoingContext(ctx, want.md) + _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)) + t.Logf("%v: RPC %d: err: %v, intnCalls: %v, newTimerCalls: %v", want.name, count, err, intnCalls, newTimerCalls) + if status.Code(err) != want.code || !reflect.DeepEqual(intnCalls, want.randIn) || !reflect.DeepEqual(newTimerCalls, want.delays) { + t.Fatalf("WANTED code: %v, intnCalls: %v, newTimerCalls: %v", want.code, want.randIn, want.delays) + } + randOut += tc.randOutInc + count++ + } + } + }) + } +} + +func (s) TestFaultInjection_MaxActiveFaults(t *testing.T) { + fs, nodeID, port, cleanup := clientSetup(t) + defer cleanup() + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: "myservice", + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + hcm := new(v3httppb.HttpConnectionManager) + err := ptypes.UnmarshalAny(resources.Listeners[0].GetApiListener().GetApiListener(), hcm) + if err != nil { + t.Fatal(err) + } + + defer func() { newTimer = time.NewTimer }() + timers := make(chan *time.Timer, 2) + newTimer = func(d time.Duration) *time.Timer { + t := time.NewTimer(24 * time.Hour) // Will reset to fire. + timers <- t + return t + } + + hcm.HttpFilters = append([]*v3httppb.HttpFilter{ + e2e.HTTPFilter("fault", &fpb.HTTPFault{ + MaxActiveFaults: wrapperspb.UInt32(2), + Delay: &cpb.FaultDelay{ + Percentage: &tpb.FractionalPercent{Numerator: 100, Denominator: tpb.FractionalPercent_HUNDRED}, + FaultDelaySecifier: &cpb.FaultDelay_FixedDelay{FixedDelay: ptypes.DurationProto(time.Second)}, + }, + })}, + hcm.HttpFilters...) + hcmAny := testutils.MarshalAny(hcm) + resources.Listeners[0].ApiListener.ApiListener = hcmAny + resources.Listeners[0].FilterChains[0].Filters[0].ConfigType = &v3listenerpb.Filter_TypedConfig{TypedConfig: hcmAny} + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := fs.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create a ClientConn + cc, err := grpc.Dial("xds:///myservice", grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + + streams := make(chan testpb.TestService_FullDuplexCallClient, 5) // startStream() is called 5 times + startStream := func() { + str, err := client.FullDuplexCall(ctx) + if err != nil { + t.Error("RPC error:", err) + } + streams <- str + } + endStream := func() { + str := <-streams + str.CloseSend() + if _, err := str.Recv(); err != io.EOF { + t.Error("stream error:", err) + } + } + releaseStream := func() { + timer := <-timers + timer.Reset(0) + } + + // Start three streams; two should delay. + go startStream() + go startStream() + go startStream() + + // End one of the streams. Ensure the others are blocked on creation. + endStream() + + select { + case <-streams: + t.Errorf("unexpected second stream created before delay expires") + case <-time.After(50 * time.Millisecond): + // Wait a short time to ensure no other streams were started yet. + } + + // Start one more; it should not be blocked. + go startStream() + endStream() + + // Expire one stream's delay; it should be created. + releaseStream() + endStream() + + // Another new stream should delay. + go startStream() + select { + case <-streams: + t.Errorf("unexpected second stream created before delay expires") + case <-time.After(50 * time.Millisecond): + // Wait a short time to ensure no other streams were started yet. + } + + // Expire both pending timers and end the two streams. + releaseStream() + releaseStream() + endStream() + endStream() +} diff --git a/xds/internal/httpfilter/httpfilter.go b/xds/internal/httpfilter/httpfilter.go index 6650241fab7..b4399f9faeb 100644 --- a/xds/internal/httpfilter/httpfilter.go +++ b/xds/internal/httpfilter/httpfilter.go @@ -50,6 +50,9 @@ type Filter interface { // not accept a custom type. The resulting FilterConfig will later be // passed to Build. ParseFilterConfigOverride(proto.Message) (FilterConfig, error) + // IsTerminal returns whether this Filter is terminal or not (i.e. it must + // be last filter in the filter chain). + IsTerminal() bool } // ClientInterceptorBuilder constructs a Client Interceptor. If this type is @@ -65,9 +68,6 @@ type ClientInterceptorBuilder interface { // ServerInterceptorBuilder constructs a Server Interceptor. If this type is // implemented by a Filter, it is capable of working on a server. -// -// Server side filters are not currently supported, but this interface is -// defined for clarity. type ServerInterceptorBuilder interface { // BuildServerInterceptor uses the FilterConfigs produced above to produce // an HTTP filter interceptor for servers. config will always be non-nil, @@ -94,6 +94,11 @@ func Register(b Filter) { } } +// UnregisterForTesting unregisters the HTTP Filter for testing purposes. +func UnregisterForTesting(typeURL string) { + delete(m, typeURL) +} + // Get returns the HTTPFilter registered with typeURL. // // If no filter is register with typeURL, nil will be returned. diff --git a/xds/internal/httpfilter/rbac/rbac.go b/xds/internal/httpfilter/rbac/rbac.go new file mode 100644 index 00000000000..e92e2e64421 --- /dev/null +++ b/xds/internal/httpfilter/rbac/rbac.go @@ -0,0 +1,220 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package rbac implements the Envoy RBAC HTTP filter. +package rbac + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/internal/xds/rbac" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/protobuf/types/known/anypb" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + rpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/rbac/v3" +) + +func init() { + if env.RBACSupport { + httpfilter.Register(builder{}) + } +} + +// RegisterForTesting registers the RBAC HTTP Filter for testing purposes, regardless +// of the RBAC environment variable. This is needed because there is no way to set the RBAC +// environment variable to true in a test before init() in this package is run. +func RegisterForTesting() { + httpfilter.Register(builder{}) +} + +// UnregisterForTesting unregisters the RBAC HTTP Filter for testing purposes. This is needed because +// there is no way to unregister the HTTP Filter after registering it solely for testing purposes using +// rbac.RegisterForTesting() +func UnregisterForTesting() { + for _, typeURL := range builder.TypeURLs(builder{}) { + httpfilter.UnregisterForTesting(typeURL) + } +} + +type builder struct { +} + +type config struct { + httpfilter.FilterConfig + config *rpb.RBAC +} + +func (builder) TypeURLs() []string { + return []string{ + "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBAC", + "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBACPerRoute", + } +} + +// Parsing is the same for the base config and the override config. +func parseConfig(rbacCfg *rpb.RBAC) (httpfilter.FilterConfig, error) { + // All the validation logic described in A41. + for _, policy := range rbacCfg.GetRules().GetPolicies() { + // "Policy.condition and Policy.checked_condition must cause a + // validation failure if present." - A41 + if policy.Condition != nil { + return nil, errors.New("rbac: Policy.condition is present") + } + if policy.CheckedCondition != nil { + return nil, errors.New("rbac: policy.CheckedCondition is present") + } + + // "It is also a validation failure if Permission or Principal has a + // header matcher for a grpc- prefixed header name or :scheme." - A41 + for _, principal := range policy.Principals { + if principal.GetHeader() != nil { + name := principal.GetHeader().GetName() + if name == ":scheme" || strings.HasPrefix(name, "grpc-") { + return nil, fmt.Errorf("rbac: principal header matcher for %v is :scheme or starts with grpc", name) + } + } + } + for _, permission := range policy.Permissions { + if permission.GetHeader() != nil { + name := permission.GetHeader().GetName() + if name == ":scheme" || strings.HasPrefix(name, "grpc-") { + return nil, fmt.Errorf("rbac: permission header matcher for %v is :scheme or starts with grpc", name) + } + } + } + } + return config{config: rbacCfg}, nil +} + +func (builder) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + if cfg == nil { + return nil, fmt.Errorf("rbac: nil configuration message provided") + } + any, ok := cfg.(*anypb.Any) + if !ok { + return nil, fmt.Errorf("rbac: error parsing config %v: unknown type %T", cfg, cfg) + } + msg := new(rpb.RBAC) + if err := ptypes.UnmarshalAny(any, msg); err != nil { + return nil, fmt.Errorf("rbac: error parsing config %v: %v", cfg, err) + } + return parseConfig(msg) +} + +func (builder) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + if override == nil { + return nil, fmt.Errorf("rbac: nil configuration message provided") + } + any, ok := override.(*anypb.Any) + if !ok { + return nil, fmt.Errorf("rbac: error parsing override config %v: unknown type %T", override, override) + } + msg := new(rpb.RBACPerRoute) + if err := ptypes.UnmarshalAny(any, msg); err != nil { + return nil, fmt.Errorf("rbac: error parsing override config %v: %v", override, err) + } + return parseConfig(msg.Rbac) +} + +func (builder) IsTerminal() bool { + return false +} + +var _ httpfilter.ServerInterceptorBuilder = builder{} + +// BuildServerInterceptor is an optional interface builder implements in order +// to signify it works server side. +func (builder) BuildServerInterceptor(cfg httpfilter.FilterConfig, override httpfilter.FilterConfig) (resolver.ServerInterceptor, error) { + if cfg == nil { + return nil, fmt.Errorf("rbac: nil config provided") + } + + c, ok := cfg.(config) + if !ok { + return nil, fmt.Errorf("rbac: incorrect config type provided (%T): %v", cfg, cfg) + } + + if override != nil { + // override completely replaces the listener configuration; but we + // still validate the listener config type. + c, ok = override.(config) + if !ok { + return nil, fmt.Errorf("rbac: incorrect override config type provided (%T): %v", override, override) + } + } + + icfg := c.config + // "If absent, no enforcing RBAC policy will be applied" - RBAC + // Documentation for Rules field. + if icfg.Rules == nil { + return nil, nil + } + + // "At this time, if the RBAC.action is Action.LOG then the policy will be + // completely ignored, as if RBAC was not configurated." - A41 + if icfg.Rules.Action == v3rbacpb.RBAC_LOG { + return nil, nil + } + + // "Envoy aliases :authority and Host in its header map implementation, so + // they should be treated equivalent for the RBAC matchers; there must be no + // behavior change depending on which of the two header names is used in the + // RBAC policy." - A41. Loop through config's principals and policies, change + // any header matcher with value "host" to :authority", as that is what + // grpc-go shifts both headers to in transport layer. + for _, policy := range icfg.Rules.GetPolicies() { + for _, principal := range policy.Principals { + if principal.GetHeader() != nil { + name := principal.GetHeader().GetName() + if name == "host" { + principal.GetHeader().Name = ":authority" + } + } + } + for _, permission := range policy.Permissions { + if permission.GetHeader() != nil { + name := permission.GetHeader().GetName() + if name == "host" { + permission.GetHeader().Name = ":authority" + } + } + } + } + + ce, err := rbac.NewChainEngine([]*v3rbacpb.RBAC{icfg.Rules}) + if err != nil { + return nil, fmt.Errorf("error constructing matching engine: %v", err) + } + return &interceptor{chainEngine: ce}, nil +} + +type interceptor struct { + chainEngine *rbac.ChainEngine +} + +func (i *interceptor) AllowRPC(ctx context.Context) error { + return i.chainEngine.IsAuthorized(ctx) +} diff --git a/xds/internal/httpfilter/router/router.go b/xds/internal/httpfilter/router/router.go index 26e3acb5a4f..1ac6518170f 100644 --- a/xds/internal/httpfilter/router/router.go +++ b/xds/internal/httpfilter/router/router.go @@ -73,7 +73,14 @@ func (builder) ParseFilterConfigOverride(override proto.Message) (httpfilter.Fil return config{}, nil } -var _ httpfilter.ClientInterceptorBuilder = builder{} +func (builder) IsTerminal() bool { + return true +} + +var ( + _ httpfilter.ClientInterceptorBuilder = builder{} + _ httpfilter.ServerInterceptorBuilder = builder{} +) func (builder) BuildClientInterceptor(cfg, override httpfilter.FilterConfig) (iresolver.ClientInterceptor, error) { if _, ok := cfg.(config); !ok { @@ -88,6 +95,18 @@ func (builder) BuildClientInterceptor(cfg, override httpfilter.FilterConfig) (ir return nil, nil } +func (builder) BuildServerInterceptor(cfg, override httpfilter.FilterConfig) (iresolver.ServerInterceptor, error) { + if _, ok := cfg.(config); !ok { + return nil, fmt.Errorf("router: incorrect config type provided (%T): %v", cfg, cfg) + } + if override != nil { + return nil, fmt.Errorf("router: unexpected override configuration specified: %v", override) + } + // The gRPC router is currently unimplemented on the server side. So we + // return a nil HTTPFilter, which will not be invoked. + return nil, nil +} + // The gRPC router filter does not currently support any configuration. Verify // type only. type config struct { diff --git a/xds/internal/internal.go b/xds/internal/internal.go index e4284ee02e0..0cccd382410 100644 --- a/xds/internal/internal.go +++ b/xds/internal/internal.go @@ -22,6 +22,8 @@ package internal import ( "encoding/json" "fmt" + + "google.golang.org/grpc/resolver" ) // LocalityID is xds.Locality without XXX fields, so it can be used as map @@ -53,3 +55,19 @@ func LocalityIDFromString(s string) (ret LocalityID, _ error) { } return ret, nil } + +type localityKeyType string + +const localityKey = localityKeyType("grpc.xds.internal.address.locality") + +// GetLocalityID returns the locality ID of addr. +func GetLocalityID(addr resolver.Address) LocalityID { + path, _ := addr.Attributes.Value(localityKey).(LocalityID) + return path +} + +// SetLocalityID sets locality ID in addr to l. +func SetLocalityID(addr resolver.Address, l LocalityID) resolver.Address { + addr.Attributes = addr.Attributes.WithValues(localityKey, l) + return addr +} diff --git a/xds/internal/resolver/matcher.go b/xds/internal/resolver/matcher.go deleted file mode 100644 index b7b5f3db0e3..00000000000 --- a/xds/internal/resolver/matcher.go +++ /dev/null @@ -1,161 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package resolver - -import ( - "fmt" - "regexp" - "strings" - - "google.golang.org/grpc/internal/grpcrand" - "google.golang.org/grpc/internal/grpcutil" - iresolver "google.golang.org/grpc/internal/resolver" - "google.golang.org/grpc/metadata" - xdsclient "google.golang.org/grpc/xds/internal/client" -) - -func routeToMatcher(r *xdsclient.Route) (*compositeMatcher, error) { - var pathMatcher pathMatcherInterface - switch { - case r.Regex != nil: - re, err := regexp.Compile(*r.Regex) - if err != nil { - return nil, fmt.Errorf("failed to compile regex %q", *r.Regex) - } - pathMatcher = newPathRegexMatcher(re) - case r.Path != nil: - pathMatcher = newPathExactMatcher(*r.Path, r.CaseInsensitive) - case r.Prefix != nil: - pathMatcher = newPathPrefixMatcher(*r.Prefix, r.CaseInsensitive) - default: - return nil, fmt.Errorf("illegal route: missing path_matcher") - } - - var headerMatchers []headerMatcherInterface - for _, h := range r.Headers { - var matcherT headerMatcherInterface - switch { - case h.ExactMatch != nil && *h.ExactMatch != "": - matcherT = newHeaderExactMatcher(h.Name, *h.ExactMatch) - case h.RegexMatch != nil && *h.RegexMatch != "": - re, err := regexp.Compile(*h.RegexMatch) - if err != nil { - return nil, fmt.Errorf("failed to compile regex %q, skipping this matcher", *h.RegexMatch) - } - matcherT = newHeaderRegexMatcher(h.Name, re) - case h.PrefixMatch != nil && *h.PrefixMatch != "": - matcherT = newHeaderPrefixMatcher(h.Name, *h.PrefixMatch) - case h.SuffixMatch != nil && *h.SuffixMatch != "": - matcherT = newHeaderSuffixMatcher(h.Name, *h.SuffixMatch) - case h.RangeMatch != nil: - matcherT = newHeaderRangeMatcher(h.Name, h.RangeMatch.Start, h.RangeMatch.End) - case h.PresentMatch != nil: - matcherT = newHeaderPresentMatcher(h.Name, *h.PresentMatch) - default: - return nil, fmt.Errorf("illegal route: missing header_match_specifier") - } - if h.InvertMatch != nil && *h.InvertMatch { - matcherT = newInvertMatcher(matcherT) - } - headerMatchers = append(headerMatchers, matcherT) - } - - var fractionMatcher *fractionMatcher - if r.Fraction != nil { - fractionMatcher = newFractionMatcher(*r.Fraction) - } - return newCompositeMatcher(pathMatcher, headerMatchers, fractionMatcher), nil -} - -// compositeMatcher.match returns true if all matchers return true. -type compositeMatcher struct { - pm pathMatcherInterface - hms []headerMatcherInterface - fm *fractionMatcher -} - -func newCompositeMatcher(pm pathMatcherInterface, hms []headerMatcherInterface, fm *fractionMatcher) *compositeMatcher { - return &compositeMatcher{pm: pm, hms: hms, fm: fm} -} - -func (a *compositeMatcher) match(info iresolver.RPCInfo) bool { - if a.pm != nil && !a.pm.match(info.Method) { - return false - } - - // Call headerMatchers even if md is nil, because routes may match - // non-presence of some headers. - var md metadata.MD - if info.Context != nil { - md, _ = metadata.FromOutgoingContext(info.Context) - if extraMD, ok := grpcutil.ExtraMetadata(info.Context); ok { - md = metadata.Join(md, extraMD) - // Remove all binary headers. They are hard to match with. May need - // to add back if asked by users. - for k := range md { - if strings.HasSuffix(k, "-bin") { - delete(md, k) - } - } - } - } - for _, m := range a.hms { - if !m.match(md) { - return false - } - } - - if a.fm != nil && !a.fm.match() { - return false - } - return true -} - -func (a *compositeMatcher) String() string { - var ret string - if a.pm != nil { - ret += a.pm.String() - } - for _, m := range a.hms { - ret += m.String() - } - if a.fm != nil { - ret += a.fm.String() - } - return ret -} - -type fractionMatcher struct { - fraction int64 // real fraction is fraction/1,000,000. -} - -func newFractionMatcher(fraction uint32) *fractionMatcher { - return &fractionMatcher{fraction: int64(fraction)} -} - -var grpcrandInt63n = grpcrand.Int63n - -func (fm *fractionMatcher) match() bool { - t := grpcrandInt63n(1000000) - return t <= fm.fraction -} - -func (fm *fractionMatcher) String() string { - return fmt.Sprintf("fraction:%v", fm.fraction) -} diff --git a/xds/internal/resolver/matcher_header.go b/xds/internal/resolver/matcher_header.go deleted file mode 100644 index 05a92788d7b..00000000000 --- a/xds/internal/resolver/matcher_header.go +++ /dev/null @@ -1,188 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package resolver - -import ( - "fmt" - "regexp" - "strconv" - "strings" - - "google.golang.org/grpc/metadata" -) - -type headerMatcherInterface interface { - match(metadata.MD) bool - String() string -} - -// mdValuesFromOutgoingCtx retrieves metadata from context. If there are -// multiple values, the values are concatenated with "," (comma and no space). -// -// All header matchers only match against the comma-concatenated string. -func mdValuesFromOutgoingCtx(md metadata.MD, key string) (string, bool) { - vs, ok := md[key] - if !ok { - return "", false - } - return strings.Join(vs, ","), true -} - -type headerExactMatcher struct { - key string - exact string -} - -func newHeaderExactMatcher(key, exact string) *headerExactMatcher { - return &headerExactMatcher{key: key, exact: exact} -} - -func (hem *headerExactMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hem.key) - if !ok { - return false - } - return v == hem.exact -} - -func (hem *headerExactMatcher) String() string { - return fmt.Sprintf("headerExact:%v:%v", hem.key, hem.exact) -} - -type headerRegexMatcher struct { - key string - re *regexp.Regexp -} - -func newHeaderRegexMatcher(key string, re *regexp.Regexp) *headerRegexMatcher { - return &headerRegexMatcher{key: key, re: re} -} - -func (hrm *headerRegexMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hrm.key) - if !ok { - return false - } - return hrm.re.MatchString(v) -} - -func (hrm *headerRegexMatcher) String() string { - return fmt.Sprintf("headerRegex:%v:%v", hrm.key, hrm.re.String()) -} - -type headerRangeMatcher struct { - key string - start, end int64 // represents [start, end). -} - -func newHeaderRangeMatcher(key string, start, end int64) *headerRangeMatcher { - return &headerRangeMatcher{key: key, start: start, end: end} -} - -func (hrm *headerRangeMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hrm.key) - if !ok { - return false - } - if i, err := strconv.ParseInt(v, 10, 64); err == nil && i >= hrm.start && i < hrm.end { - return true - } - return false -} - -func (hrm *headerRangeMatcher) String() string { - return fmt.Sprintf("headerRange:%v:[%d,%d)", hrm.key, hrm.start, hrm.end) -} - -type headerPresentMatcher struct { - key string - present bool -} - -func newHeaderPresentMatcher(key string, present bool) *headerPresentMatcher { - return &headerPresentMatcher{key: key, present: present} -} - -func (hpm *headerPresentMatcher) match(md metadata.MD) bool { - vs, ok := mdValuesFromOutgoingCtx(md, hpm.key) - present := ok && len(vs) > 0 - return present == hpm.present -} - -func (hpm *headerPresentMatcher) String() string { - return fmt.Sprintf("headerPresent:%v:%v", hpm.key, hpm.present) -} - -type headerPrefixMatcher struct { - key string - prefix string -} - -func newHeaderPrefixMatcher(key string, prefix string) *headerPrefixMatcher { - return &headerPrefixMatcher{key: key, prefix: prefix} -} - -func (hpm *headerPrefixMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hpm.key) - if !ok { - return false - } - return strings.HasPrefix(v, hpm.prefix) -} - -func (hpm *headerPrefixMatcher) String() string { - return fmt.Sprintf("headerPrefix:%v:%v", hpm.key, hpm.prefix) -} - -type headerSuffixMatcher struct { - key string - suffix string -} - -func newHeaderSuffixMatcher(key string, suffix string) *headerSuffixMatcher { - return &headerSuffixMatcher{key: key, suffix: suffix} -} - -func (hsm *headerSuffixMatcher) match(md metadata.MD) bool { - v, ok := mdValuesFromOutgoingCtx(md, hsm.key) - if !ok { - return false - } - return strings.HasSuffix(v, hsm.suffix) -} - -func (hsm *headerSuffixMatcher) String() string { - return fmt.Sprintf("headerSuffix:%v:%v", hsm.key, hsm.suffix) -} - -type invertMatcher struct { - m headerMatcherInterface -} - -func newInvertMatcher(m headerMatcherInterface) *invertMatcher { - return &invertMatcher{m: m} -} - -func (i *invertMatcher) match(md metadata.MD) bool { - return !i.m.match(md) -} - -func (i *invertMatcher) String() string { - return fmt.Sprintf("invert{%s}", i.m) -} diff --git a/xds/internal/resolver/serviceconfig.go b/xds/internal/resolver/serviceconfig.go index 7c1ec853e4c..ddf699f938b 100644 --- a/xds/internal/resolver/serviceconfig.go +++ b/xds/internal/resolver/serviceconfig.go @@ -22,18 +22,25 @@ import ( "context" "encoding/json" "fmt" + "math/bits" + "strings" "sync/atomic" "time" + xxhash "github.com/cespare/xxhash/v2" "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpcrand" iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/wrr" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/grpc/xds/internal/balancer/clustermanager" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/env" + "google.golang.org/grpc/xds/internal/balancer/ringhash" "google.golang.org/grpc/xds/internal/httpfilter" "google.golang.org/grpc/xds/internal/httpfilter/router" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -76,7 +83,7 @@ func (r *xdsResolver) pruneActiveClusters() { // serviceConfigJSON produces a service config in JSON format representing all // the clusters referenced in activeClusters. This includes clusters with zero // references, so they must be pruned first. -func serviceConfigJSON(activeClusters map[string]*clusterInfo) (string, error) { +func serviceConfigJSON(activeClusters map[string]*clusterInfo) ([]byte, error) { // Generate children (all entries in activeClusters). children := make(map[string]xdsChildConfig) for cluster := range activeClusters { @@ -93,14 +100,16 @@ func serviceConfigJSON(activeClusters map[string]*clusterInfo) (string, error) { bs, err := json.Marshal(sc) if err != nil { - return "", fmt.Errorf("failed to marshal json: %v", err) + return nil, fmt.Errorf("failed to marshal json: %v", err) } - return string(bs), nil + return bs, nil } type virtualHost struct { // map from filter name to its config httpFilterConfigOverride map[string]httpfilter.FilterConfig + // retry policy present in virtual host + retryConfig *xdsclient.RetryConfig } // routeCluster holds information about a cluster as referenced by a route. @@ -111,11 +120,13 @@ type routeCluster struct { } type route struct { - m *compositeMatcher // converted from route matchers - clusters wrr.WRR // holds *routeCluster entries + m *xdsclient.CompositeMatcher // converted from route matchers + clusters wrr.WRR // holds *routeCluster entries maxStreamDuration time.Duration // map from filter name to its config httpFilterConfigOverride map[string]httpfilter.FilterConfig + retryConfig *xdsclient.RetryConfig + hashPolicies []*xdsclient.HashPolicy } func (r route) String() string { @@ -139,7 +150,7 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP var rt *route // Loop through routes in order and select first match. for _, r := range cs.routes { - if r.m.match(rpcInfo) { + if r.m.Match(rpcInfo) { rt = &r break } @@ -161,9 +172,15 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP return nil, err } + lbCtx := clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name) + // Request Hashes are only applicable for a Ring Hash LB. + if env.RingHashSupport { + lbCtx = ringhash.SetRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies)) + } + config := &iresolver.RPCConfig{ - // Communicate to the LB policy the chosen cluster. - Context: clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name), + // Communicate to the LB policy the chosen cluster and request hash, if Ring Hash LB policy. + Context: lbCtx, OnCommitted: func() { // When the RPC is committed, the cluster is no longer required. // Decrease its ref. @@ -179,13 +196,83 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP Interceptor: interceptor, } - if env.TimeoutSupport && rt.maxStreamDuration != 0 { + if rt.maxStreamDuration != 0 { config.MethodConfig.Timeout = &rt.maxStreamDuration } + if rt.retryConfig != nil { + config.MethodConfig.RetryPolicy = retryConfigToPolicy(rt.retryConfig) + } else if cs.virtualHost.retryConfig != nil { + config.MethodConfig.RetryPolicy = retryConfigToPolicy(cs.virtualHost.retryConfig) + } return config, nil } +func retryConfigToPolicy(config *xdsclient.RetryConfig) *serviceconfig.RetryPolicy { + return &serviceconfig.RetryPolicy{ + MaxAttempts: int(config.NumRetries) + 1, + InitialBackoff: config.RetryBackoff.BaseInterval, + MaxBackoff: config.RetryBackoff.MaxInterval, + BackoffMultiplier: 2, + RetryableStatusCodes: config.RetryOn, + } +} + +func (cs *configSelector) generateHash(rpcInfo iresolver.RPCInfo, hashPolicies []*xdsclient.HashPolicy) uint64 { + var hash uint64 + var generatedHash bool + for _, policy := range hashPolicies { + var policyHash uint64 + var generatedPolicyHash bool + switch policy.HashPolicyType { + case xdsclient.HashPolicyTypeHeader: + md, ok := metadata.FromOutgoingContext(rpcInfo.Context) + if !ok { + continue + } + values := md.Get(policy.HeaderName) + // If the header isn't present, no-op. + if len(values) == 0 { + continue + } + joinedValues := strings.Join(values, ",") + if policy.Regex != nil { + joinedValues = policy.Regex.ReplaceAllString(joinedValues, policy.RegexSubstitution) + } + policyHash = xxhash.Sum64String(joinedValues) + generatedHash = true + generatedPolicyHash = true + case xdsclient.HashPolicyTypeChannelID: + // Hash the ClientConn pointer which logically uniquely + // identifies the client. + policyHash = xxhash.Sum64String(fmt.Sprintf("%p", &cs.r.cc)) + generatedHash = true + generatedPolicyHash = true + } + + // Deterministically combine the hash policies. Rotating prevents + // duplicate hash policies from cancelling each other out and preserves + // the 64 bits of entropy. + if generatedPolicyHash { + hash = bits.RotateLeft64(hash, 1) + hash = hash ^ policyHash + } + + // If terminal policy and a hash has already been generated, ignore the + // rest of the policies and use that hash already generated. + if policy.Terminal && generatedHash { + break + } + } + + if generatedHash { + return hash + } + // If no generated hash return a random long. In the grand scheme of things + // this logically will map to choosing a random backend to route request to. + return grpcrand.Uint64() +} + func (cs *configSelector) newInterceptor(rt *route, cluster *routeCluster) (iresolver.ClientInterceptor, error) { if len(cs.httpFilterConfig) == 0 { return nil, nil @@ -254,8 +341,11 @@ var newWRR = wrr.NewRandom // r.activeClusters for previously-unseen clusters. func (r *xdsResolver) newConfigSelector(su serviceUpdate) (*configSelector, error) { cs := &configSelector{ - r: r, - virtualHost: virtualHost{httpFilterConfigOverride: su.virtualHost.HTTPFilterConfigOverride}, + r: r, + virtualHost: virtualHost{ + httpFilterConfigOverride: su.virtualHost.HTTPFilterConfigOverride, + retryConfig: su.virtualHost.RetryConfig, + }, routes: make([]route, len(su.virtualHost.Routes)), clusters: make(map[string]*clusterInfo), httpFilterConfig: su.ldsConfig.httpFilterConfig, @@ -282,7 +372,7 @@ func (r *xdsResolver) newConfigSelector(su serviceUpdate) (*configSelector, erro cs.routes[i].clusters = clusters var err error - cs.routes[i].m, err = routeToMatcher(rt) + cs.routes[i].m, err = xdsclient.RouteToMatcher(rt) if err != nil { return nil, err } @@ -293,6 +383,8 @@ func (r *xdsResolver) newConfigSelector(su serviceUpdate) (*configSelector, erro } cs.routes[i].httpFilterConfigOverride = rt.HTTPFilterConfigOverride + cs.routes[i].retryConfig = rt.RetryConfig + cs.routes[i].hashPolicies = rt.HashPolicies } // Account for this config selector's clusters. Do this after no further diff --git a/xds/internal/resolver/serviceconfig_test.go b/xds/internal/resolver/serviceconfig_test.go index 1e253841e80..a1a48944dc4 100644 --- a/xds/internal/resolver/serviceconfig_test.go +++ b/xds/internal/resolver/serviceconfig_test.go @@ -19,10 +19,17 @@ package resolver import ( + "context" + "fmt" + "regexp" "testing" + xxhash "github.com/cespare/xxhash/v2" "github.com/google/go-cmp/cmp" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/metadata" _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // To parse LB config + "google.golang.org/grpc/xds/internal/xdsclient" ) func (s) TestPruneActiveClusters(t *testing.T) { @@ -41,3 +48,70 @@ func (s) TestPruneActiveClusters(t *testing.T) { t.Fatalf("r.activeClusters = %v; want %v\nDiffs: %v", r.activeClusters, want, d) } } + +func (s) TestGenerateRequestHash(t *testing.T) { + cs := &configSelector{ + r: &xdsResolver{ + cc: &testClientConn{}, + }, + } + tests := []struct { + name string + hashPolicies []*xdsclient.HashPolicy + requestHashWant uint64 + rpcInfo iresolver.RPCInfo + }{ + // TestGenerateRequestHashHeaders tests generating request hashes for + // hash policies that specify to hash headers. + { + name: "test-generate-request-hash-headers", + hashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("/products") }(), // Will replace /products with /new-products, to test find and replace functionality. + RegexSubstitution: "/new-products", + }}, + requestHashWant: xxhash.Sum64String("/new-products"), + rpcInfo: iresolver.RPCInfo{ + Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(":path", "/products")), + Method: "/some-method", + }, + }, + // TestGenerateHashChannelID tests generating request hashes for hash + // policies that specify to hash something that uniquely identifies the + // ClientConn (the pointer). + { + name: "test-generate-request-hash-channel-id", + hashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeChannelID, + }}, + requestHashWant: xxhash.Sum64String(fmt.Sprintf("%p", &cs.r.cc)), + rpcInfo: iresolver.RPCInfo{}, + }, + // TestGenerateRequestHashEmptyString tests generating request hashes + // for hash policies that specify to hash headers and replace empty + // strings in the headers. + { + name: "test-generate-request-hash-empty-string", + hashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("") }(), + RegexSubstitution: "e", + }}, + requestHashWant: xxhash.Sum64String("eaebece"), + rpcInfo: iresolver.RPCInfo{ + Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(":path", "abc")), + Method: "/some-method", + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + requestHashGot := cs.generateHash(test.rpcInfo, test.hashPolicies) + if requestHashGot != test.requestHashWant { + t.Fatalf("requestHashGot = %v, requestHashWant = %v", requestHashGot, test.requestHashWant) + } + }) + } +} diff --git a/xds/internal/resolver/watch_service.go b/xds/internal/resolver/watch_service.go index 913ac4ced15..da0bf95f3b9 100644 --- a/xds/internal/resolver/watch_service.go +++ b/xds/internal/resolver/watch_service.go @@ -20,12 +20,12 @@ package resolver import ( "fmt" - "strings" "sync" "time" "google.golang.org/grpc/internal/grpclog" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/xds/internal/xdsclient" ) // serviceUpdate contains information received from the LDS/RDS responses which @@ -53,7 +53,7 @@ type ldsConfig struct { // Note that during race (e.g. an xDS response is received while the user is // calling cancel()), there's a small window where the callback can be called // after the watcher is canceled. The caller needs to handle this case. -func watchService(c xdsClientInterface, serviceName string, cb func(serviceUpdate, error), logger *grpclog.PrefixLogger) (cancel func()) { +func watchService(c xdsclient.XDSClient, serviceName string, cb func(serviceUpdate, error), logger *grpclog.PrefixLogger) (cancel func()) { w := &serviceUpdateWatcher{ logger: logger, c: c, @@ -69,7 +69,7 @@ func watchService(c xdsClientInterface, serviceName string, cb func(serviceUpdat // callback at the right time. type serviceUpdateWatcher struct { logger *grpclog.PrefixLogger - c xdsClientInterface + c xdsclient.XDSClient serviceName string ldsCancel func() serviceCb func(serviceUpdate, error) @@ -82,7 +82,7 @@ type serviceUpdateWatcher struct { } func (w *serviceUpdateWatcher) handleLDSResp(update xdsclient.ListenerUpdate, err error) { - w.logger.Infof("received LDS update: %+v, err: %v", update, err) + w.logger.Infof("received LDS update: %+v, err: %v", pretty.ToJSON(update), err) w.mu.Lock() defer w.mu.Unlock() if w.closed { @@ -110,13 +110,37 @@ func (w *serviceUpdateWatcher) handleLDSResp(update xdsclient.ListenerUpdate, er httpFilterConfig: update.HTTPFilters, } + if update.InlineRouteConfig != nil { + // If there was an RDS watch, cancel it. + w.rdsName = "" + if w.rdsCancel != nil { + w.rdsCancel() + w.rdsCancel = nil + } + + // Handle the inline RDS update as if it's from an RDS watch. + w.updateVirtualHostsFromRDS(*update.InlineRouteConfig) + return + } + + // RDS name from update is not an empty string, need RDS to fetch the + // routes. + if w.rdsName == update.RouteConfigName { // If the new RouteConfigName is same as the previous, don't cancel and // restart the RDS watch. // // If the route name did change, then we must wait until the first RDS // update before reporting this LDS config. - w.serviceCb(w.lastUpdate, nil) + if w.lastUpdate.virtualHost != nil { + // We want to send an update with the new fields from the new LDS + // (e.g. max stream duration), and old fields from the the previous + // RDS. + // + // But note that this should only happen when virtual host is set, + // which means an RDS was received. + w.serviceCb(w.lastUpdate, nil) + } return } w.rdsName = update.RouteConfigName @@ -126,8 +150,20 @@ func (w *serviceUpdateWatcher) handleLDSResp(update xdsclient.ListenerUpdate, er w.rdsCancel = w.c.WatchRouteConfig(update.RouteConfigName, w.handleRDSResp) } +func (w *serviceUpdateWatcher) updateVirtualHostsFromRDS(update xdsclient.RouteConfigUpdate) { + matchVh := xdsclient.FindBestMatchingVirtualHost(w.serviceName, update.VirtualHosts) + if matchVh == nil { + // No matching virtual host found. + w.serviceCb(serviceUpdate{}, fmt.Errorf("no matching virtual host found for %q", w.serviceName)) + return + } + + w.lastUpdate.virtualHost = matchVh + w.serviceCb(w.lastUpdate, nil) +} + func (w *serviceUpdateWatcher) handleRDSResp(update xdsclient.RouteConfigUpdate, err error) { - w.logger.Infof("received RDS update: %+v, err: %v", update, err) + w.logger.Infof("received RDS update: %+v, err: %v", pretty.ToJSON(update), err) w.mu.Lock() defer w.mu.Unlock() if w.closed { @@ -142,16 +178,7 @@ func (w *serviceUpdateWatcher) handleRDSResp(update xdsclient.RouteConfigUpdate, w.serviceCb(serviceUpdate{}, err) return } - - matchVh := findBestMatchingVirtualHost(w.serviceName, update.VirtualHosts) - if matchVh == nil { - // No matching virtual host found. - w.serviceCb(serviceUpdate{}, fmt.Errorf("no matching virtual host found for %q", w.serviceName)) - return - } - - w.lastUpdate.virtualHost = matchVh - w.serviceCb(w.lastUpdate, nil) + w.updateVirtualHostsFromRDS(update) } func (w *serviceUpdateWatcher) close() { @@ -164,97 +191,3 @@ func (w *serviceUpdateWatcher) close() { w.rdsCancel = nil } } - -type domainMatchType int - -const ( - domainMatchTypeInvalid domainMatchType = iota - domainMatchTypeUniversal - domainMatchTypePrefix - domainMatchTypeSuffix - domainMatchTypeExact -) - -// Exact > Suffix > Prefix > Universal > Invalid. -func (t domainMatchType) betterThan(b domainMatchType) bool { - return t > b -} - -func matchTypeForDomain(d string) domainMatchType { - if d == "" { - return domainMatchTypeInvalid - } - if d == "*" { - return domainMatchTypeUniversal - } - if strings.HasPrefix(d, "*") { - return domainMatchTypeSuffix - } - if strings.HasSuffix(d, "*") { - return domainMatchTypePrefix - } - if strings.Contains(d, "*") { - return domainMatchTypeInvalid - } - return domainMatchTypeExact -} - -func match(domain, host string) (domainMatchType, bool) { - switch typ := matchTypeForDomain(domain); typ { - case domainMatchTypeInvalid: - return typ, false - case domainMatchTypeUniversal: - return typ, true - case domainMatchTypePrefix: - // abc.* - return typ, strings.HasPrefix(host, strings.TrimSuffix(domain, "*")) - case domainMatchTypeSuffix: - // *.123 - return typ, strings.HasSuffix(host, strings.TrimPrefix(domain, "*")) - case domainMatchTypeExact: - return typ, domain == host - default: - return domainMatchTypeInvalid, false - } -} - -// findBestMatchingVirtualHost returns the virtual host whose domains field best -// matches host -// -// The domains field support 4 different matching pattern types: -// - Exact match -// - Suffix match (e.g. “*ABC”) -// - Prefix match (e.g. “ABC*) -// - Universal match (e.g. “*”) -// -// The best match is defined as: -// - A match is better if it’s matching pattern type is better -// - Exact match > suffix match > prefix match > universal match -// - If two matches are of the same pattern type, the longer match is better -// - This is to compare the length of the matching pattern, e.g. “*ABCDE” > -// “*ABC” -func findBestMatchingVirtualHost(host string, vHosts []*xdsclient.VirtualHost) *xdsclient.VirtualHost { - var ( - matchVh *xdsclient.VirtualHost - matchType = domainMatchTypeInvalid - matchLen int - ) - for _, vh := range vHosts { - for _, domain := range vh.Domains { - typ, matched := match(domain, host) - if typ == domainMatchTypeInvalid { - // The rds response is invalid. - return nil - } - if matchType.betterThan(typ) || matchType == typ && matchLen >= len(domain) || !matched { - // The previous match has better type, or the previous match has - // better length, or this domain isn't a match. - continue - } - matchVh = vh - matchType = typ - matchLen = len(domain) - } - } - return matchVh -} diff --git a/xds/internal/resolver/watch_service_test.go b/xds/internal/resolver/watch_service_test.go index 705a3d35ae1..1bf65c4d450 100644 --- a/xds/internal/resolver/watch_service_test.go +++ b/xds/internal/resolver/watch_service_test.go @@ -27,57 +27,11 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" "google.golang.org/protobuf/proto" ) -func (s) TestMatchTypeForDomain(t *testing.T) { - tests := []struct { - d string - want domainMatchType - }{ - {d: "", want: domainMatchTypeInvalid}, - {d: "*", want: domainMatchTypeUniversal}, - {d: "bar.*", want: domainMatchTypePrefix}, - {d: "*.abc.com", want: domainMatchTypeSuffix}, - {d: "foo.bar.com", want: domainMatchTypeExact}, - {d: "foo.*.com", want: domainMatchTypeInvalid}, - } - for _, tt := range tests { - if got := matchTypeForDomain(tt.d); got != tt.want { - t.Errorf("matchTypeForDomain(%q) = %v, want %v", tt.d, got, tt.want) - } - } -} - -func (s) TestMatch(t *testing.T) { - tests := []struct { - name string - domain string - host string - wantTyp domainMatchType - wantMatched bool - }{ - {name: "invalid-empty", domain: "", host: "", wantTyp: domainMatchTypeInvalid, wantMatched: false}, - {name: "invalid", domain: "a.*.b", host: "", wantTyp: domainMatchTypeInvalid, wantMatched: false}, - {name: "universal", domain: "*", host: "abc.com", wantTyp: domainMatchTypeUniversal, wantMatched: true}, - {name: "prefix-match", domain: "abc.*", host: "abc.123", wantTyp: domainMatchTypePrefix, wantMatched: true}, - {name: "prefix-no-match", domain: "abc.*", host: "abcd.123", wantTyp: domainMatchTypePrefix, wantMatched: false}, - {name: "suffix-match", domain: "*.123", host: "abc.123", wantTyp: domainMatchTypeSuffix, wantMatched: true}, - {name: "suffix-no-match", domain: "*.123", host: "abc.1234", wantTyp: domainMatchTypeSuffix, wantMatched: false}, - {name: "exact-match", domain: "foo.bar", host: "foo.bar", wantTyp: domainMatchTypeExact, wantMatched: true}, - {name: "exact-no-match", domain: "foo.bar.com", host: "foo.bar", wantTyp: domainMatchTypeExact, wantMatched: false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if gotTyp, gotMatched := match(tt.domain, tt.host); gotTyp != tt.wantTyp || gotMatched != tt.wantMatched { - t.Errorf("match() = %v, %v, want %v, %v", gotTyp, gotMatched, tt.wantTyp, tt.wantMatched) - } - }) - } -} - func (s) TestFindBestMatchingVirtualHost(t *testing.T) { var ( oneExactMatch = &xdsclient.VirtualHost{ @@ -121,7 +75,7 @@ func (s) TestFindBestMatchingVirtualHost(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := findBestMatchingVirtualHost(tt.host, tt.vHosts); !cmp.Equal(got, tt.want, cmp.Comparer(proto.Equal)) { + if got := xdsclient.FindBestMatchingVirtualHost(tt.host, tt.vHosts); !cmp.Equal(got, tt.want, cmp.Comparer(proto.Equal)) { t.Errorf("findBestMatchingxdsclient.VirtualHost() = %v, want %v", got, tt.want) } }) @@ -167,7 +121,7 @@ func (s) TestServiceWatch(t *testing.T) { waitForWatchRouteConfig(ctx, t, xdsC, routeStr) wantUpdate := serviceUpdate{virtualHost: &xdsclient.VirtualHost{Domains: []string{"target"}, Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}}} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -185,7 +139,7 @@ func (s) TestServiceWatch(t *testing.T) { WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}, }}, }} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -221,7 +175,7 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) { waitForWatchRouteConfig(ctx, t, xdsC, routeStr) wantUpdate := serviceUpdate{virtualHost: &xdsclient.VirtualHost{Domains: []string{"target"}, Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}}} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -235,14 +189,14 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) { // Another LDS update with a different RDS_name. xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr + "2"}, nil) - if err := xdsC.WaitForCancelRouteConfigWatch(ctx); err != nil { + if _, err := xdsC.WaitForCancelRouteConfigWatch(ctx); err != nil { t.Fatalf("wait for cancel route watch failed: %v, want nil", err) } waitForWatchRouteConfig(ctx, t, xdsC, routeStr+"2") // RDS update for the new name. wantUpdate2 := serviceUpdate{virtualHost: &xdsclient.VirtualHost{Domains: []string{"target"}, Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster + "2": {Weight: 1}}}}}} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback(routeStr+"2", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -277,7 +231,7 @@ func (s) TestServiceWatchLDSUpdateMaxStreamDuration(t *testing.T) { WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}}, ldsConfig: ldsConfig{maxStreamDuration: time.Second}, } - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -301,7 +255,7 @@ func (s) TestServiceWatchLDSUpdateMaxStreamDuration(t *testing.T) { Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster + "2": {Weight: 1}}}}, }} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -335,7 +289,7 @@ func (s) TestServiceNotCancelRDSOnSameLDSUpdate(t *testing.T) { Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, }} - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -352,7 +306,85 @@ func (s) TestServiceNotCancelRDSOnSameLDSUpdate(t *testing.T) { xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr}, nil) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() - if err := xdsC.WaitForCancelRouteConfigWatch(sCtx); err != context.DeadlineExceeded { + if _, err := xdsC.WaitForCancelRouteConfigWatch(sCtx); err != context.DeadlineExceeded { + t.Fatalf("wait for cancel route watch failed: %v, want nil", err) + } +} + +// TestServiceWatchInlineRDS covers the cases switching between: +// - LDS update contains RDS name to watch +// - LDS update contains inline RDS resource +func (s) TestServiceWatchInlineRDS(t *testing.T) { + serviceUpdateCh := testutils.NewChannel() + xdsC := fakeclient.NewClient() + cancelWatch := watchService(xdsC, targetStr, func(update serviceUpdate, err error) { + serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err}) + }, nil) + defer cancelWatch() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // First LDS update is LDS with RDS name to watch. + waitForWatchListener(ctx, t, xdsC, targetStr) + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr}, nil) + waitForWatchRouteConfig(ctx, t, xdsC, routeStr) + wantUpdate := serviceUpdate{virtualHost: &xdsclient.VirtualHost{Domains: []string{"target"}, Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}}} + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{ + { + Domains: []string{targetStr}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + }, + }, + }, nil) + if err := verifyServiceUpdate(ctx, serviceUpdateCh, wantUpdate); err != nil { + t.Fatal(err) + } + + // Switch LDS resp to a LDS with inline RDS resource + wantVirtualHosts2 := &xdsclient.VirtualHost{Domains: []string{"target"}, + Routes: []*xdsclient.Route{{ + Path: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}, + }}, + } + wantUpdate2 := serviceUpdate{virtualHost: wantVirtualHosts2} + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{InlineRouteConfig: &xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{wantVirtualHosts2}, + }}, nil) + // This inline RDS resource should cause the RDS watch to be canceled. + if _, err := xdsC.WaitForCancelRouteConfigWatch(ctx); err != nil { t.Fatalf("wait for cancel route watch failed: %v, want nil", err) } + if err := verifyServiceUpdate(ctx, serviceUpdateCh, wantUpdate2); err != nil { + t.Fatal(err) + } + + // Switch LDS update back to LDS with RDS name to watch. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr}, nil) + waitForWatchRouteConfig(ctx, t, xdsC, routeStr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{ + { + Domains: []string{targetStr}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + }, + }, + }, nil) + if err := verifyServiceUpdate(ctx, serviceUpdateCh, wantUpdate); err != nil { + t.Fatal(err) + } + + // Switch LDS resp to a LDS with inline RDS resource again. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{InlineRouteConfig: &xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{wantVirtualHosts2}, + }}, nil) + // This inline RDS resource should cause the RDS watch to be canceled. + if _, err := xdsC.WaitForCancelRouteConfigWatch(ctx); err != nil { + t.Fatalf("wait for cancel route watch failed: %v, want nil", err) + } + if err := verifyServiceUpdate(ctx, serviceUpdateCh, wantUpdate2); err != nil { + t.Fatal(err) + } } diff --git a/xds/internal/resolver/xds_resolver.go b/xds/internal/resolver/xds_resolver.go index d8c09db69b5..19ee01773e8 100644 --- a/xds/internal/resolver/xds_resolver.go +++ b/xds/internal/resolver/xds_resolver.go @@ -26,23 +26,35 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/xds/internal/client/bootstrap" - + "google.golang.org/grpc/internal/pretty" iresolver "google.golang.org/grpc/internal/resolver" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/xdsclient" ) const xdsScheme = "xds" +// NewBuilder creates a new xds resolver builder using a specific xds bootstrap +// config, so tests can use multiple xds clients in different ClientConns at +// the same time. +func NewBuilder(config []byte) (resolver.Builder, error) { + return &xdsResolverBuilder{ + newXDSClient: func() (xdsclient.XDSClient, error) { + return xdsclient.NewClientWithBootstrapContents(config) + }, + }, nil +} + // For overriding in unittests. -var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } +var newXDSClient = func() (xdsclient.XDSClient, error) { return xdsclient.New() } func init() { resolver.Register(&xdsResolverBuilder{}) } -type xdsResolverBuilder struct{} +type xdsResolverBuilder struct { + newXDSClient func() (xdsclient.XDSClient, error) +} // Build helps implement the resolver.Builder interface. // @@ -59,6 +71,11 @@ func (b *xdsResolverBuilder) Build(t resolver.Target, cc resolver.ClientConn, op r.logger = prefixLogger((r)) r.logger.Infof("Creating resolver for target: %+v", t) + newXDSClient := newXDSClient + if b.newXDSClient != nil { + newXDSClient = b.newXDSClient + } + client, err := newXDSClient() if err != nil { return nil, fmt.Errorf("xds: failed to create xds-client: %v", err) @@ -100,15 +117,6 @@ func (*xdsResolverBuilder) Scheme() string { return xdsScheme } -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the resolver. This will be faked out in unittests. -type xdsClientInterface interface { - WatchListener(serviceName string, cb func(xdsclient.ListenerUpdate, error)) func() - WatchRouteConfig(routeName string, cb func(xdsclient.RouteConfigUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - // suWithError wraps the ServiceUpdate and error received through a watch API // callback, so that it can pushed onto the update channel as a single entity. type suWithError struct { @@ -130,7 +138,7 @@ type xdsResolver struct { logger *grpclog.PrefixLogger // The underlying xdsClient which performs all xDS requests and responses. - client xdsClientInterface + client xdsclient.XDSClient // A channel for the watch API callback to write service updates on to. The // updates are read by the run goroutine and passed on to the ClientConn. updateCh chan suWithError @@ -171,13 +179,13 @@ func (r *xdsResolver) sendNewServiceConfig(cs *configSelector) bool { r.cc.ReportError(err) return false } - r.logger.Infof("Received update on resource %v from xds-client %p, generated service config: %v", r.target.Endpoint, r.client, sc) + r.logger.Infof("Received update on resource %v from xds-client %p, generated service config: %v", r.target.Endpoint, r.client, pretty.FormatJSON(sc)) // Send the update to the ClientConn. state := iresolver.SetConfigSelector(resolver.State{ - ServiceConfig: r.cc.ParseServiceConfig(sc), + ServiceConfig: r.cc.ParseServiceConfig(string(sc)), }, cs) - r.cc.UpdateState(state) + r.cc.UpdateState(xdsclient.SetClient(state, r.client)) return true } diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go index 53ea17042aa..90e6c1d4db0 100644 --- a/xds/internal/resolver/xds_resolver_test.go +++ b/xds/internal/resolver/xds_resolver_test.go @@ -26,6 +26,7 @@ import ( "testing" "time" + xxhash "github.com/cespare/xxhash/v2" "github.com/google/go-cmp/cmp" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" @@ -36,19 +37,20 @@ import ( iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/wrr" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/status" _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // To parse LB config "google.golang.org/grpc/xds/internal/balancer/clustermanager" - "google.golang.org/grpc/xds/internal/client" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" - "google.golang.org/grpc/xds/internal/env" + "google.golang.org/grpc/xds/internal/balancer/ringhash" "google.golang.org/grpc/xds/internal/httpfilter" "google.golang.org/grpc/xds/internal/httpfilter/router" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( @@ -88,8 +90,9 @@ type testClientConn struct { errorCh *testutils.Channel } -func (t *testClientConn) UpdateState(s resolver.State) { +func (t *testClientConn) UpdateState(s resolver.State) error { t.stateCh.Send(s) + return nil } func (t *testClientConn) ReportError(err error) { @@ -112,19 +115,19 @@ func newTestClientConn() *testClientConn { func (s) TestResolverBuilder(t *testing.T) { tests := []struct { name string - xdsClientFunc func() (xdsClientInterface, error) + xdsClientFunc func() (xdsclient.XDSClient, error) wantErr bool }{ { name: "simple-good", - xdsClientFunc: func() (xdsClientInterface, error) { + xdsClientFunc: func() (xdsclient.XDSClient, error) { return fakeclient.NewClient(), nil }, wantErr: false, }, { name: "newXDSClient-throws-error", - xdsClientFunc: func() (xdsClientInterface, error) { + xdsClientFunc: func() (xdsclient.XDSClient, error) { return nil, errors.New("newXDSClient-throws-error") }, wantErr: true, @@ -165,7 +168,7 @@ func (s) TestResolverBuilder_xdsCredsBootstrapMismatch(t *testing.T) { // Fake out the xdsClient creation process by providing a fake, which does // not have any certificate provider configuration. oldClientMaker := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { fc := fakeclient.NewClient() fc.SetBootstrapConfig(&bootstrap.Config{}) return fc, nil @@ -192,7 +195,7 @@ func (s) TestResolverBuilder_xdsCredsBootstrapMismatch(t *testing.T) { } type setupOpts struct { - xdsClientFunc func() (xdsClientInterface, error) + xdsClientFunc func() (xdsclient.XDSClient, error) } func testSetup(t *testing.T, opts setupOpts) (*xdsResolver, *testClientConn, func()) { @@ -252,7 +255,7 @@ func waitForWatchRouteConfig(ctx context.Context, t *testing.T, xdsC *fakeclient func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() @@ -265,11 +268,11 @@ func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { // Call the watchAPI callback after closing the resolver, and make sure no // update is triggerred on the ClientConn. xdsR.Close() - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, }, }, }, nil) @@ -279,17 +282,29 @@ func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { } } +// TestXDSResolverCloseClosesXDSClient tests that the XDS resolver's Close +// method closes the XDS client. +func (s) TestXDSResolverCloseClosesXDSClient(t *testing.T) { + xdsC := fakeclient.NewClient() + xdsR, _, cancel := testSetup(t, setupOpts{ + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, + }) + defer cancel() + xdsR.Close() + if !xdsC.Closed.HasFired() { + t.Fatalf("xds client not closed by xds resolver Close method") + } +} + // TestXDSResolverBadServiceUpdate tests the case the xdsClient returns a bad // service update. func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -300,7 +315,7 @@ func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { // Invoke the watchAPI callback with a bad service update and wait for the // ReportError method to be called on the ClientConn. suErr := errors.New("bad serviceupdate") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != nil || gotErrVal != suErr { t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr) @@ -312,12 +327,10 @@ func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -332,7 +345,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { wantClusters map[string]bool }{ { - routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, wantJSON: `{"loadBalancingConfig":[{ "xds_cluster_manager_experimental":{ "children":{ @@ -344,7 +357,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { wantClusters: map[string]bool{"test-cluster-1": true}, }, { - routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ + routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ "cluster_1": {Weight: 75}, "cluster_2": {Weight: 25}, }}}, @@ -368,7 +381,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { wantClusters: map[string]bool{"cluster_1": true, "cluster_2": true}, }, { - routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ + routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ "cluster_1": {Weight: 75}, "cluster_2": {Weight: 25}, }}}, @@ -391,7 +404,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { } { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, @@ -404,7 +417,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { defer cancel() gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -443,12 +456,77 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { } } +// TestXDSResolverRequestHash tests a case where a resolver receives a RouteConfig update +// with a HashPolicy specifying to generate a hash. The configSelector generated should +// successfully generate a Hash. +func (s) TestXDSResolverRequestHash(t *testing.T) { + oldRH := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRH }() + + xdsC := fakeclient.NewClient() + xdsR, tcc, cancel := testSetup(t, setupOpts{ + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, + }) + defer xdsR.Close() + defer cancel() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + waitForWatchListener(ctx, t, xdsC, targetStr) + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr, HTTPFilters: routerFilterList}, nil) + waitForWatchRouteConfig(ctx, t, xdsC, routeStr) + // Invoke watchAPI callback with a good service update (with hash policies + // specified) and wait for UpdateState method to be called on ClientConn. + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{ + { + Domains: []string{targetStr}, + Routes: []*xdsclient.Route{{ + Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{ + "cluster_1": {Weight: 75}, + "cluster_2": {Weight: 25}, + }, + HashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeHeader, + HeaderName: ":path", + }}, + }}, + }, + }, + }, nil) + + ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotState, err := tcc.stateCh.Receive(ctx) + if err != nil { + t.Fatalf("Error waiting for UpdateState to be called: %v", err) + } + rState := gotState.(resolver.State) + cs := iresolver.GetConfigSelector(rState) + if cs == nil { + t.Error("received nil config selector") + } + // Selecting a config when there was a hash policy specified in the route + // that will be selected should put a request hash in the config's context. + res, err := cs.SelectConfig(iresolver.RPCInfo{Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(":path", "/products"))}) + if err != nil { + t.Fatalf("Unexpected error from cs.SelectConfig(_): %v", err) + } + requestHashGot := ringhash.GetRequestHashForTesting(res.Context) + requestHashWant := xxhash.Sum64String("/products") + if requestHashGot != requestHashWant { + t.Fatalf("requestHashGot = %v, requestHashWant = %v", requestHashGot, requestHashWant) + } +} + // TestXDSResolverRemovedWithRPCs tests the case where a config selector sends // an empty update to the resolver after the resource is removed. func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() defer xdsR.Close() @@ -461,18 +539,18 @@ func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, }, }, }, nil) gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -492,10 +570,10 @@ func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { // Delete the resource suErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "resource removed error") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if _, err = tcc.stateCh.Receive(ctx); err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } // "Finish the RPC"; this could cause a panic if the resolver doesn't @@ -508,7 +586,7 @@ func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { func (s) TestXDSResolverRemovedResource(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() defer xdsR.Close() @@ -521,11 +599,11 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, }, }, }, nil) @@ -541,7 +619,7 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -571,10 +649,10 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { // Delete the resource. The channel should receive a service config with the // original cluster but with an erroring config selector. suErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "resource removed error") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if gotState, err = tcc.stateCh.Receive(ctx); err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState = gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -599,7 +677,7 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { // In the meantime, an empty ServiceConfig update should have been sent. if gotState, err = tcc.stateCh.Receive(ctx); err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState = gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -616,12 +694,10 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { func (s) TestXDSResolverWRR(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -634,11 +710,11 @@ func (s) TestXDSResolverWRR(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ "A": {Weight: 5}, "B": {Weight: 10}, }}}, @@ -648,7 +724,7 @@ func (s) TestXDSResolverWRR(t *testing.T) { gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -676,15 +752,12 @@ func (s) TestXDSResolverWRR(t *testing.T) { } func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { - defer func(old bool) { env.TimeoutSupport = old }(env.TimeoutSupport) xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -697,11 +770,11 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{ + Routes: []*xdsclient.Route{{ Prefix: newStringP("/foo"), WeightedClusters: map[string]xdsclient.WeightedCluster{"A": {Weight: 1}}, MaxStreamDuration: newDurationP(5 * time.Second), @@ -719,7 +792,7 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -732,35 +805,25 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { } testCases := []struct { - name string - method string - timeoutSupport bool - want *time.Duration + name string + method string + want *time.Duration }{{ - name: "RDS setting", - method: "/foo/method", - timeoutSupport: true, - want: newDurationP(5 * time.Second), - }, { - name: "timeout support disabled", - method: "/foo/method", - timeoutSupport: false, - want: nil, + name: "RDS setting", + method: "/foo/method", + want: newDurationP(5 * time.Second), }, { - name: "explicit zero in RDS; ignore LDS", - method: "/bar/method", - timeoutSupport: true, - want: nil, + name: "explicit zero in RDS; ignore LDS", + method: "/bar/method", + want: nil, }, { - name: "no config in RDS; fallback to LDS", - method: "/baz/method", - timeoutSupport: true, - want: newDurationP(time.Second), + name: "no config in RDS; fallback to LDS", + method: "/baz/method", + want: newDurationP(time.Second), }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - env.TimeoutSupport = tc.timeoutSupport req := iresolver.RPCInfo{ Method: tc.method, Context: context.Background(), @@ -784,12 +847,10 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -799,18 +860,18 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, }, }, }, nil) gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -849,27 +910,28 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { // Perform TWO updates to ensure the old config selector does not hold a // reference to test-cluster-1. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, }, }, }, nil) - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + tcc.stateCh.Receive(ctx) // Ignore the first update. + + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, }, }, }, nil) - tcc.stateCh.Receive(ctx) // Ignore the first update gotState, err = tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState = gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -897,17 +959,17 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { // test-cluster-1. res.OnCommitted() - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, }, }, }, nil) gotState, err = tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState = gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -934,12 +996,10 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -950,7 +1010,7 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { // Invoke the watchAPI callback with a bad service update and wait for the // ReportError method to be called on the ClientConn. suErr := errors.New("bad serviceupdate") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != nil || gotErrVal != suErr { t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr) @@ -958,17 +1018,17 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, }, }, }, nil) gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -978,7 +1038,7 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { // Invoke the watchAPI callback with a bad service update and wait for the // ReportError method to be called on the ClientConn. suErr2 := errors.New("bad serviceupdate 2") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr2) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr2) if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != nil || gotErrVal != suErr2 { t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr2) } @@ -990,12 +1050,10 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -1006,7 +1064,7 @@ func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { // Invoke the watchAPI callback with a bad service update and wait for the // ReportError method to be called on the ClientConn. suErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "resource removed error") - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{}, suErr) + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{}, suErr) if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != context.DeadlineExceeded { t.Fatalf("ClientConn.ReportError() received %v, %v, want channel recv timeout", gotErrVal, gotErr) @@ -1016,7 +1074,7 @@ func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { defer cancel() gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) wantParsedConfig := internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)("{}") @@ -1030,6 +1088,46 @@ func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { } } +// TestXDSResolverMultipleLDSUpdates tests the case where two LDS updates with +// the same RDS name to watch are received without an RDS in between. Those LDS +// updates shouldn't trigger service config update. +// +// This test case also makes sure the resolver doesn't panic. +func (s) TestXDSResolverMultipleLDSUpdates(t *testing.T) { + xdsC := fakeclient.NewClient() + xdsR, tcc, cancel := testSetup(t, setupOpts{ + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, + }) + defer xdsR.Close() + defer cancel() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + waitForWatchListener(ctx, t, xdsC, targetStr) + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr, HTTPFilters: routerFilterList}, nil) + waitForWatchRouteConfig(ctx, t, xdsC, routeStr) + defer replaceRandNumGenerator(0)() + + // Send a new LDS update, with the same fields. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr, HTTPFilters: routerFilterList}, nil) + ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer cancel() + // Should NOT trigger a state update. + gotState, err := tcc.stateCh.Receive(ctx) + if err == nil { + t.Fatalf("ClientConn.UpdateState received %v, want timeout error", gotState) + } + + // Send a new LDS update, with the same RDS name, but different fields. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr, MaxStreamDuration: time.Second, HTTPFilters: routerFilterList}, nil) + ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer cancel() + gotState, err = tcc.stateCh.Receive(ctx) + if err == nil { + t.Fatalf("ClientConn.UpdateState received %v, want timeout error", gotState) + } +} + type filterBuilder struct { httpfilter.Filter // embedded as we do not need to implement registry / parsing in this test. path *[]string @@ -1173,12 +1271,10 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { t.Run(tc.name, func(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) - defer func() { - cancel() - xdsR.Close() - }() + defer xdsR.Close() + defer cancel() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -1197,11 +1293,11 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { // Invoke the watchAPI callback with a good service update and wait for the // UpdateState method to be called on the ClientConn. - xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + xdsC.InvokeWatchRouteConfigCallback("", xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{ + Routes: []*xdsclient.Route{{ Prefix: newStringP("1"), WeightedClusters: map[string]xdsclient.WeightedCluster{ "A": {Weight: 1}, "B": {Weight: 1}, @@ -1220,7 +1316,7 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { gotState, err := tcc.stateCh.Receive(ctx) if err != nil { - t.Fatalf("ClientConn.UpdateState returned error: %v", err) + t.Fatalf("Error waiting for UpdateState to be called: %v", err) } rState := gotState.(resolver.State) if err := rState.ServiceConfig.Err; err != nil { @@ -1295,13 +1391,13 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { func replaceRandNumGenerator(start int64) func() { nextInt := start - grpcrandInt63n = func(int64) (ret int64) { + xdsclient.RandInt63n = func(int64) (ret int64) { ret = nextInt nextInt++ return } return func() { - grpcrandInt63n = grpcrand.Int63n + xdsclient.RandInt63n = grpcrand.Int63n } } diff --git a/xds/internal/server/conn_wrapper.go b/xds/internal/server/conn_wrapper.go new file mode 100644 index 00000000000..dd0374dc88e --- /dev/null +++ b/xds/internal/server/conn_wrapper.go @@ -0,0 +1,165 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + "google.golang.org/grpc/credentials/tls/certprovider" + xdsinternal "google.golang.org/grpc/internal/credentials/xds" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +// connWrapper is a thin wrapper around a net.Conn returned by Accept(). It +// provides the following additional functionality: +// 1. A way to retrieve the configured deadline. This is required by the +// ServerHandshake() method of the xdsCredentials when it attempts to read +// key material from the certificate providers. +// 2. Implements the XDSHandshakeInfo() method used by the xdsCredentials to +// retrieve the configured certificate providers. +// 3. xDS filter_chain matching logic to select appropriate security +// configuration for the incoming connection. +type connWrapper struct { + net.Conn + + // The specific filter chain picked for handling this connection. + filterChain *xdsclient.FilterChain + + // A reference fo the listenerWrapper on which this connection was accepted. + parent *listenerWrapper + + // The certificate providers created for this connection. + rootProvider, identityProvider certprovider.Provider + + // The connection deadline as configured by the grpc.Server on the rawConn + // that is returned by a call to Accept(). This is set to the connection + // timeout value configured by the user (or to a default value) before + // initiating the transport credential handshake, and set to zero after + // completing the HTTP2 handshake. + deadlineMu sync.Mutex + deadline time.Time + + // The virtual hosts with matchable routes and instantiated HTTP Filters per + // route. + virtualHosts []xdsclient.VirtualHostWithInterceptors +} + +// VirtualHosts returns the virtual hosts to be used for server side routing. +func (c *connWrapper) VirtualHosts() []xdsclient.VirtualHostWithInterceptors { + return c.virtualHosts +} + +// SetDeadline makes a copy of the passed in deadline and forwards the call to +// the underlying rawConn. +func (c *connWrapper) SetDeadline(t time.Time) error { + c.deadlineMu.Lock() + c.deadline = t + c.deadlineMu.Unlock() + return c.Conn.SetDeadline(t) +} + +// GetDeadline returns the configured deadline. This will be invoked by the +// ServerHandshake() method of the XdsCredentials, which needs a deadline to +// pass to the certificate provider. +func (c *connWrapper) GetDeadline() time.Time { + c.deadlineMu.Lock() + t := c.deadline + c.deadlineMu.Unlock() + return t +} + +// XDSHandshakeInfo returns a HandshakeInfo with appropriate security +// configuration for this connection. This method is invoked by the +// ServerHandshake() method of the XdsCredentials. +func (c *connWrapper) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) { + // Ideally this should never happen, since xdsCredentials are the only ones + // which will invoke this method at handshake time. But to be on the safe + // side, we avoid acting on the security configuration received from the + // control plane when the user has not configured the use of xDS + // credentials, by checking the value of this flag. + if !c.parent.xdsCredsInUse { + return nil, errors.New("user has not configured xDS credentials") + } + + if c.filterChain.SecurityCfg == nil { + // If the security config is empty, this means that the control plane + // did not provide any security configuration and therefore we should + // return an empty HandshakeInfo here so that the xdsCreds can use the + // configured fallback credentials. + return xdsinternal.NewHandshakeInfo(nil, nil), nil + } + + cpc := c.parent.xdsC.BootstrapConfig().CertProviderConfigs + // Identity provider name is mandatory on the server-side, and this is + // enforced when the resource is received at the XDSClient layer. + secCfg := c.filterChain.SecurityCfg + ip, err := buildProviderFunc(cpc, secCfg.IdentityInstanceName, secCfg.IdentityCertName, true, false) + if err != nil { + return nil, err + } + // Root provider name is optional and required only when doing mTLS. + var rp certprovider.Provider + if instance, cert := secCfg.RootInstanceName, secCfg.RootCertName; instance != "" { + rp, err = buildProviderFunc(cpc, instance, cert, false, true) + if err != nil { + return nil, err + } + } + c.identityProvider = ip + c.rootProvider = rp + + xdsHI := xdsinternal.NewHandshakeInfo(c.rootProvider, c.identityProvider) + xdsHI.SetRequireClientCert(secCfg.RequireClientCert) + return xdsHI, nil +} + +// Close closes the providers and the underlying connection. +func (c *connWrapper) Close() error { + if c.identityProvider != nil { + c.identityProvider.Close() + } + if c.rootProvider != nil { + c.rootProvider.Close() + } + return c.Conn.Close() +} + +func buildProviderFunc(configs map[string]*certprovider.BuildableConfig, instanceName, certName string, wantIdentity, wantRoot bool) (certprovider.Provider, error) { + cfg, ok := configs[instanceName] + if !ok { + return nil, fmt.Errorf("certificate provider instance %q not found in bootstrap file", instanceName) + } + provider, err := cfg.Build(certprovider.BuildOptions{ + CertName: certName, + WantIdentity: wantIdentity, + WantRoot: wantRoot, + }) + if err != nil { + // This error is not expected since the bootstrap process parses the + // config and makes sure that it is acceptable to the plugin. Still, it + // is possible that the plugin parses the config successfully, but its + // Build() method errors out. + return nil, fmt.Errorf("failed to get security plugin instance (%+v): %v", cfg, err) + } + return provider, nil +} diff --git a/xds/internal/server/listener_wrapper.go b/xds/internal/server/listener_wrapper.go new file mode 100644 index 00000000000..99c9a753230 --- /dev/null +++ b/xds/internal/server/listener_wrapper.go @@ -0,0 +1,442 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package server contains internal server-side functionality used by the public +// facing xds package. +package server + +import ( + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + "unsafe" + + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + internalbackoff "google.golang.org/grpc/internal/backoff" + internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" +) + +var ( + logger = grpclog.Component("xds") + + // Backoff strategy for temporary errors received from Accept(). If this + // needs to be configurable, we can inject it through ListenerWrapperParams. + bs = internalbackoff.Exponential{Config: backoff.Config{ + BaseDelay: 5 * time.Millisecond, + Multiplier: 2.0, + MaxDelay: 1 * time.Second, + }} + backoffFunc = bs.Backoff +) + +// ServingModeCallback is the callback that users can register to get notified +// about the server's serving mode changes. The callback is invoked with the +// address of the listener and its new mode. The err parameter is set to a +// non-nil error if the server has transitioned into not-serving mode. +type ServingModeCallback func(addr net.Addr, mode connectivity.ServingMode, err error) + +// DrainCallback is the callback that an xDS-enabled server registers to get +// notified about updates to the Listener configuration. The server is expected +// to gracefully shutdown existing connections, thereby forcing clients to +// reconnect and have the new configuration applied to the newly created +// connections. +type DrainCallback func(addr net.Addr) + +func prefixLogger(p *listenerWrapper) *internalgrpclog.PrefixLogger { + return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[xds-server-listener %p] ", p)) +} + +// XDSClient wraps the methods on the XDSClient which are required by +// the listenerWrapper. +type XDSClient interface { + WatchListener(string, func(xdsclient.ListenerUpdate, error)) func() + WatchRouteConfig(string, func(xdsclient.RouteConfigUpdate, error)) func() + BootstrapConfig() *bootstrap.Config +} + +// ListenerWrapperParams wraps parameters required to create a listenerWrapper. +type ListenerWrapperParams struct { + // Listener is the net.Listener passed by the user that is to be wrapped. + Listener net.Listener + // ListenerResourceName is the xDS Listener resource to request. + ListenerResourceName string + // XDSCredsInUse specifies whether or not the user expressed interest to + // receive security configuration from the control plane. + XDSCredsInUse bool + // XDSClient provides the functionality from the XDSClient required here. + XDSClient XDSClient + // ModeCallback is the callback to invoke when the serving mode changes. + ModeCallback ServingModeCallback + // DrainCallback is the callback to invoke when the Listener gets a LDS + // update. + DrainCallback DrainCallback +} + +// NewListenerWrapper creates a new listenerWrapper with params. It returns a +// net.Listener and a channel which is written to, indicating that the former is +// ready to be passed to grpc.Serve(). +// +// Only TCP listeners are supported. +func NewListenerWrapper(params ListenerWrapperParams) (net.Listener, <-chan struct{}) { + lw := &listenerWrapper{ + Listener: params.Listener, + name: params.ListenerResourceName, + xdsCredsInUse: params.XDSCredsInUse, + xdsC: params.XDSClient, + modeCallback: params.ModeCallback, + drainCallback: params.DrainCallback, + isUnspecifiedAddr: params.Listener.Addr().(*net.TCPAddr).IP.IsUnspecified(), + + closed: grpcsync.NewEvent(), + goodUpdate: grpcsync.NewEvent(), + ldsUpdateCh: make(chan ldsUpdateWithError, 1), + rdsUpdateCh: make(chan rdsHandlerUpdate, 1), + } + lw.logger = prefixLogger(lw) + + // Serve() verifies that Addr() returns a valid TCPAddr. So, it is safe to + // ignore the error from SplitHostPort(). + lisAddr := lw.Listener.Addr().String() + lw.addr, lw.port, _ = net.SplitHostPort(lisAddr) + + lw.rdsHandler = newRDSHandler(lw.xdsC, lw.rdsUpdateCh) + + cancelWatch := lw.xdsC.WatchListener(lw.name, lw.handleListenerUpdate) + lw.logger.Infof("Watch started on resource name %v", lw.name) + lw.cancelWatch = func() { + cancelWatch() + lw.logger.Infof("Watch cancelled on resource name %v", lw.name) + } + go lw.run() + return lw, lw.goodUpdate.Done() +} + +type ldsUpdateWithError struct { + update xdsclient.ListenerUpdate + err error +} + +// listenerWrapper wraps the net.Listener associated with the listening address +// passed to Serve(). It also contains all other state associated with this +// particular invocation of Serve(). +type listenerWrapper struct { + net.Listener + logger *internalgrpclog.PrefixLogger + + name string + xdsCredsInUse bool + xdsC XDSClient + cancelWatch func() + modeCallback ServingModeCallback + drainCallback DrainCallback + + // Set to true if the listener is bound to the IP_ANY address (which is + // "0.0.0.0" for IPv4 and "::" for IPv6). + isUnspecifiedAddr bool + // Listening address and port. Used to validate the socket address in the + // Listener resource received from the control plane. + addr, port string + + // This is used to notify that a good update has been received and that + // Serve() can be invoked on the underlying gRPC server. Using an event + // instead of a vanilla channel simplifies the update handler as it need not + // keep track of whether the received update is the first one or not. + goodUpdate *grpcsync.Event + // A small race exists in the XDSClient code between the receipt of an xDS + // response and the user cancelling the associated watch. In this window, + // the registered callback may be invoked after the watch is canceled, and + // the user is expected to work around this. This event signifies that the + // listener is closed (and hence the watch is cancelled), and we drop any + // updates received in the callback if this event has fired. + closed *grpcsync.Event + + // mu guards access to the current serving mode and the filter chains. The + // reason for using an rw lock here is that these fields are read in + // Accept() for all incoming connections, but writes happen rarely (when we + // get a Listener resource update). + mu sync.RWMutex + // Current serving mode. + mode connectivity.ServingMode + // Filter chains received as part of the last good update. + filterChains *xdsclient.FilterChainManager + + // rdsHandler is used for any dynamic RDS resources specified in a LDS + // update. + rdsHandler *rdsHandler + // rdsUpdates are the RDS resources received from the management + // server, keyed on the RouteName of the RDS resource. + rdsUpdates unsafe.Pointer // map[string]xdsclient.RouteConfigUpdate + // ldsUpdateCh is a channel for XDSClient LDS updates. + ldsUpdateCh chan ldsUpdateWithError + // rdsUpdateCh is a channel for XDSClient RDS updates. + rdsUpdateCh chan rdsHandlerUpdate +} + +// Accept blocks on an Accept() on the underlying listener, and wraps the +// returned net.connWrapper with the configured certificate providers. +func (l *listenerWrapper) Accept() (net.Conn, error) { + var retries int + for { + conn, err := l.Listener.Accept() + if err != nil { + // Temporary() method is implemented by certain error types returned + // from the net package, and it is useful for us to not shutdown the + // server in these conditions. The listen queue being full is one + // such case. + if ne, ok := err.(interface{ Temporary() bool }); !ok || !ne.Temporary() { + return nil, err + } + retries++ + timer := time.NewTimer(backoffFunc(retries)) + select { + case <-timer.C: + case <-l.closed.Done(): + timer.Stop() + // Continuing here will cause us to call Accept() again + // which will return a non-temporary error. + continue + } + continue + } + // Reset retries after a successful Accept(). + retries = 0 + + // Since the net.Conn represents an incoming connection, the source and + // destination address can be retrieved from the local address and + // remote address of the net.Conn respectively. + destAddr, ok1 := conn.LocalAddr().(*net.TCPAddr) + srcAddr, ok2 := conn.RemoteAddr().(*net.TCPAddr) + if !ok1 || !ok2 { + // If the incoming connection is not a TCP connection, which is + // really unexpected since we check whether the provided listener is + // a TCP listener in Serve(), we return an error which would cause + // us to stop serving. + return nil, fmt.Errorf("received connection with non-TCP address (local: %T, remote %T)", conn.LocalAddr(), conn.RemoteAddr()) + } + + l.mu.RLock() + if l.mode == connectivity.ServingModeNotServing { + // Close connections as soon as we accept them when we are in + // "not-serving" mode. Since we accept a net.Listener from the user + // in Serve(), we cannot close the listener when we move to + // "not-serving". Closing the connection immediately upon accepting + // is one of the other ways to implement the "not-serving" mode as + // outlined in gRFC A36. + l.mu.RUnlock() + conn.Close() + continue + } + fc, err := l.filterChains.Lookup(xdsclient.FilterChainLookupParams{ + IsUnspecifiedListener: l.isUnspecifiedAddr, + DestAddr: destAddr.IP, + SourceAddr: srcAddr.IP, + SourcePort: srcAddr.Port, + }) + l.mu.RUnlock() + if err != nil { + // When a matching filter chain is not found, we close the + // connection right away, but do not return an error back to + // `grpc.Serve()` from where this Accept() was invoked. Returning an + // error to `grpc.Serve()` causes the server to shutdown. If we want + // to avoid the server from shutting down, we would need to return + // an error type which implements the `Temporary() bool` method, + // which is invoked by `grpc.Serve()` to see if the returned error + // represents a temporary condition. In the case of a temporary + // error, `grpc.Serve()` method sleeps for a small duration and + // therefore ends up blocking all connection attempts during that + // time frame, which is also not ideal for an error like this. + l.logger.Warningf("connection from %s to %s failed to find any matching filter chain", conn.RemoteAddr().String(), conn.LocalAddr().String()) + conn.Close() + continue + } + if !env.RBACSupport { + return &connWrapper{Conn: conn, filterChain: fc, parent: l}, nil + } + var rc xdsclient.RouteConfigUpdate + if fc.InlineRouteConfig != nil { + rc = *fc.InlineRouteConfig + } else { + rcPtr := atomic.LoadPointer(&l.rdsUpdates) + rcuPtr := (*map[string]xdsclient.RouteConfigUpdate)(rcPtr) + // This shouldn't happen, but this error protects against a panic. + if rcuPtr == nil { + return nil, errors.New("route configuration pointer is nil") + } + rcu := *rcuPtr + rc = rcu[fc.RouteConfigName] + } + // The filter chain will construct a usuable route table on each + // connection accept. This is done because preinstantiating every route + // table before it is needed for a connection would potentially lead to + // a lot of cpu time and memory allocated for route tables that will + // never be used. There was also a thought to cache this configuration, + // and reuse it for the next accepted connection. However, this would + // lead to a lot of code complexity (RDS Updates for a given route name + // can come it at any time), and connections aren't accepted too often, + // so this reinstantation of the Route Configuration is an acceptable + // tradeoff for simplicity. + vhswi, err := fc.ConstructUsableRouteConfiguration(rc) + if err != nil { + l.logger.Warningf("route configuration construction: %v", err) + conn.Close() + continue + } + return &connWrapper{Conn: conn, filterChain: fc, parent: l, virtualHosts: vhswi}, nil + } +} + +// Close closes the underlying listener. It also cancels the xDS watch +// registered in Serve() and closes any certificate provider instances created +// based on security configuration received in the LDS response. +func (l *listenerWrapper) Close() error { + l.closed.Fire() + l.Listener.Close() + if l.cancelWatch != nil { + l.cancelWatch() + } + l.rdsHandler.close() + return nil +} + +// run is a long running goroutine which handles all xds updates. LDS and RDS +// push updates onto a channel which is read and acted upon from this goroutine. +func (l *listenerWrapper) run() { + for { + select { + case <-l.closed.Done(): + return + case u := <-l.ldsUpdateCh: + l.handleLDSUpdate(u) + case u := <-l.rdsUpdateCh: + l.handleRDSUpdate(u) + } + } +} + +// handleLDSUpdate is the callback which handles LDS Updates. It writes the +// received update to the update channel, which is picked up by the run +// goroutine. +func (l *listenerWrapper) handleListenerUpdate(update xdsclient.ListenerUpdate, err error) { + if l.closed.HasFired() { + l.logger.Warningf("Resource %q received update: %v with error: %v, after listener was closed", l.name, update, err) + return + } + // Remove any existing entry in ldsUpdateCh and replace with the new one, as the only update + // listener cares about is most recent update. + select { + case <-l.ldsUpdateCh: + default: + } + l.ldsUpdateCh <- ldsUpdateWithError{update: update, err: err} +} + +// handleRDSUpdate handles a full rds update from rds handler. On a successful +// update, the server will switch to ServingModeServing as the full +// configuration (both LDS and RDS) has been received. +func (l *listenerWrapper) handleRDSUpdate(update rdsHandlerUpdate) { + if l.closed.HasFired() { + l.logger.Warningf("RDS received update: %v with error: %v, after listener was closed", update.updates, update.err) + return + } + if update.err != nil { + l.logger.Warningf("Received error for rds names specified in resource %q: %+v", l.name, update.err) + if xdsclient.ErrType(update.err) == xdsclient.ErrorTypeResourceNotFound { + l.switchMode(nil, connectivity.ServingModeNotServing, update.err) + } + // For errors which are anything other than "resource-not-found", we + // continue to use the old configuration. + return + } + atomic.StorePointer(&l.rdsUpdates, unsafe.Pointer(&update.updates)) + + l.switchMode(l.filterChains, connectivity.ServingModeServing, nil) + l.goodUpdate.Fire() +} + +func (l *listenerWrapper) handleLDSUpdate(update ldsUpdateWithError) { + if update.err != nil { + l.logger.Warningf("Received error for resource %q: %+v", l.name, update.err) + if xdsclient.ErrType(update.err) == xdsclient.ErrorTypeResourceNotFound { + l.switchMode(nil, connectivity.ServingModeNotServing, update.err) + } + // For errors which are anything other than "resource-not-found", we + // continue to use the old configuration. + return + } + l.logger.Infof("Received update for resource %q: %+v", l.name, update.update) + + // Make sure that the socket address on the received Listener resource + // matches the address of the net.Listener passed to us by the user. This + // check is done here instead of at the XDSClient layer because of the + // following couple of reasons: + // - XDSClient cannot know the listening address of every listener in the + // system, and hence cannot perform this check. + // - this is a very context-dependent check and only the server has the + // appropriate context to perform this check. + // + // What this means is that the XDSClient has ACKed a resource which can push + // the server into a "not serving" mode. This is not ideal, but this is + // what we have decided to do. See gRPC A36 for more details. + ilc := update.update.InboundListenerCfg + if ilc.Address != l.addr || ilc.Port != l.port { + l.switchMode(nil, connectivity.ServingModeNotServing, fmt.Errorf("address (%s:%s) in Listener update does not match listening address: (%s:%s)", ilc.Address, ilc.Port, l.addr, l.port)) + return + } + + // "Updates to a Listener cause all older connections on that Listener to be + // gracefully shut down with a grace period of 10 minutes for long-lived + // RPC's, such that clients will reconnect and have the updated + // configuration apply." - A36 Note that this is not the same as moving the + // Server's state to ServingModeNotServing. That prevents new connections + // from being accepted, whereas here we simply want the clients to reconnect + // to get the updated configuration. + if env.RBACSupport { + if l.drainCallback != nil { + l.drainCallback(l.Listener.Addr()) + } + } + l.rdsHandler.updateRouteNamesToWatch(ilc.FilterChains.RouteConfigNames) + // If there are no dynamic RDS Configurations still needed to be received + // from the management server, this listener has all the configuration + // needed, and is ready to serve. + if len(ilc.FilterChains.RouteConfigNames) == 0 { + l.switchMode(ilc.FilterChains, connectivity.ServingModeServing, nil) + l.goodUpdate.Fire() + } +} + +func (l *listenerWrapper) switchMode(fcs *xdsclient.FilterChainManager, newMode connectivity.ServingMode, err error) { + l.mu.Lock() + defer l.mu.Unlock() + + l.filterChains = fcs + l.mode = newMode + if l.modeCallback != nil { + l.modeCallback(l.Listener.Addr(), newMode, err) + } + l.logger.Warningf("Listener %q entering mode: %q due to error: %v", l.Addr(), newMode, err) +} diff --git a/xds/internal/server/listener_wrapper_test.go b/xds/internal/server/listener_wrapper_test.go new file mode 100644 index 00000000000..38372936366 --- /dev/null +++ b/xds/internal/server/listener_wrapper_test.go @@ -0,0 +1,484 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "context" + "errors" + "net" + "strconv" + "testing" + "time" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + _ "google.golang.org/grpc/xds/internal/httpfilter/router" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const ( + fakeListenerHost = "0.0.0.0" + fakeListenerPort = 50051 + testListenerResourceName = "lds.target.1.2.3.4:1111" + defaultTestTimeout = 1 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond +) + +var listenerWithRouteConfiguration = &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourcePorts: []uint32{80}, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, +} + +var listenerWithFilterChains = &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourcePorts: []uint32{80}, + }, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, +} + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type tempError struct{} + +func (tempError) Error() string { + return "listenerWrapper test temporary error" +} + +func (tempError) Temporary() bool { + return true +} + +// connAndErr wraps a net.Conn and an error. +type connAndErr struct { + conn net.Conn + err error +} + +// fakeListener allows the user to inject conns returned by Accept(). +type fakeListener struct { + acceptCh chan connAndErr + closeCh *testutils.Channel +} + +func (fl *fakeListener) Accept() (net.Conn, error) { + cne, ok := <-fl.acceptCh + if !ok { + return nil, errors.New("a non-temporary error") + } + return cne.conn, cne.err +} + +func (fl *fakeListener) Close() error { + fl.closeCh.Send(nil) + return nil +} + +func (fl *fakeListener) Addr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(0, 0, 0, 0), + Port: fakeListenerPort, + } +} + +// fakeConn overrides LocalAddr, RemoteAddr and Close methods. +type fakeConn struct { + net.Conn + local, remote net.Addr + closeCh *testutils.Channel +} + +func (fc *fakeConn) LocalAddr() net.Addr { + return fc.local +} + +func (fc *fakeConn) RemoteAddr() net.Addr { + return fc.remote +} + +func (fc *fakeConn) Close() error { + fc.closeCh.Send(nil) + return nil +} + +func newListenerWrapper(t *testing.T) (*listenerWrapper, <-chan struct{}, *fakeclient.Client, *fakeListener, func()) { + t.Helper() + + // Create a listener wrapper with a fake listener and fake XDSClient and + // verify that it extracts the host and port from the passed in listener. + lis := &fakeListener{ + acceptCh: make(chan connAndErr, 1), + closeCh: testutils.NewChannel(), + } + xdsC := fakeclient.NewClient() + lParams := ListenerWrapperParams{ + Listener: lis, + ListenerResourceName: testListenerResourceName, + XDSClient: xdsC, + } + l, readyCh := NewListenerWrapper(lParams) + if l == nil { + t.Fatalf("NewListenerWrapper(%+v) returned nil", lParams) + } + lw, ok := l.(*listenerWrapper) + if !ok { + t.Fatalf("NewListenerWrapper(%+v) returned listener of type %T want *listenerWrapper", lParams, l) + } + if lw.addr != fakeListenerHost || lw.port != strconv.Itoa(fakeListenerPort) { + t.Fatalf("listenerWrapper has host:port %s:%s, want %s:%d", lw.addr, lw.port, fakeListenerHost, fakeListenerPort) + } + return lw, readyCh, xdsC, lis, func() { l.Close() } +} + +func (s) TestNewListenerWrapper(t *testing.T) { + _, readyCh, xdsC, _, cleanup := newListenerWrapper(t) + defer cleanup() + + // Verify that the listener wrapper registers a listener watch for the + // expected Listener resource name. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + name, err := xdsC.WaitForWatchListener(ctx) + if err != nil { + t.Fatalf("error when waiting for a watch on a Listener resource: %v", err) + } + if name != testListenerResourceName { + t.Fatalf("listenerWrapper registered a lds watch on %s, want %s", name, testListenerResourceName) + } + + // Push an error to the listener update handler. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{}, errors.New("bad listener update")) + timer := time.NewTimer(defaultTestShortTimeout) + select { + case <-timer.C: + timer.Stop() + case <-readyCh: + t.Fatalf("ready channel written to after receipt of a bad Listener update") + } + + fcm, err := xdsclient.NewFilterChainManager(listenerWithFilterChains) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + + // Push an update whose address does not match the address to which our + // listener is bound, and verify that the ready channel is not written to. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: "10.0.0.1", + Port: "50051", + FilterChains: fcm, + }}, nil) + timer = time.NewTimer(defaultTestShortTimeout) + select { + case <-timer.C: + timer.Stop() + case <-readyCh: + t.Fatalf("ready channel written to after receipt of a bad Listener update") + } + + // Push a good update, and verify that the ready channel is written to. + // Since there are no dynamic RDS updates needed to be received, the + // ListenerWrapper does not have to wait for anything else before telling + // that it is ready. + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: fakeListenerHost, + Port: strconv.Itoa(fakeListenerPort), + FilterChains: fcm, + }}, nil) + + select { + case <-ctx.Done(): + t.Fatalf("timeout waiting for the ready channel to be written to after receipt of a good Listener update") + case <-readyCh: + } +} + +// TestNewListenerWrapperWithRouteUpdate tests the scenario where the listener +// gets built, starts a watch, that watch returns a list of Route Names to +// return, than receives an update from the rds handler. Only after receiving +// the update from the rds handler should it move the server to +// ServingModeServing. +func (s) TestNewListenerWrapperWithRouteUpdate(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + _, readyCh, xdsC, _, cleanup := newListenerWrapper(t) + defer cleanup() + + // Verify that the listener wrapper registers a listener watch for the + // expected Listener resource name. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + name, err := xdsC.WaitForWatchListener(ctx) + if err != nil { + t.Fatalf("error when waiting for a watch on a Listener resource: %v", err) + } + if name != testListenerResourceName { + t.Fatalf("listenerWrapper registered a lds watch on %s, want %s", name, testListenerResourceName) + } + fcm, err := xdsclient.NewFilterChainManager(listenerWithRouteConfiguration) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + + // Push a good update which contains a Filter Chain that specifies dynamic + // RDS Resources that need to be received. This should ping rds handler + // about which rds names to start, which will eventually start a watch on + // xds client for rds name "route-1". + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: fakeListenerHost, + Port: strconv.Itoa(fakeListenerPort), + FilterChains: fcm, + }}, nil) + + // This should start a watch on xds client for rds name "route-1". + routeName, err := xdsC.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("error when waiting for a watch on a Route resource: %v", err) + } + if routeName != "route-1" { + t.Fatalf("listenerWrapper registered a lds watch on %s, want %s", routeName, "route-1") + } + + // This shouldn't invoke good update channel, as has not received rds updates yet. + timer := time.NewTimer(defaultTestShortTimeout) + select { + case <-timer.C: + timer.Stop() + case <-readyCh: + t.Fatalf("ready channel written to without rds configuration specified") + } + + // Invoke rds callback for the started rds watch. This valid rds callback + // should trigger the listener wrapper to fire GoodUpdate, as it has + // received both it's LDS Configuration and also RDS Configuration, + // specified in LDS Configuration. + xdsC.InvokeWatchRouteConfigCallback("route-1", xdsclient.RouteConfigUpdate{}, nil) + + // All of the xDS updates have completed, so can expect to send a ping on + // good update channel. + select { + case <-ctx.Done(): + t.Fatalf("timeout waiting for the ready channel to be written to after receipt of a good rds update") + case <-readyCh: + } +} + +func (s) TestListenerWrapper_Accept(t *testing.T) { + boCh := testutils.NewChannel() + origBackoffFunc := backoffFunc + backoffFunc = func(v int) time.Duration { + boCh.Send(v) + return 0 + } + defer func() { backoffFunc = origBackoffFunc }() + + lw, readyCh, xdsC, lis, cleanup := newListenerWrapper(t) + defer cleanup() + + // Push a good update with a filter chain which accepts local connections on + // 192.168.0.0/16 subnet and port 80. + fcm, err := xdsclient.NewFilterChainManager(listenerWithFilterChains) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: fakeListenerHost, + Port: strconv.Itoa(fakeListenerPort), + FilterChains: fcm, + }}, nil) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + defer close(lis.acceptCh) + select { + case <-ctx.Done(): + t.Fatalf("timeout waiting for the ready channel to be written to after receipt of a good Listener update") + case <-readyCh: + } + + // Push a non-temporary error into Accept(). + nonTempErr := errors.New("a non-temporary error") + lis.acceptCh <- connAndErr{err: nonTempErr} + if _, err := lw.Accept(); err != nonTempErr { + t.Fatalf("listenerWrapper.Accept() returned error: %v, want: %v", err, nonTempErr) + } + + // Invoke Accept() in a goroutine since we expect it to swallow: + // 1. temporary errors returned from the underlying listener + // 2. errors related to finding a matching filter chain for the incoming + // connection. + errCh := testutils.NewChannel() + go func() { + conn, err := lw.Accept() + if err != nil { + errCh.Send(err) + return + } + if _, ok := conn.(*connWrapper); !ok { + errCh.Send(errors.New("listenerWrapper.Accept() returned a Conn of type %T, want *connWrapper")) + return + } + errCh.Send(nil) + }() + + // Push a temporary error into Accept() and verify that it backs off. + lis.acceptCh <- connAndErr{err: tempError{}} + if _, err := boCh.Receive(ctx); err != nil { + t.Fatalf("error when waiting for Accept() to backoff on temporary errors: %v", err) + } + + // Push a fakeConn which does not match any filter chains configured on the + // received Listener resource. Verify that the conn is closed. + fc := &fakeConn{ + local: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 79}, + remote: &net.TCPAddr{IP: net.IPv4(10, 1, 1, 1), Port: 80}, + closeCh: testutils.NewChannel(), + } + lis.acceptCh <- connAndErr{conn: fc} + if _, err := fc.closeCh.Receive(ctx); err != nil { + t.Fatalf("error when waiting for conn to be closed on no filter chain match: %v", err) + } + + // Push a fakeConn which matches the filter chains configured on the + // received Listener resource. Verify that Accept() returns. + fc = &fakeConn{ + local: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 2)}, + remote: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 80}, + closeCh: testutils.NewChannel(), + } + lis.acceptCh <- connAndErr{conn: fc} + if _, err := errCh.Receive(ctx); err != nil { + t.Fatalf("error when waiting for Accept() to return the conn on filter chain match: %v", err) + } +} diff --git a/xds/internal/server/rds_handler.go b/xds/internal/server/rds_handler.go new file mode 100644 index 00000000000..cc676c4ca05 --- /dev/null +++ b/xds/internal/server/rds_handler.go @@ -0,0 +1,133 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "sync" + + "google.golang.org/grpc/xds/internal/xdsclient" +) + +// rdsHandlerUpdate wraps the full RouteConfigUpdate that are dynamically +// queried for a given server side listener. +type rdsHandlerUpdate struct { + updates map[string]xdsclient.RouteConfigUpdate + err error +} + +// rdsHandler handles any RDS queries that need to be started for a given server +// side listeners Filter Chains (i.e. not inline). +type rdsHandler struct { + xdsC XDSClient + + mu sync.Mutex + updates map[string]xdsclient.RouteConfigUpdate + cancels map[string]func() + + // For a rdsHandler update, the only update wrapped listener cares about is + // most recent one, so this channel will be opportunistically drained before + // sending any new updates. + updateChannel chan rdsHandlerUpdate +} + +// newRDSHandler creates a new rdsHandler to watch for RDS resources. +// listenerWrapper updates the list of route names to watch by calling +// updateRouteNamesToWatch() upon receipt of new Listener configuration. +func newRDSHandler(xdsC XDSClient, ch chan rdsHandlerUpdate) *rdsHandler { + return &rdsHandler{ + xdsC: xdsC, + updateChannel: ch, + updates: make(map[string]xdsclient.RouteConfigUpdate), + cancels: make(map[string]func()), + } +} + +// updateRouteNamesToWatch handles a list of route names to watch for a given +// server side listener (if a filter chain specifies dynamic RDS configuration). +// This function handles all the logic with respect to any routes that may have +// been added or deleted as compared to what was previously present. +func (rh *rdsHandler) updateRouteNamesToWatch(routeNamesToWatch map[string]bool) { + rh.mu.Lock() + defer rh.mu.Unlock() + // Add and start watches for any routes for any new routes in + // routeNamesToWatch. + for routeName := range routeNamesToWatch { + if _, ok := rh.cancels[routeName]; !ok { + func(routeName string) { + rh.cancels[routeName] = rh.xdsC.WatchRouteConfig(routeName, func(update xdsclient.RouteConfigUpdate, err error) { + rh.handleRouteUpdate(routeName, update, err) + }) + }(routeName) + } + } + + // Delete and cancel watches for any routes from persisted routeNamesToWatch + // that are no longer present. + for routeName := range rh.cancels { + if _, ok := routeNamesToWatch[routeName]; !ok { + rh.cancels[routeName]() + delete(rh.cancels, routeName) + delete(rh.updates, routeName) + } + } + + // If the full list (determined by length) of updates are now successfully + // updated, the listener is ready to be updated. + if len(rh.updates) == len(rh.cancels) && len(routeNamesToWatch) != 0 { + drainAndPush(rh.updateChannel, rdsHandlerUpdate{updates: rh.updates}) + } +} + +// handleRouteUpdate persists the route config for a given route name, and also +// sends an update to the Listener Wrapper on an error received or if the rds +// handler has a full collection of updates. +func (rh *rdsHandler) handleRouteUpdate(routeName string, update xdsclient.RouteConfigUpdate, err error) { + if err != nil { + drainAndPush(rh.updateChannel, rdsHandlerUpdate{err: err}) + return + } + rh.mu.Lock() + defer rh.mu.Unlock() + rh.updates[routeName] = update + + // If the full list (determined by length) of updates have successfully + // updated, the listener is ready to be updated. + if len(rh.updates) == len(rh.cancels) { + drainAndPush(rh.updateChannel, rdsHandlerUpdate{updates: rh.updates}) + } +} + +func drainAndPush(ch chan rdsHandlerUpdate, update rdsHandlerUpdate) { + select { + case <-ch: + default: + } + ch <- update +} + +// close() is meant to be called by wrapped listener when the wrapped listener +// is closed, and it cleans up resources by canceling all the active RDS +// watches. +func (rh *rdsHandler) close() { + rh.mu.Lock() + defer rh.mu.Unlock() + for _, cancel := range rh.cancels { + cancel() + } +} diff --git a/xds/internal/server/rds_handler_test.go b/xds/internal/server/rds_handler_test.go new file mode 100644 index 00000000000..d1daffd940c --- /dev/null +++ b/xds/internal/server/rds_handler_test.go @@ -0,0 +1,401 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package server + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const ( + route1 = "route1" + route2 = "route2" + route3 = "route3" +) + +// setupTests creates a rds handler with a fake xds client for control over the +// xds client. +func setupTests() (*rdsHandler, *fakeclient.Client, chan rdsHandlerUpdate) { + xdsC := fakeclient.NewClient() + ch := make(chan rdsHandlerUpdate, 1) + rh := newRDSHandler(xdsC, ch) + return rh, xdsC, ch +} + +// waitForFuncWithNames makes sure that a blocking function returns the correct +// set of names, where order doesn't matter. This takes away nondeterminism from +// ranging through a map. +func waitForFuncWithNames(ctx context.Context, f func(context.Context) (string, error), names ...string) error { + wantNames := make(map[string]bool, len(names)) + for _, name := range names { + wantNames[name] = true + } + gotNames := make(map[string]bool, len(names)) + for range wantNames { + name, err := f(ctx) + if err != nil { + return err + } + gotNames[name] = true + } + if !cmp.Equal(gotNames, wantNames) { + return fmt.Errorf("got routeNames %v, want %v", gotNames, wantNames) + } + return nil +} + +// TestSuccessCaseOneRDSWatch tests the simplest scenario: the rds handler +// receives a single route name, starts a watch for that route name, gets a +// successful update, and then writes an update to the update channel for +// listener to pick up. +func (s) TestSuccessCaseOneRDSWatch(t *testing.T) { + rh, fakeClient, ch := setupTests() + // When you first update the rds handler with a list of a single Route names + // that needs dynamic RDS Configuration, this Route name has not been seen + // before, so the RDS Handler should start a watch on that RouteName. + rh.updateRouteNamesToWatch(map[string]bool{route1: true}) + // The RDS Handler should start a watch for that routeName. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotRoute, err := fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error: %v", err) + } + if gotRoute != route1 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route1) + } + rdsUpdate := xdsclient.RouteConfigUpdate{} + // Invoke callback with the xds client with a certain route update. Due to + // this route update updating every route name that rds handler handles, + // this should write to the update channel to send to the listener. + fakeClient.InvokeWatchRouteConfigCallback(route1, rdsUpdate, nil) + rhuWant := map[string]xdsclient.RouteConfigUpdate{route1: rdsUpdate} + select { + case rhu := <-ch: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } + // Close the rds handler. This is meant to be called when the lis wrapper is + // closed, and the call should cancel all the watches present (for this + // test, a single watch). + rh.close() + routeNameDeleted, err := fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error: %v", err) + } + if routeNameDeleted != route1 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route1) + } +} + +// TestSuccessCaseTwoUpdates tests the case where the rds handler receives an +// update with a single Route, then receives a second update with two routes. +// The handler should start a watch for the added route, and if received a RDS +// update for that route it should send an update with both RDS updates present. +func (s) TestSuccessCaseTwoUpdates(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true}) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotRoute, err := fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error: %v", err) + } + if gotRoute != route1 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route1) + } + + // Update the RDSHandler with route names which adds a route name to watch. + // This should trigger the RDSHandler to start a watch for the added route + // name to watch. + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true}) + gotRoute, err = fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error: %v", err) + } + if gotRoute != route2 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route2) + } + + // Invoke the callback with an update for route 1. This shouldn't cause the + // handler to write an update, as it has not received RouteConfigurations + // for every RouteName. + rdsUpdate1 := xdsclient.RouteConfigUpdate{} + fakeClient.InvokeWatchRouteConfigCallback(route1, rdsUpdate1, nil) + + // The RDS Handler should not send an update. + sCtx, sCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCtxCancel() + select { + case <-ch: + t.Fatal("RDS Handler wrote an update to updateChannel when it shouldn't have, as each route name has not received an update yet") + case <-sCtx.Done(): + } + + // Invoke the callback with an update for route 2. This should cause the + // handler to write an update, as it has received RouteConfigurations for + // every RouteName. + rdsUpdate2 := xdsclient.RouteConfigUpdate{} + fakeClient.InvokeWatchRouteConfigCallback(route2, rdsUpdate2, nil) + // The RDS Handler should then update the listener wrapper with an update + // with two route configurations, as both route names the RDS Handler handles + // have received an update. + rhuWant := map[string]xdsclient.RouteConfigUpdate{route1: rdsUpdate1, route2: rdsUpdate2} + select { + case rhu := <-ch: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the rds handler update to be written to the update buffer.") + } + + // Close the rds handler. This is meant to be called when the lis wrapper is + // closed, and the call should cancel all the watches present (for this + // test, two watches on route1 and route2). + rh.close() + if err = waitForFuncWithNames(ctx, fakeClient.WaitForCancelRouteConfigWatch, route1, route2); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } +} + +// TestSuccessCaseDeletedRoute tests the case where the rds handler receives an +// update with two routes, then receives an update with only one route. The RDS +// Handler is expected to cancel the watch for the route no longer present. +func (s) TestSuccessCaseDeletedRoute(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true}) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Will start two watches. + if err := waitForFuncWithNames(ctx, fakeClient.WaitForWatchRouteConfig, route1, route2); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } + + // Update the RDSHandler with route names which deletes a route name to + // watch. This should trigger the RDSHandler to cancel the watch for the + // deleted route name to watch. + rh.updateRouteNamesToWatch(map[string]bool{route1: true}) + // This should delete the watch for route2. + routeNameDeleted, err := fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error %v", err) + } + if routeNameDeleted != route2 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route2) + } + + rdsUpdate := xdsclient.RouteConfigUpdate{} + // Invoke callback with the xds client with a certain route update. Due to + // this route update updating every route name that rds handler handles, + // this should write to the update channel to send to the listener. + fakeClient.InvokeWatchRouteConfigCallback(route1, rdsUpdate, nil) + rhuWant := map[string]xdsclient.RouteConfigUpdate{route1: rdsUpdate} + select { + case rhu := <-ch: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel.") + } + + rh.close() + routeNameDeleted, err = fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error: %v", err) + } + if routeNameDeleted != route1 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route1) + } +} + +// TestSuccessCaseTwoUpdatesAddAndDeleteRoute tests the case where the rds +// handler receives an update with two routes, and then receives an update with +// two routes, one previously there and one added (i.e. 12 -> 23). This should +// cause the route that is no longer there to be deleted and cancelled, and the +// route that was added should have a watch started for it. +func (s) TestSuccessCaseTwoUpdatesAddAndDeleteRoute(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true}) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := waitForFuncWithNames(ctx, fakeClient.WaitForWatchRouteConfig, route1, route2); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } + + // Update the rds handler with two routes, one which was already there and a new route. + // This should cause the rds handler to delete/cancel watch for route 1 and start a watch + // for route 3. + rh.updateRouteNamesToWatch(map[string]bool{route2: true, route3: true}) + + // Start watch comes first, which should be for route3 as was just added. + gotRoute, err := fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error: %v", err) + } + if gotRoute != route3 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route3) + } + + // Then route 1 should be deleted/cancelled watch for, as it is no longer present + // in the new RouteName to watch map. + routeNameDeleted, err := fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error: %v", err) + } + if routeNameDeleted != route1 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route1) + } + + // Invoke the callback with an update for route 2. This shouldn't cause the + // handler to write an update, as it has not received RouteConfigurations + // for every RouteName. + rdsUpdate2 := xdsclient.RouteConfigUpdate{} + fakeClient.InvokeWatchRouteConfigCallback(route2, rdsUpdate2, nil) + + // The RDS Handler should not send an update. + sCtx, sCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCtxCancel() + select { + case <-ch: + t.Fatalf("RDS Handler wrote an update to updateChannel when it shouldn't have, as each route name has not received an update yet") + case <-sCtx.Done(): + } + + // Invoke the callback with an update for route 3. This should cause the + // handler to write an update, as it has received RouteConfigurations for + // every RouteName. + rdsUpdate3 := xdsclient.RouteConfigUpdate{} + fakeClient.InvokeWatchRouteConfigCallback(route3, rdsUpdate3, nil) + // The RDS Handler should then update the listener wrapper with an update + // with two route configurations, as both route names the RDS Handler handles + // have received an update. + rhuWant := map[string]xdsclient.RouteConfigUpdate{route2: rdsUpdate2, route3: rdsUpdate3} + select { + case rhu := <-rh.updateChannel: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the rds handler update to be written to the update buffer.") + } + // Close the rds handler. This is meant to be called when the lis wrapper is + // closed, and the call should cancel all the watches present (for this + // test, two watches on route2 and route3). + rh.close() + if err = waitForFuncWithNames(ctx, fakeClient.WaitForCancelRouteConfigWatch, route2, route3); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } +} + +// TestSuccessCaseSecondUpdateMakesRouteFull tests the scenario where the rds handler gets +// told to watch three rds configurations, gets two successful updates, then gets told to watch +// only those two. The rds handler should then write an update to update buffer. +func (s) TestSuccessCaseSecondUpdateMakesRouteFull(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true, route3: true}) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := waitForFuncWithNames(ctx, fakeClient.WaitForWatchRouteConfig, route1, route2, route3); err != nil { + t.Fatalf("Error while waiting for names: %v", err) + } + + // Invoke the callbacks for two of the three watches. Since RDS is not full, + // this shouldn't trigger rds handler to write an update to update buffer. + fakeClient.InvokeWatchRouteConfigCallback(route1, xdsclient.RouteConfigUpdate{}, nil) + fakeClient.InvokeWatchRouteConfigCallback(route2, xdsclient.RouteConfigUpdate{}, nil) + + // The RDS Handler should not send an update. + sCtx, sCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCtxCancel() + select { + case <-rh.updateChannel: + t.Fatalf("RDS Handler wrote an update to updateChannel when it shouldn't have, as each route name has not received an update yet") + case <-sCtx.Done(): + } + + // Tell the rds handler to now only watch Route 1 and Route 2. This should + // trigger the rds handler to write an update to the update buffer as it now + // has full rds configuration. + rh.updateRouteNamesToWatch(map[string]bool{route1: true, route2: true}) + // Route 3 should be deleted/cancelled watch for, as it is no longer present + // in the new RouteName to watch map. + routeNameDeleted, err := fakeClient.WaitForCancelRouteConfigWatch(ctx) + if err != nil { + t.Fatalf("xdsClient.CancelRDS failed with error: %v", err) + } + if routeNameDeleted != route3 { + t.Fatalf("xdsClient.CancelRDS called for route %v, want %v", routeNameDeleted, route1) + } + rhuWant := map[string]xdsclient.RouteConfigUpdate{route1: {}, route2: {}} + select { + case rhu := <-ch: + if diff := cmp.Diff(rhu.updates, rhuWant); diff != "" { + t.Fatalf("got unexpected route update, diff (-got, +want): %v", diff) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for the rds handler update to be written to the update buffer.") + } +} + +// TestErrorReceived tests the case where the rds handler receives a route name +// to watch, then receives an update with an error. This error should be then +// written to the update channel. +func (s) TestErrorReceived(t *testing.T) { + rh, fakeClient, ch := setupTests() + + rh.updateRouteNamesToWatch(map[string]bool{route1: true}) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotRoute, err := fakeClient.WaitForWatchRouteConfig(ctx) + if err != nil { + t.Fatalf("xdsClient.WatchRDS failed with error %v", err) + } + if gotRoute != route1 { + t.Fatalf("xdsClient.WatchRDS called for route: %v, want %v", gotRoute, route1) + } + + rdsErr := errors.New("some error") + fakeClient.InvokeWatchRouteConfigCallback(route1, xdsclient.RouteConfigUpdate{}, rdsErr) + select { + case rhu := <-ch: + if rhu.err.Error() != "some error" { + t.Fatalf("Did not receive the expected error, instead received: %v", rhu.err.Error()) + } + case <-ctx.Done(): + t.Fatal("Timed out waiting for update from update channel") + } +} diff --git a/xds/internal/test/e2e/README.md b/xds/internal/test/e2e/README.md new file mode 100644 index 00000000000..33cffa0da56 --- /dev/null +++ b/xds/internal/test/e2e/README.md @@ -0,0 +1,19 @@ +Build client and server binaries. + +```sh +go build -o ./binaries/client ../../../../interop/xds/client/ +go build -o ./binaries/server ../../../../interop/xds/server/ +``` + +Run the test + +```sh +go test . -v +``` + +The client/server paths are flags + +```sh +go test . -v -client=$HOME/grpc-java/interop-testing/build/install/grpc-interop-testing/bin/xds-test-client +``` +Note that grpc logs are only turned on for Go. diff --git a/xds/internal/test/e2e/controlplane.go b/xds/internal/test/e2e/controlplane.go new file mode 100644 index 00000000000..247991b83d3 --- /dev/null +++ b/xds/internal/test/e2e/controlplane.go @@ -0,0 +1,62 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package e2e + +import ( + "fmt" + + "github.com/google/uuid" + xdsinternal "google.golang.org/grpc/internal/xds" + "google.golang.org/grpc/xds/internal/testutils/e2e" +) + +type controlPlane struct { + server *e2e.ManagementServer + nodeID string + bootstrapContent string +} + +func newControlPlane(testName string) (*controlPlane, error) { + // Spin up an xDS management server on a local port. + server, err := e2e.StartManagementServer() + if err != nil { + return nil, fmt.Errorf("failed to spin up the xDS management server: %v", err) + } + + nodeID := uuid.New().String() + bootstrapContentBytes, err := xdsinternal.BootstrapContents(xdsinternal.BootstrapOptions{ + Version: xdsinternal.TransportV3, + NodeID: nodeID, + ServerURI: server.Address, + ServerListenerResourceNameTemplate: e2e.ServerListenerResourceNameTemplate, + }) + if err != nil { + server.Stop() + return nil, fmt.Errorf("failed to create bootstrap file: %v", err) + } + + return &controlPlane{ + server: server, + nodeID: nodeID, + bootstrapContent: string(bootstrapContentBytes), + }, nil +} + +func (cp *controlPlane) stop() { + cp.server.Stop() +} diff --git a/xds/internal/test/e2e/e2e.go b/xds/internal/test/e2e/e2e.go new file mode 100644 index 00000000000..ade6339bf53 --- /dev/null +++ b/xds/internal/test/e2e/e2e.go @@ -0,0 +1,178 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package e2e implements xds e2e tests using go-control-plane. +package e2e + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + + "google.golang.org/grpc" + channelzgrpc "google.golang.org/grpc/channelz/grpc_channelz_v1" + channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +func cmd(path string, logger io.Writer, args []string, env []string) (*exec.Cmd, error) { + cmd := exec.Command(path, args...) + cmd.Env = append(os.Environ(), env...) + cmd.Stdout = logger + cmd.Stderr = logger + return cmd, nil +} + +const ( + clientStatsPort = 60363 // TODO: make this different per-test, only needed for parallel tests. +) + +type client struct { + cmd *exec.Cmd + + target string + statsCC *grpc.ClientConn +} + +// newClient create a client with the given target and bootstrap content. +func newClient(target, binaryPath, bootstrap string, logger io.Writer, flags ...string) (*client, error) { + cmd, err := cmd( + binaryPath, + logger, + append([]string{ + "--server=" + target, + "--print_response=true", + "--qps=100", + fmt.Sprintf("--stats_port=%d", clientStatsPort), + }, flags...), // Append any flags from caller. + []string{ + "GRPC_GO_LOG_VERBOSITY_LEVEL=99", + "GRPC_GO_LOG_SEVERITY_LEVEL=info", + "GRPC_XDS_BOOTSTRAP_CONFIG=" + bootstrap, // The bootstrap content doesn't need to be quoted. + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to run client cmd: %v", err) + } + cmd.Start() + + cc, err := grpc.Dial(fmt.Sprintf("localhost:%d", clientStatsPort), grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.WaitForReady(true))) + if err != nil { + return nil, err + } + return &client{ + cmd: cmd, + target: target, + statsCC: cc, + }, nil +} + +func (c *client) clientStats(ctx context.Context) (*testpb.LoadBalancerStatsResponse, error) { + ccc := testgrpc.NewLoadBalancerStatsServiceClient(c.statsCC) + return ccc.GetClientStats(ctx, &testpb.LoadBalancerStatsRequest{ + NumRpcs: 100, + TimeoutSec: 10, + }) +} + +func (c *client) configRPCs(ctx context.Context, req *testpb.ClientConfigureRequest) error { + ccc := testgrpc.NewXdsUpdateClientConfigureServiceClient(c.statsCC) + _, err := ccc.Configure(ctx, req) + return err +} + +func (c *client) channelzSubChannels(ctx context.Context) ([]*channelzpb.Subchannel, error) { + ccc := channelzgrpc.NewChannelzClient(c.statsCC) + r, err := ccc.GetTopChannels(ctx, &channelzpb.GetTopChannelsRequest{}) + if err != nil { + return nil, err + } + + var ret []*channelzpb.Subchannel + for _, cc := range r.Channel { + if cc.Data.Target != c.target { + continue + } + for _, sc := range cc.SubchannelRef { + rr, err := ccc.GetSubchannel(ctx, &channelzpb.GetSubchannelRequest{SubchannelId: sc.SubchannelId}) + if err != nil { + return nil, err + } + ret = append(ret, rr.Subchannel) + } + } + return ret, nil +} + +func (c *client) stop() { + c.cmd.Process.Kill() + c.cmd.Wait() +} + +const ( + serverPort = 50051 // TODO: make this different per-test, only needed for parallel tests. +) + +type server struct { + cmd *exec.Cmd + port int +} + +// newServer creates multiple servers with the given bootstrap content. +// +// Each server gets a different hostname, in the format of +// -. +func newServers(hostnamePrefix, binaryPath, bootstrap string, logger io.Writer, count int) (_ []*server, err error) { + var ret []*server + defer func() { + if err != nil { + for _, s := range ret { + s.stop() + } + } + }() + for i := 0; i < count; i++ { + port := serverPort + i + cmd, err := cmd( + binaryPath, + logger, + []string{ + fmt.Sprintf("--port=%d", port), + fmt.Sprintf("--host_name_override=%s-%d", hostnamePrefix, i), + }, + []string{ + "GRPC_GO_LOG_VERBOSITY_LEVEL=99", + "GRPC_GO_LOG_SEVERITY_LEVEL=info", + "GRPC_XDS_BOOTSTRAP_CONFIG=" + bootstrap, // The bootstrap content doesn't need to be quoted., + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to run server cmd: %v", err) + } + cmd.Start() + ret = append(ret, &server{cmd: cmd, port: port}) + } + return ret, nil +} + +func (s *server) stop() { + s.cmd.Process.Kill() + s.cmd.Wait() +} diff --git a/xds/internal/test/e2e/e2e_test.go b/xds/internal/test/e2e/e2e_test.go new file mode 100644 index 00000000000..6984566db2e --- /dev/null +++ b/xds/internal/test/e2e/e2e_test.go @@ -0,0 +1,257 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package e2e + +import ( + "bytes" + "context" + "flag" + "fmt" + "os" + "strconv" + "testing" + "time" + + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" + testpb "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/xds/internal/testutils/e2e" +) + +var ( + clientPath = flag.String("client", "./binaries/client", "The interop client") + serverPath = flag.String("server", "./binaries/server", "The interop server") +) + +type testOpts struct { + testName string + backendCount int + clientFlags []string +} + +func setup(t *testing.T, opts testOpts) (*controlPlane, *client, []*server) { + t.Helper() + if _, err := os.Stat(*clientPath); os.IsNotExist(err) { + t.Skip("skipped because client is not found") + } + if _, err := os.Stat(*serverPath); os.IsNotExist(err) { + t.Skip("skipped because server is not found") + } + backendCount := 1 + if opts.backendCount != 0 { + backendCount = opts.backendCount + } + + cp, err := newControlPlane(opts.testName) + if err != nil { + t.Fatalf("failed to start control-plane: %v", err) + } + t.Cleanup(cp.stop) + + var clientLog bytes.Buffer + c, err := newClient(fmt.Sprintf("xds:///%s", opts.testName), *clientPath, cp.bootstrapContent, &clientLog, opts.clientFlags...) + if err != nil { + t.Fatalf("failed to start client: %v", err) + } + t.Cleanup(c.stop) + + var serverLog bytes.Buffer + servers, err := newServers(opts.testName, *serverPath, cp.bootstrapContent, &serverLog, backendCount) + if err != nil { + t.Fatalf("failed to start server: %v", err) + } + t.Cleanup(func() { + for _, s := range servers { + s.stop() + } + }) + t.Cleanup(func() { + // TODO: find a better way to print the log. They are long, and hide the failure. + t.Logf("\n----- client logs -----\n%v", clientLog.String()) + t.Logf("\n----- server logs -----\n%v", serverLog.String()) + }) + return cp, c, servers +} + +func TestPingPong(t *testing.T) { + const testName = "pingpong" + cp, c, _ := setup(t, testOpts{testName: testName}) + + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: testName, + NodeID: cp.nodeID, + Host: "localhost", + Port: serverPort, + SecLevel: e2e.SecurityLevelNone, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := cp.server.Update(ctx, resources); err != nil { + t.Fatalf("failed to update control plane resources: %v", err) + } + + st, err := c.clientStats(ctx) + if err != nil { + t.Fatalf("failed to get client stats: %v", err) + } + if st.NumFailures != 0 { + t.Fatalf("Got %v failures: %+v", st.NumFailures, st) + } +} + +// TestAffinity covers the affinity tests with ringhash policy. +// - client is configured to use ringhash, with 3 backends +// - all RPCs will hash a specific metadata header +// - verify that +// - all RPCs with the same metadata value are sent to the same backend +// - only one backend is Ready +// - send more RPCs with different metadata values until a new backend is picked, and verify that +// - only two backends are in Ready +func TestAffinity(t *testing.T) { + const ( + testName = "affinity" + backendCount = 3 + testMDKey = "xds_md" + testMDValue = "unary_yranu" + ) + cp, c, servers := setup(t, testOpts{ + testName: testName, + backendCount: backendCount, + clientFlags: []string{"--rpc=EmptyCall", fmt.Sprintf("--metadata=EmptyCall:%s:%s", testMDKey, testMDValue)}, + }) + + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: testName, + NodeID: cp.nodeID, + Host: "localhost", + Port: serverPort, + SecLevel: e2e.SecurityLevelNone, + }) + + // Update EDS to multiple backends. + var ports []uint32 + for _, s := range servers { + ports = append(ports, uint32(s.port)) + } + edsMsg := resources.Endpoints[0] + resources.Endpoints[0] = e2e.DefaultEndpoint( + edsMsg.ClusterName, + "localhost", + ports, + ) + + // Update CDS lbpolicy to ringhash. + cdsMsg := resources.Clusters[0] + cdsMsg.LbPolicy = v3clusterpb.Cluster_RING_HASH + + // Update RDS to hash the header. + rdsMsg := resources.Routes[0] + rdsMsg.VirtualHosts[0].Routes[0].Action = &v3routepb.Route_Route{Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: cdsMsg.Name}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{{ + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: testMDKey, + }, + }, + }}, + }} + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := cp.server.Update(ctx, resources); err != nil { + t.Fatalf("failed to update control plane resources: %v", err) + } + + // Note: We can skip CSDS check because there's no long delay as in TD. + // + // The client stats check doesn't race with the xds resource update because + // there's only one version of xds resource, updated at the beginning of the + // test. So there's no need to retry the stats call. + // + // In the future, we may add tests that update xds in the middle. Then we + // either need to retry clientStats(), or make a CSDS check before so the + // result is stable. + + st, err := c.clientStats(ctx) + if err != nil { + t.Fatalf("failed to get client stats: %v", err) + } + if st.NumFailures != 0 { + t.Fatalf("Got %v failures: %+v", st.NumFailures, st) + } + if len(st.RpcsByPeer) != 1 { + t.Fatalf("more than 1 backends got traffic: %v, want 1", st.RpcsByPeer) + } + + // Call channelz to verify that only one subchannel is in state Ready. + scs, err := c.channelzSubChannels(ctx) + if err != nil { + t.Fatalf("failed to fetch channelz: %v", err) + } + verifySubConnStates(t, scs, map[channelzpb.ChannelConnectivityState_State]int{ + channelzpb.ChannelConnectivityState_READY: 1, + channelzpb.ChannelConnectivityState_IDLE: 2, + }) + + // Send Unary call with different metadata value with integers starting from + // 0. Stop when a second peer is picked. + var ( + diffPeerPicked bool + mdValue int + ) + for !diffPeerPicked { + if err := c.configRPCs(ctx, &testpb.ClientConfigureRequest{ + Types: []testpb.ClientConfigureRequest_RpcType{ + testpb.ClientConfigureRequest_EMPTY_CALL, + testpb.ClientConfigureRequest_UNARY_CALL, + }, + Metadata: []*testpb.ClientConfigureRequest_Metadata{ + {Type: testpb.ClientConfigureRequest_EMPTY_CALL, Key: testMDKey, Value: testMDValue}, + {Type: testpb.ClientConfigureRequest_UNARY_CALL, Key: testMDKey, Value: strconv.Itoa(mdValue)}, + }, + }); err != nil { + t.Fatalf("failed to configure RPC: %v", err) + } + + st, err := c.clientStats(ctx) + if err != nil { + t.Fatalf("failed to get client stats: %v", err) + } + if st.NumFailures != 0 { + t.Fatalf("Got %v failures: %+v", st.NumFailures, st) + } + if len(st.RpcsByPeer) == 2 { + break + } + + mdValue++ + } + + // Call channelz to verify that only one subchannel is in state Ready. + scs2, err := c.channelzSubChannels(ctx) + if err != nil { + t.Fatalf("failed to fetch channelz: %v", err) + } + verifySubConnStates(t, scs2, map[channelzpb.ChannelConnectivityState_State]int{ + channelzpb.ChannelConnectivityState_READY: 2, + channelzpb.ChannelConnectivityState_IDLE: 1, + }) +} diff --git a/internal/credentials/syscallconn_appengine.go b/xds/internal/test/e2e/e2e_utils.go similarity index 50% rename from internal/credentials/syscallconn_appengine.go rename to xds/internal/test/e2e/e2e_utils.go index a6144cd661c..34b0ee9eb09 100644 --- a/internal/credentials/syscallconn_appengine.go +++ b/xds/internal/test/e2e/e2e_utils.go @@ -1,8 +1,6 @@ -// +build appengine - /* * - * Copyright 2018 gRPC authors. + * Copyright 2021 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,16 +13,24 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * */ -package credentials +package e2e import ( - "net" + "testing" + + "github.com/google/go-cmp/cmp" + channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1" ) -// WrapSyscallConn returns newConn on appengine. -func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn { - return newConn +func verifySubConnStates(t *testing.T, scs []*channelzpb.Subchannel, want map[channelzpb.ChannelConnectivityState_State]int) { + t.Helper() + var scStatsCount = map[channelzpb.ChannelConnectivityState_State]int{} + for _, sc := range scs { + scStatsCount[sc.Data.State.State]++ + } + if diff := cmp.Diff(scStatsCount, want); diff != "" { + t.Fatalf("got unexpected number of subchannels in state Ready, %v, scs: %v", diff, scs) + } } diff --git a/xds/internal/test/e2e/run.sh b/xds/internal/test/e2e/run.sh new file mode 100755 index 00000000000..4363d6cbd94 --- /dev/null +++ b/xds/internal/test/e2e/run.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +mkdir binaries +go build -o ./binaries/client ../../../../interop/xds/client/ +go build -o ./binaries/server ../../../../interop/xds/server/ +go test . -v diff --git a/xds/internal/test/xds_client_affinity_test.go b/xds/internal/test/xds_client_affinity_test.go new file mode 100644 index 00000000000..e9ddfe157b1 --- /dev/null +++ b/xds/internal/test/xds_client_affinity_test.go @@ -0,0 +1,136 @@ +//go:build !386 +// +build !386 + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds_test + +import ( + "context" + "fmt" + "testing" + + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/xds/env" + testpb "google.golang.org/grpc/test/grpc_testing" + "google.golang.org/grpc/xds/internal/testutils/e2e" +) + +const hashHeaderName = "session_id" + +// hashRouteConfig returns a RouteConfig resource with hash policy set to +// header "session_id". +func hashRouteConfig(routeName, ldsTarget, clusterName string) *v3routepb.RouteConfiguration { + return &v3routepb.RouteConfiguration{ + Name: routeName, + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{ldsTarget}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}}, + Action: &v3routepb.Route_Route{Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{{ + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: hashHeaderName, + }, + }, + Terminal: true, + }}, + }}, + }}, + }}, + } +} + +// ringhashCluster returns a Cluster resource that picks ringhash as the lb +// policy. +func ringhashCluster(clusterName, edsServiceName string) *v3clusterpb.Cluster { + return &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: edsServiceName, + }, + LbPolicy: v3clusterpb.Cluster_RING_HASH, + } +} + +// TestClientSideAffinitySanityCheck tests that the affinity config can be +// propagated to pick the ring_hash policy. It doesn't test the affinity +// behavior in ring_hash policy. +func (s) TestClientSideAffinitySanityCheck(t *testing.T) { + defer func() func() { + old := env.RingHashSupport + env.RingHashSupport = true + return func() { env.RingHashSupport = old } + }()() + + managementServer, nodeID, _, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + port, cleanup2 := clientSetup(t, &testService{}) + defer cleanup2() + + const serviceName = "my-service-client-side-xds" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + // Replace RDS and CDS resources with ringhash config, but keep the resource + // names. + resources.Routes = []*v3routepb.RouteConfiguration{hashRouteConfig( + resources.Routes[0].Name, + resources.Listeners[0].Name, + resources.Clusters[0].Name, + )} + resources.Clusters = []*v3clusterpb.Cluster{ringhashCluster( + resources.Clusters[0].Name, + resources.Clusters[0].EdsClusterConfig.ServiceName, + )} + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create a ClientConn and make a successful RPC. + cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) + } +} diff --git a/xds/internal/test/xds_client_integration_test.go b/xds/internal/test/xds_client_integration_test.go index c3ea71acaa3..23ea1546935 100644 --- a/xds/internal/test/xds_client_integration_test.go +++ b/xds/internal/test/xds_client_integration_test.go @@ -1,3 +1,4 @@ +//go:build !386 // +build !386 /* @@ -22,51 +23,35 @@ package xds_test import ( "context" + "fmt" "net" "testing" - "github.com/google/uuid" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/status" "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/e2e" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" testpb "google.golang.org/grpc/test/grpc_testing" ) // clientSetup performs a bunch of steps common to all xDS client tests here: -// - spin up an xDS management server on a local port // - spin up a gRPC server and register the test service on it // - create a local TCP listener and start serving on it // // Returns the following: -// - the management server: tests use this to configure resources -// - nodeID expected by the management server: this is set in the Node proto -// sent by the xdsClient for queries. // - the port the server is listening on // - cleanup function to be invoked by the tests when done -func clientSetup(t *testing.T) (*e2e.ManagementServer, string, uint32, func()) { - // Spin up a xDS management server on a local port. - nodeID := uuid.New().String() - fs, err := e2e.StartManagementServer() - if err != nil { - t.Fatal(err) - } - - // Create a bootstrap file in a temporary directory. - bootstrapCleanup, err := e2e.SetupBootstrapFile(e2e.BootstrapOptions{ - Version: e2e.TransportV3, - NodeID: nodeID, - ServerURI: fs.Address, - ServerResourceNameID: "grpc/server", - }) - if err != nil { - t.Fatal(err) - } - +func clientSetup(t *testing.T, tss testpb.TestServiceServer) (uint32, func()) { // Initialize a gRPC server and register the stubServer on it. server := grpc.NewServer() - testpb.RegisterTestServiceServer(server, &testService{}) + testpb.RegisterTestServiceServer(server, tss) // Create a local listener and pass it to Serve(). lis, err := testutils.LocalTCPListener() @@ -80,33 +65,190 @@ func clientSetup(t *testing.T) (*e2e.ManagementServer, string, uint32, func()) { } }() - return fs, nodeID, uint32(lis.Addr().(*net.TCPAddr).Port), func() { - fs.Stop() - bootstrapCleanup() + return uint32(lis.Addr().(*net.TCPAddr).Port), func() { server.Stop() } } func (s) TestClientSideXDS(t *testing.T) { - fs, nodeID, port, cleanup := clientSetup(t) - defer cleanup() + managementServer, nodeID, _, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() - resources := e2e.DefaultClientResources("myservice", nodeID, "localhost", port) - if err := fs.Update(resources); err != nil { + port, cleanup2 := clientSetup(t, &testService{}) + defer cleanup2() + + const serviceName = "my-service-client-side-xds" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { t.Fatal(err) } // Create a ClientConn and make a successful RPC. - cc, err := grpc.Dial("xds:///myservice", grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolver)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } defer cc.Close() client := testpb.NewTestServiceClient(cc) - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { t.Fatalf("rpc EmptyCall() failed: %v", err) } } + +func (s) TestClientSideRetry(t *testing.T) { + if !env.RetrySupport { + // Skip this test if retry is not enabled. + return + } + + ctr := 0 + errs := []codes.Code{codes.ResourceExhausted} + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + defer func() { ctr++ }() + if ctr < len(errs) { + return nil, status.Errorf(errs[ctr], "this should be retried") + } + return &testpb.Empty{}, nil + }, + } + + managementServer, nodeID, _, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + port, cleanup2 := clientSetup(t, ss) + defer cleanup2() + + const serviceName = "my-service-client-side-xds" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create a ClientConn and make a successful RPC. + cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + defer cancel() + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != codes.ResourceExhausted { + t.Fatalf("rpc EmptyCall() = _, %v; want _, ResourceExhausted", err) + } + + testCases := []struct { + name string + vhPolicy *v3routepb.RetryPolicy + routePolicy *v3routepb.RetryPolicy + errs []codes.Code // the errors returned by the server for each RPC + tryAgainErr codes.Code // the error that would be returned if we are still using the old retry policies. + errWant codes.Code + }{{ + name: "virtualHost only, fail", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted,unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 1}, + }, + errs: []codes.Code{codes.ResourceExhausted, codes.Unavailable}, + routePolicy: nil, + tryAgainErr: codes.ResourceExhausted, + errWant: codes.Unavailable, + }, { + name: "virtualHost only", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted, unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + errs: []codes.Code{codes.ResourceExhausted, codes.Unavailable}, + routePolicy: nil, + tryAgainErr: codes.Unavailable, + errWant: codes.OK, + }, { + name: "virtualHost+route, fail", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted,unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + routePolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + errs: []codes.Code{codes.ResourceExhausted, codes.Unavailable}, + tryAgainErr: codes.OK, + errWant: codes.Unavailable, + }, { + name: "virtualHost+route", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "resource-exhausted", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + routePolicy: &v3routepb.RetryPolicy{ + RetryOn: "unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + errs: []codes.Code{codes.Unavailable}, + tryAgainErr: codes.Unavailable, + errWant: codes.OK, + }, { + name: "virtualHost+route, not enough attempts", + vhPolicy: &v3routepb.RetryPolicy{ + RetryOn: "unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 2}, + }, + routePolicy: &v3routepb.RetryPolicy{ + RetryOn: "unavailable", + NumRetries: &wrapperspb.UInt32Value{Value: 1}, + }, + errs: []codes.Code{codes.Unavailable, codes.Unavailable}, + tryAgainErr: codes.OK, + errWant: codes.Unavailable, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + errs = tc.errs + + // Confirm tryAgainErr is correct before updating resources. + ctr = 0 + _, err := client.EmptyCall(ctx, &testpb.Empty{}) + if code := status.Code(err); code != tc.tryAgainErr { + t.Fatalf("with old retry policy: EmptyCall() = _, %v; want _, %v", err, tc.tryAgainErr) + } + + resources.Routes[0].VirtualHosts[0].RetryPolicy = tc.vhPolicy + resources.Routes[0].VirtualHosts[0].Routes[0].GetRoute().RetryPolicy = tc.routePolicy + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + for { + ctr = 0 + _, err := client.EmptyCall(ctx, &testpb.Empty{}) + if code := status.Code(err); code == tc.tryAgainErr { + continue + } else if code != tc.errWant { + t.Fatalf("rpc EmptyCall() = _, %v; want _, %v", err, tc.errWant) + } + break + } + }) + } +} diff --git a/xds/internal/test/xds_integration_test.go b/xds/internal/test/xds_integration_test.go index ae306ae7864..4b7cca3b828 100644 --- a/xds/internal/test/xds_integration_test.go +++ b/xds/internal/test/xds_integration_test.go @@ -1,3 +1,4 @@ +//go:build !386 // +build !386 /* @@ -23,15 +24,32 @@ package xds_test import ( "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path" "testing" "time" + "github.com/google/uuid" + + "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/testdata" + "google.golang.org/grpc/xds" + "google.golang.org/grpc/xds/internal/testutils/e2e" + + xdsinternal "google.golang.org/grpc/internal/xds" testpb "google.golang.org/grpc/test/grpc_testing" ) const ( - defaultTestTimeout = 10 * time.Second + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 100 * time.Millisecond ) type s struct { @@ -49,3 +67,134 @@ type testService struct { func (*testService) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil } + +func (*testService) UnaryCall(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil +} + +func createTmpFile(src, dst string) error { + data, err := ioutil.ReadFile(src) + if err != nil { + return fmt.Errorf("ioutil.ReadFile(%q) failed: %v", src, err) + } + if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil { + return fmt.Errorf("ioutil.WriteFile(%q) failed: %v", dst, err) + } + return nil +} + +// createTempDirWithFiles creates a temporary directory under the system default +// tempDir with the given dirSuffix. It also reads from certSrc, keySrc and +// rootSrc files are creates appropriate files under the newly create tempDir. +// Returns the name of the created tempDir. +func createTmpDirWithFiles(dirSuffix, certSrc, keySrc, rootSrc string) (string, error) { + // Create a temp directory. Passing an empty string for the first argument + // uses the system temp directory. + dir, err := ioutil.TempDir("", dirSuffix) + if err != nil { + return "", fmt.Errorf("ioutil.TempDir() failed: %v", err) + } + + if err := createTmpFile(testdata.Path(certSrc), path.Join(dir, certFile)); err != nil { + return "", err + } + if err := createTmpFile(testdata.Path(keySrc), path.Join(dir, keyFile)); err != nil { + return "", err + } + if err := createTmpFile(testdata.Path(rootSrc), path.Join(dir, rootFile)); err != nil { + return "", err + } + return dir, nil +} + +// createClientTLSCredentials creates client-side TLS transport credentials. +func createClientTLSCredentials(t *testing.T) credentials.TransportCredentials { + t.Helper() + + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) + if err != nil { + t.Fatalf("tls.LoadX509KeyPair(x509/client1_cert.pem, x509/client1_key.pem) failed: %v", err) + } + b, err := ioutil.ReadFile(testdata.Path("x509/server_ca_cert.pem")) + if err != nil { + t.Fatalf("ioutil.ReadFile(x509/server_ca_cert.pem) failed: %v", err) + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(b) { + t.Fatal("failed to append certificates") + } + return credentials.NewTLS(&tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: roots, + ServerName: "x.test.example.com", + }) +} + +// setupManagement server performs the following: +// - spin up an xDS management server on a local port +// - set up certificates for consumption by the file_watcher plugin +// - creates a bootstrap file in a temporary location +// - creates an xDS resolver using the above bootstrap contents +// +// Returns the following: +// - management server +// - nodeID to be used by the client when connecting to the management server +// - bootstrap contents to be used by the client +// - xDS resolver builder to be used by the client +// - a cleanup function to be invoked at the end of the test +func setupManagementServer(t *testing.T) (*e2e.ManagementServer, string, []byte, resolver.Builder, func()) { + t.Helper() + + // Spin up an xDS management server on a local port. + server, err := e2e.StartManagementServer() + if err != nil { + t.Fatalf("Failed to spin up the xDS management server: %v", err) + } + defer func() { + if err != nil { + server.Stop() + } + }() + + // Create a directory to hold certs and key files used on the server side. + serverDir, err := createTmpDirWithFiles("testServerSideXDS*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") + if err != nil { + server.Stop() + t.Fatal(err) + } + + // Create a directory to hold certs and key files used on the client side. + clientDir, err := createTmpDirWithFiles("testClientSideXDS*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") + if err != nil { + server.Stop() + t.Fatal(err) + } + + // Create certificate providers section of the bootstrap config with entries + // for both the client and server sides. + cpc := map[string]json.RawMessage{ + e2e.ServerSideCertProviderInstance: e2e.DefaultFileWatcherConfig(path.Join(serverDir, certFile), path.Join(serverDir, keyFile), path.Join(serverDir, rootFile)), + e2e.ClientSideCertProviderInstance: e2e.DefaultFileWatcherConfig(path.Join(clientDir, certFile), path.Join(clientDir, keyFile), path.Join(clientDir, rootFile)), + } + + // Create a bootstrap file in a temporary directory. + nodeID := uuid.New().String() + bootstrapContents, err := xdsinternal.BootstrapContents(xdsinternal.BootstrapOptions{ + Version: xdsinternal.TransportV3, + NodeID: nodeID, + ServerURI: server.Address, + CertificateProviders: cpc, + ServerListenerResourceNameTemplate: e2e.ServerListenerResourceNameTemplate, + }) + if err != nil { + server.Stop() + t.Fatalf("Failed to create bootstrap file: %v", err) + } + resolver, err := xds.NewXDSResolverWithConfigForTesting(bootstrapContents) + if err != nil { + server.Stop() + t.Fatalf("Failed to create xDS resolver for testing: %v", err) + } + + return server, nodeID, bootstrapContents, resolver, func() { server.Stop() } +} diff --git a/xds/internal/test/xds_security_config_nack_test.go b/xds/internal/test/xds_security_config_nack_test.go new file mode 100644 index 00000000000..7b8e36c3f3a --- /dev/null +++ b/xds/internal/test/xds_security_config_nack_test.go @@ -0,0 +1,372 @@ +//go:build !386 +// +build !386 + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds_test + +import ( + "context" + "fmt" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + xdscreds "google.golang.org/grpc/credentials/xds" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +func (s) TestUnmarshalListener_WithUpdateValidatorFunc(t *testing.T) { + const ( + serviceName = "my-service-client-side-xds" + missingIdentityProviderInstance = "missing-identity-provider-instance" + missingRootProviderInstance = "missing-root-provider-instance" + ) + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + // Grab the host and port of the server and create client side xDS + // resources corresponding to it. + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + + // Create xDS resources to be consumed on the client side. This + // includes the listener, route configuration, cluster (with + // security configuration) and endpoint resources. + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelMTLS, + }) + + tests := []struct { + name string + securityConfig *v3corepb.TransportSocket + wantErr bool + }{ + { + name: "both identity and root providers are not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingIdentityProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingRootProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "only identity provider is not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingIdentityProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ServerSideCertProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "only root provider is not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ServerSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingRootProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "both identity and root providers are present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ServerSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ServerSideCertProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: false, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Create an inbound xDS listener resource for the server side. + inboundLis := e2e.DefaultServerListener(host, port, e2e.SecurityLevelMTLS) + for _, fc := range inboundLis.GetFilterChains() { + fc.TransportSocket = test.securityConfig + } + + // Setup the management server with client and server resources. + if len(resources.Listeners) == 1 { + resources.Listeners = append(resources.Listeners, inboundLis) + } else { + resources.Listeners[1] = inboundLis + } + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create client-side xDS credentials with an insecure fallback. + creds, err := xdscreds.NewClientCredentials(xdscreds.ClientOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + t.Fatal(err) + } + + // Create a ClientConn with the xds scheme and make an RPC. + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + // Make a context with a shorter timeout from the top level test + // context for cases where we expect failures. + timeout := defaultTestTimeout + if test.wantErr { + timeout = defaultTestShortTimeout + } + ctx2, cancel2 := context.WithTimeout(ctx, timeout) + defer cancel2() + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx2, &testpb.Empty{}, grpc.WaitForReady(true)); (err != nil) != test.wantErr { + t.Fatalf("EmptyCall() returned err: %v, wantErr %v", err, test.wantErr) + } + }) + } +} + +func (s) TestUnmarshalCluster_WithUpdateValidatorFunc(t *testing.T) { + const ( + serviceName = "my-service-client-side-xds" + missingIdentityProviderInstance = "missing-identity-provider-instance" + missingRootProviderInstance = "missing-root-provider-instance" + ) + + tests := []struct { + name string + securityConfig *v3corepb.TransportSocket + wantErr bool + }{ + { + name: "both identity and root providers are not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingIdentityProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingRootProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "only identity provider is not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingIdentityProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ClientSideCertProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "only root provider is not present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ClientSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: missingRootProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: true, + }, + { + name: "both identity and root providers are present in bootstrap", + securityConfig: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ClientSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: e2e.ClientSideCertProviderInstance, + }, + }, + }, + }, + }), + }, + }, + wantErr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // setupManagementServer() sets up a bootstrap file with certificate + // provider instance names: `e2e.ServerSideCertProviderInstance` and + // `e2e.ClientSideCertProviderInstance`. + managementServer, nodeID, _, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + port, cleanup2 := clientSetup(t, &testService{}) + defer cleanup2() + + // This creates a `Cluster` resource with a security config which + // refers to `e2e.ClientSideCertProviderInstance` for both root and + // identity certs. + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: port, + SecLevel: e2e.SecurityLevelMTLS, + }) + resources.Clusters[0].TransportSocket = test.securityConfig + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + // Make a context with a shorter timeout from the top level test + // context for cases where we expect failures. + timeout := defaultTestTimeout + if test.wantErr { + timeout = defaultTestShortTimeout + } + ctx2, cancel2 := context.WithTimeout(ctx, timeout) + defer cancel2() + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx2, &testpb.Empty{}, grpc.WaitForReady(true)); (err != nil) != test.wantErr { + t.Fatalf("EmptyCall() returned err: %v, wantErr %v", err, test.wantErr) + } + }) + } +} diff --git a/xds/internal/test/xds_server_integration_test.go b/xds/internal/test/xds_server_integration_test.go index 169daf19d26..707a9605d82 100644 --- a/xds/internal/test/xds_server_integration_test.go +++ b/xds/internal/test/xds_server_integration_test.go @@ -1,3 +1,4 @@ +//go:build !386 // +build !386 /* @@ -23,36 +24,35 @@ package xds_test import ( "context" - "crypto/tls" - "crypto/x509" "fmt" - "io/ioutil" "net" - "os" - "path" "strconv" + "strings" "testing" - v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" - v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" - wrapperspb "github.com/golang/protobuf/ptypes/wrappers" - "github.com/google/uuid" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - xdscreds "google.golang.org/grpc/credentials/xds" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/status" - testpb "google.golang.org/grpc/test/grpc_testing" - "google.golang.org/grpc/testdata" "google.golang.org/grpc/xds" - "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/httpfilter/rbac" "google.golang.org/grpc/xds/internal/testutils/e2e" - "google.golang.org/grpc/xds/internal/version" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + rpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/rbac/v3" + v3routerpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/router/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + anypb "github.com/golang/protobuf/ptypes/any" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" + xdscreds "google.golang.org/grpc/credentials/xds" + testpb "google.golang.org/grpc/test/grpc_testing" + xdstestutils "google.golang.org/grpc/xds/internal/testutils" ) const ( @@ -62,156 +62,587 @@ const ( rootFile = "ca.pem" ) -func createTmpFile(t *testing.T, src, dst string) { +// setupGRPCServer performs the following: +// - spin up an xDS-enabled gRPC server, configure it with xdsCredentials and +// register the test service on it +// - create a local TCP listener and start serving on it +// +// Returns the following: +// - local listener on which the xDS-enabled gRPC server is serving on +// - cleanup function to be invoked by the tests when done +func setupGRPCServer(t *testing.T, bootstrapContents []byte) (net.Listener, func()) { t.Helper() - data, err := ioutil.ReadFile(src) + // Configure xDS credentials to be used on the server-side. + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{ + FallbackCreds: insecure.NewCredentials(), + }) if err != nil { - t.Fatalf("ioutil.ReadFile(%q) failed: %v", src, err) - } - if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil { - t.Fatalf("ioutil.WriteFile(%q) failed: %v", dst, err) + t.Fatal(err) } - t.Logf("Wrote file at: %s", dst) - t.Logf("%s", string(data)) -} -// createTempDirWithFiles creates a temporary directory under the system default -// tempDir with the given dirSuffix. It also reads from certSrc, keySrc and -// rootSrc files are creates appropriate files under the newly create tempDir. -// Returns the name of the created tempDir. -func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string { - t.Helper() + // Initialize an xDS-enabled gRPC server and register the stubServer on it. + server := xds.NewGRPCServer(grpc.Creds(creds), xds.BootstrapContentsForTesting(bootstrapContents)) + testpb.RegisterTestServiceServer(server, &testService{}) - // Create a temp directory. Passing an empty string for the first argument - // uses the system temp directory. - dir, err := ioutil.TempDir("", dirSuffix) + // Create a local listener and pass it to Serve(). + lis, err := xdstestutils.LocalTCPListener() if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) + t.Fatalf("testutils.LocalTCPListener() failed: %v", err) } - t.Logf("Using tmpdir: %s", dir) - createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile)) - createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile)) - createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile)) - return dir + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + return lis, func() { + server.Stop() + } } -// createClientTLSCredentials creates client-side TLS transport credentials. -func createClientTLSCredentials(t *testing.T) credentials.TransportCredentials { - cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) +func hostPortFromListener(lis net.Listener) (string, uint32, error) { + host, p, err := net.SplitHostPort(lis.Addr().String()) if err != nil { - t.Fatalf("tls.LoadX509KeyPair(x509/client1_cert.pem, x509/client1_key.pem) failed: %v", err) + return "", 0, fmt.Errorf("net.SplitHostPort(%s) failed: %v", lis.Addr().String(), err) } - b, err := ioutil.ReadFile(testdata.Path("x509/server_ca_cert.pem")) + port, err := strconv.ParseInt(p, 10, 32) if err != nil { - t.Fatalf("ioutil.ReadFile(x509/server_ca_cert.pem) failed: %v", err) + return "", 0, fmt.Errorf("strconv.ParseInt(%s, 10, 32) failed: %v", p, err) } - roots := x509.NewCertPool() - if !roots.AppendCertsFromPEM(b) { - t.Fatal("failed to append certificates") - } - return credentials.NewTLS(&tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: roots, - ServerName: "x.test.example.com", - }) + return host, uint32(port), nil } -// commonSetup performs a bunch of steps common to all xDS server tests here: -// - spin up an xDS management server on a local port -// - set up certificates for consumption by the file_watcher plugin -// - spin up an xDS-enabled gRPC server, configure it with xdsCredentials and -// register the test service on it -// - create a local TCP listener and start serving on it +// TestServerSideXDS_Fallback is an e2e test which verifies xDS credentials +// fallback functionality. // -// Returns the following: -// - the management server: tests use this to configure resources -// - nodeID expected by the management server: this is set in the Node proto -// sent by the xdsClient used on the xDS-enabled gRPC server -// - local listener on which the xDS-enabled gRPC server is serving on -// - cleanup function to be invoked by the tests when done -func commonSetup(t *testing.T) (*e2e.ManagementServer, string, net.Listener, func()) { - t.Helper() +// The following sequence of events happen as part of this test: +// - An xDS-enabled gRPC server is created and xDS credentials are configured. +// - xDS is enabled on the client by the use of the xds:/// scheme, and xDS +// credentials are configured. +// - Control plane is configured to not send any security configuration to both +// the client and the server. This results in both of them using the +// configured fallback credentials (which is insecure creds in this case). +func (s) TestServerSideXDS_Fallback(t *testing.T) { + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() - // Spin up a xDS management server on a local port. - nodeID := uuid.New().String() - fs, err := e2e.StartManagementServer() + // Grab the host and port of the server and create client side xDS resources + // corresponding to it. This contains default resources with no security + // configuration in the Cluster resources. + host, port, err := hostPortFromListener(lis) if err != nil { - t.Fatal(err) + t.Fatalf("failed to retrieve host and port of server: %v", err) } + const serviceName = "my-service-fallback" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) - // Create certificate and key files in a temporary directory and generate - // certificate provider configuration for a file_watcher plugin. - tmpdir := createTmpDirWithFiles(t, "testServerSideXDS*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") - cpc := e2e.DefaultFileWatcherConfig(path.Join(tmpdir, certFile), path.Join(tmpdir, keyFile), path.Join(tmpdir, rootFile)) + // Create an inbound xDS listener resource for the server side that does not + // contain any security configuration. This should force the server-side + // xdsCredentials to use fallback. + inboundLis := e2e.DefaultServerListener(host, port, e2e.SecurityLevelNone) + resources.Listeners = append(resources.Listeners, inboundLis) - // Create a bootstrap file in a temporary directory. - bootstrapCleanup, err := e2e.SetupBootstrapFile(e2e.BootstrapOptions{ - Version: e2e.TransportV3, - NodeID: nodeID, - ServerURI: fs.Address, - CertificateProviders: cpc, - ServerResourceNameID: "grpc/server", + // Setup the management server with client and server-side resources. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create client-side xDS credentials with an insecure fallback. + creds, err := xdscreds.NewClientCredentials(xdscreds.ClientOptions{ + FallbackCreds: insecure.NewCredentials(), }) if err != nil { t.Fatal(err) } - // Configure xDS credentials to be used on the server-side. - creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{ + // Create a ClientConn with the xds scheme and make a successful RPC. + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Errorf("rpc EmptyCall() failed: %v", err) + } +} + +// TestServerSideXDS_FileWatcherCerts is an e2e test which verifies xDS +// credentials with file watcher certificate provider. +// +// The following sequence of events happen as part of this test: +// - An xDS-enabled gRPC server is created and xDS credentials are configured. +// - xDS is enabled on the client by the use of the xds:/// scheme, and xDS +// credentials are configured. +// - Control plane is configured to send security configuration to both the +// client and the server, pointing to the file watcher certificate provider. +// We verify both TLS and mTLS scenarios. +func (s) TestServerSideXDS_FileWatcherCerts(t *testing.T) { + tests := []struct { + name string + secLevel e2e.SecurityLevel + }{ + { + name: "tls", + secLevel: e2e.SecurityLevelTLS, + }, + { + name: "mtls", + secLevel: e2e.SecurityLevelMTLS, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + // Grab the host and port of the server and create client side xDS + // resources corresponding to it. + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + + // Create xDS resources to be consumed on the client side. This + // includes the listener, route configuration, cluster (with + // security configuration) and endpoint resources. + serviceName := "my-service-file-watcher-certs-" + test.name + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: test.secLevel, + }) + + // Create an inbound xDS listener resource for the server side that + // contains security configuration pointing to the file watcher + // plugin. + inboundLis := e2e.DefaultServerListener(host, port, test.secLevel) + resources.Listeners = append(resources.Listeners, inboundLis) + + // Setup the management server with client and server resources. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create client-side xDS credentials with an insecure fallback. + creds, err := xdscreds.NewClientCredentials(xdscreds.ClientOptions{ + FallbackCreds: insecure.NewCredentials(), + }) + if err != nil { + t.Fatal(err) + } + + // Create a ClientConn with the xds scheme and make an RPC. + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) + } + }) + } +} + +// TestServerSideXDS_SecurityConfigChange is an e2e test where xDS is enabled on +// the server-side and xdsCredentials are configured for security. The control +// plane initially does not any security configuration. This forces the +// xdsCredentials to use fallback creds, which is this case is insecure creds. +// We verify that a client connecting with TLS creds is not able to successfully +// make an RPC. The control plane then sends a listener resource with security +// configuration pointing to the use of the file_watcher plugin and we verify +// that the same client is now able to successfully make an RPC. +func (s) TestServerSideXDS_SecurityConfigChange(t *testing.T) { + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + // Grab the host and port of the server and create client side xDS resources + // corresponding to it. This contains default resources with no security + // configuration in the Cluster resource. This should force the xDS + // credentials on the client to use its fallback. + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + const serviceName = "my-service-security-config-change" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + + // Create an inbound xDS listener resource for the server side that does not + // contain any security configuration. This should force the xDS credentials + // on server to use its fallback. + inboundLis := e2e.DefaultServerListener(host, port, e2e.SecurityLevelNone) + resources.Listeners = append(resources.Listeners, inboundLis) + + // Setup the management server with client and server-side resources. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Create client-side xDS credentials with an insecure fallback. + xdsCreds, err := xdscreds.NewClientCredentials(xdscreds.ClientOptions{ FallbackCreds: insecure.NewCredentials(), }) if err != nil { t.Fatal(err) } - // Initialize an xDS-enabled gRPC server and register the stubServer on it. - server := xds.NewGRPCServer(grpc.Creds(creds)) - testpb.RegisterTestServiceServer(server, &testService{}) + // Create a ClientConn with the xds scheme and make a successful RPC. + xdsCC, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(xdsCreds), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer xdsCC.Close() - // Create a local listener and pass it to Serve(). - lis, err := testutils.LocalTCPListener() + client := testpb.NewTestServiceClient(xdsCC) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) + } + + // Create a ClientConn with TLS creds. This should fail since the server is + // using fallback credentials which in this case in insecure creds. + tlsCreds := createClientTLSCredentials(t) + tlsCC, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithTransportCredentials(tlsCreds)) if err != nil { - t.Fatalf("testutils.LocalTCPListener() failed: %v", err) + t.Fatalf("failed to dial local test server: %v", err) } + defer tlsCC.Close() - go func() { - if err := server.Serve(lis); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() + // We don't set 'waitForReady` here since we want this call to failfast. + client = testpb.NewTestServiceClient(tlsCC) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable { + t.Fatal("rpc EmptyCall() succeeded when expected to fail") + } - return fs, nodeID, lis, func() { - fs.Stop() - bootstrapCleanup() - server.Stop() + // Switch server and client side resources with ones that contain required + // security configuration for mTLS with a file watcher certificate provider. + resources = e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelMTLS, + }) + inboundLis = e2e.DefaultServerListener(host, port, e2e.SecurityLevelMTLS) + resources.Listeners = append(resources.Listeners, inboundLis) + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + // Make another RPC with `waitForReady` set and expect this to succeed. + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) } } -func hostPortFromListener(t *testing.T, lis net.Listener) (string, uint32) { - t.Helper() +// TestServerSideXDS_RouteConfiguration is an e2e test which verifies routing +// functionality. The xDS enabled server will be set up with route configuration +// where the route configuration has routes with the correct routing actions +// (NonForwardingAction), and the RPC's matching those routes should proceed as +// normal. +func (s) TestServerSideXDS_RouteConfiguration(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() - host, p, err := net.SplitHostPort(lis.Addr().String()) + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + host, port, err := hostPortFromListener(lis) if err != nil { - t.Fatalf("net.SplitHostPort(%s) failed: %v", lis.Addr().String(), err) + t.Fatalf("failed to retrieve host and port of server: %v", err) } - port, err := strconv.ParseInt(p, 10, 32) + const serviceName = "my-service-fallback" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + + // Create an inbound xDS listener resource with route configuration which + // selectively will allow RPC's through or not. This will test routing in + // xds(Unary|Stream)Interceptors. + vhs := []*v3routepb.VirtualHost{ + // Virtual host that will never be matched to test Virtual Host selection. + { + Domains: []string{"this will not match*"}, + Routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }, + }, + }, + // This Virtual Host will actually get matched to. + { + Domains: []string{"*"}, + Routes: []*v3routepb.Route{ + // A routing rule that can be selectively triggered based on properties about incoming RPC. + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/grpc.testing.TestService/EmptyCall"}, + // "Fully-qualified RPC method name with leading slash. Same as :path header". + }, + // Correct Action, so RPC's that match this route should proceed to interceptor processing. + Action: &v3routepb.Route_NonForwardingAction{}, + }, + // This routing rule is matched the same way as the one above, + // except has an incorrect action for the server side. However, + // since routing chooses the first route which matches an + // incoming RPC, this should never get invoked (iteration + // through this route slice is deterministic). + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/grpc.testing.TestService/EmptyCall"}, + // "Fully-qualified RPC method name with leading slash. Same as :path header". + }, + // Incorrect Action, so RPC's that match this route should get denied. + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: ""}}, + }, + }, + // Another routing rule that can be selectively triggered based on incoming RPC. + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/grpc.testing.TestService/UnaryCall"}, + }, + // Wrong action (!Non_Forwarding_Action) so RPC's that match this route should get denied. + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: ""}}, + }, + }, + // Another routing rule that can be selectively triggered based on incoming RPC. + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/grpc.testing.TestService/StreamingInputCall"}, + }, + // Wrong action (!Non_Forwarding_Action) so RPC's that match this route should get denied. + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: ""}}, + }, + }, + // Not matching route, this is be able to get invoked logically (i.e. doesn't have to match the Route configurations above). + }}, + } + inboundLis := &v3listenerpb.Listener{ + Name: fmt.Sprintf(e2e.ServerListenerResourceNameTemplate, net.JoinHostPort(host, strconv.Itoa(int(port)))), + Address: &v3corepb.Address{ + Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Address: host, + PortSpecifier: &v3corepb.SocketAddress_PortValue{ + PortValue: port, + }, + }, + }, + }, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "v4-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{e2e.HTTPFilter("router", &v3routerpb.Router{})}, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: vhs, + }, + }, + }), + }, + }, + }, + }, + { + Name: "v6-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{e2e.HTTPFilter("router", &v3routerpb.Router{})}, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: vhs, + }, + }, + }), + }, + }, + }, + }, + }, + } + resources.Listeners = append(resources.Listeners, inboundLis) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Setup the management server with client and server-side resources. + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithInsecure(), grpc.WithResolvers(resolver)) if err != nil { - t.Fatalf("strconv.ParseInt(%s, 10, 32) failed: %v", p, err) + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + + // This Empty Call should match to a route with a correct action + // (NonForwardingAction). Thus, this RPC should proceed as normal. There is + // a routing rule that this RPC would match to that has an incorrect action, + // but the server should only use the first route matched to with the + // correct action. + if _, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) } - return host, uint32(port) + // This Unary Call should match to a route with an incorrect action. Thus, + // this RPC should not go through as per A36, and this call should receive + // an error with codes.Unavailable. + if _, err = client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Unavailable { + t.Fatalf("client.UnaryCall() = _, %v, want _, error code %s", err, codes.Unavailable) + } + + // This Streaming Call should match to a route with an incorrect action. + // Thus, this RPC should not go through as per A36, and this call should + // receive an error with codes.Unavailable. + stream, err := client.StreamingInputCall(ctx) + if err != nil { + t.Fatalf("StreamingInputCall(_) = _, %v, want ", err) + } + if _, err = stream.CloseAndRecv(); status.Code(err) != codes.Unavailable || !strings.Contains(err.Error(), "the incoming RPC matched to a route that was not of action type non forwarding") { + t.Fatalf("streaming RPC should have been denied") + } + + // This Full Duplex should not match to a route, and thus should return an + // error and not proceed. + dStream, err := client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("FullDuplexCall(_) = _, %v, want ", err) + } + if _, err = dStream.Recv(); status.Code(err) != codes.Unavailable || !strings.Contains(err.Error(), "the incoming RPC did not match a configured Route") { + t.Fatalf("streaming RPC should have been denied") + } } -// listenerResourceWithoutSecurityConfig returns a listener resource with no -// security configuration, and name and address fields matching the passed in -// net.Listener. -func listenerResourceWithoutSecurityConfig(t *testing.T, lis net.Listener) *v3listenerpb.Listener { - host, port := hostPortFromListener(t, lis) +// serverListenerWithRBACHTTPFilters returns an xds Listener resource with HTTP Filters defined in the HCM, and a route +// configuration that always matches to a route and a VH. +func serverListenerWithRBACHTTPFilters(host string, port uint32, rbacCfg *rpb.RBAC) *v3listenerpb.Listener { + // Rather than declare typed config inline, take a HCM proto and append the + // RBAC Filters to it. + hcm := &v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"*"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}, + // This tests override parsing + building when RBAC Filter + // passed both normal and override config. + TypedPerFilterConfig: map[string]*anypb.Any{ + "rbac": testutils.MarshalAny(&rpb.RBACPerRoute{Rbac: rbacCfg}), + }, + }}}, + }, + } + hcm.HttpFilters = nil + hcm.HttpFilters = append(hcm.HttpFilters, e2e.HTTPFilter("rbac", rbacCfg)) + hcm.HttpFilters = append(hcm.HttpFilters, e2e.RouterHTTPFilter) + return &v3listenerpb.Listener{ - // This needs to match the name we are querying for. - Name: fmt.Sprintf("grpc/server?udpa.resource.listening_address=%s", lis.Addr().String()), + Name: fmt.Sprintf(e2e.ServerListenerResourceNameTemplate, net.JoinHostPort(host, strconv.Itoa(int(port)))), Address: &v3corepb.Address{ Address: &v3corepb.Address_SocketAddress{ SocketAddress: &v3corepb.SocketAddress{ @@ -224,176 +655,563 @@ func listenerResourceWithoutSecurityConfig(t *testing.T, lis net.Listener) *v3li }, FilterChains: []*v3listenerpb.FilterChain{ { - Name: "filter-chain-1", + Name: "v4-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(hcm), + }, + }, + }, + }, + { + Name: "v6-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(hcm), + }, + }, + }, + }, + }, + } +} + +// TestRBACHTTPFilter tests the xds configured RBAC HTTP Filter. It sets up the +// full end to end flow, and makes sure certain RPC's are successful and proceed +// as normal and certain RPC's are denied by the RBAC HTTP Filter which gets +// called by hooked xds interceptors. +func (s) TestRBACHTTPFilter(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + rbac.RegisterForTesting() + defer rbac.UnregisterForTesting() + tests := []struct { + name string + rbacCfg *rpb.RBAC + wantStatusEmptyCall codes.Code + wantStatusUnaryCall codes.Code + }{ + // This test tests an RBAC HTTP Filter which is configured to allow any RPC. + // Any RPC passing through this RBAC HTTP Filter should proceed as normal. + { + name: "allow-anything", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, + }, + // This test tests an RBAC HTTP Filter which is configured to allow only + // RPC's with certain paths ("UnaryCall"). Only unary calls passing + // through this RBAC HTTP Filter should proceed as normal, and any + // others should be denied. + { + name: "allow-certain-path", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-path": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "/grpc.testing.TestService/UnaryCall"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.PermissionDenied, + wantStatusUnaryCall: codes.OK, + }, + // This test that a RBAC Config with nil rules means that every RPC is + // allowed. This maps to the line "If absent, no enforcing RBAC policy + // will be applied" from the RBAC Proto documentation for the Rules + // field. + { + name: "absent-rules", + rbacCfg: &rpb.RBAC{ + Rules: nil, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, + }, + // The two tests below test that configuring the xDS RBAC HTTP Filter + // with :authority and host header matchers end up being logically + // equivalent. This represents functionality from this line in A41 - + // "As documented for HeaderMatcher, Envoy aliases :authority and Host + // in its header map implementation, so they should be treated + // equivalent for the RBAC matchers; there must be no behavior change + // depending on which of the two header names is used in the RBAC + // policy." + + // This test tests an xDS RBAC Filter with an :authority header matcher. + { + name: "match-on-authority", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "match-on-authority": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":authority", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "my-service-fallback"}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, + }, + // This test tests that configuring an xDS RBAC Filter with a host + // header matcher has the same behavior as if it was configured with + // :authority. Since host and authority are aliased, this should still + // continue to match on incoming RPC's :authority, just as the test + // above. + { + name: "match-on-host", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "match-on-authority": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: "host", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "my-service-fallback"}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, + }, + // This test tests that the RBAC HTTP Filter hard codes the :method + // header to POST. Since the RBAC Configuration says to deny every RPC + // with a method :POST, every RPC tried should be denied. + { + name: "deny-post", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "post-method": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "POST"}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, }, + wantStatusEmptyCall: codes.PermissionDenied, + wantStatusUnaryCall: codes.PermissionDenied, + }, + // This test tests that RBAC ignores the TE: trailers header (which is + // hardcoded in http2_client.go for every RPC). Since the RBAC + // Configuration says to only ALLOW RPC's with a TE: Trailers, every RPC + // tried should be denied. + { + name: "allow-only-te", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "post-method": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: "TE", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "trailers"}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.PermissionDenied, + wantStatusUnaryCall: codes.PermissionDenied, + }, + // This test tests that an RBAC Config with Action.LOG configured allows + // every RPC through. This maps to the line "At this time, if the + // RBAC.action is Action.LOG then the policy will be completely ignored, + // as if RBAC was not configurated." from A41 + { + name: "action-log", + rbacCfg: &rpb.RBAC{ + Rules: &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_LOG, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantStatusEmptyCall: codes.OK, + wantStatusUnaryCall: codes.OK, }, } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + func() { + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + const serviceName = "my-service-fallback" + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + inboundLis := serverListenerWithRBACHTTPFilters(host, port, test.rbacCfg) + resources.Listeners = append(resources.Listeners, inboundLis) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Setup the management server with client and server-side resources. + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithInsecure(), grpc.WithResolvers(resolver)) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + + client := testpb.NewTestServiceClient(cc) + + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != test.wantStatusEmptyCall { + t.Fatalf("EmptyCall() returned err with status: %v, wantStatusEmptyCall: %v", status.Code(err), test.wantStatusEmptyCall) + } + + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != test.wantStatusUnaryCall { + t.Fatalf("UnaryCall() returned err with status: %v, wantStatusUnaryCall: %v", err, test.wantStatusUnaryCall) + } + + // Toggle the RBAC Env variable off, this should disable RBAC and allow any RPC"s through (will not go through + // routing or processed by HTTP Filters and thus will never get denied by RBAC). + env.RBACSupport = false + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.OK { + t.Fatalf("EmptyCall() returned err with status: %v, once RBAC is disabled all RPC's should proceed as normal", status.Code(err)) + } + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.OK { + t.Fatalf("UnaryCall() returned err with status: %v, once RBAC is disabled all RPC's should proceed as normal", status.Code(err)) + } + // Toggle RBAC back on for next iterations. + env.RBACSupport = true + }() + }) + } } -// listenerResourceWithSecurityConfig returns a listener resource with security -// configuration pointing to the use of the file_watcher certificate provider -// plugin, and name and address fields matching the passed in net.Listener. -func listenerResourceWithSecurityConfig(t *testing.T, lis net.Listener) *v3listenerpb.Listener { - host, port := hostPortFromListener(t, lis) +// serverListenerWithBadRouteConfiguration returns an xds Listener resource with +// a Route Configuration that will never successfully match in order to test +// RBAC Environment variable being toggled on and off. +func serverListenerWithBadRouteConfiguration(host string, port uint32) *v3listenerpb.Listener { return &v3listenerpb.Listener{ - // This needs to match the name we are querying for. - Name: fmt.Sprintf("grpc/server?udpa.resource.listening_address=%s", lis.Addr().String()), + Name: fmt.Sprintf(e2e.ServerListenerResourceNameTemplate, net.JoinHostPort(host, strconv.Itoa(int(port)))), Address: &v3corepb.Address{ Address: &v3corepb.Address_SocketAddress{ SocketAddress: &v3corepb.SocketAddress{ Address: host, PortSpecifier: &v3corepb.SocketAddress_PortValue{ PortValue: port, - }}}}, + }, + }, + }, + }, FilterChains: []*v3listenerpb.FilterChain{ { - Name: "filter-chain-1", - TransportSocket: &v3corepb.TransportSocket{ - Name: "envoy.transport_sockets.tls", - ConfigType: &v3corepb.TransportSocket_TypedConfig{ - TypedConfig: &anypb.Any{ - TypeUrl: version.V3DownstreamTLSContextURL, - Value: func() []byte { - tls := &v3tlspb.DownstreamTlsContext{ - RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, - CommonTlsContext: &v3tlspb.CommonTlsContext{ - TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "google_cloud_private_spiffe", - }, - ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ - ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ - InstanceName: "google_cloud_private_spiffe", - }}}} - mtls, _ := proto.Marshal(tls) - return mtls - }(), - }}}}}, + Name: "v4-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + // Incoming RPC's will try and match to Virtual Hosts based on their :authority header. + // Thus, incoming RPC's will never match to a Virtual Host (server side requires matching + // to a VH/Route of type Non Forwarding Action to proceed normally), and all incoming RPC's + // with this route configuration will be denied. + Domains: []string{"will-never-match"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + { + Name: "v6-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + // Incoming RPC's will try and match to Virtual Hosts based on their :authority header. + // Thus, incoming RPC's will never match to a Virtual Host (server side requires matching + // to a VH/Route of type Non Forwarding Action to proceed normally), and all incoming RPC's + // with this route configuration will be denied. + Domains: []string{"will-never-match"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, } } -// TestServerSideXDS_Fallback is an e2e test where xDS is enabled on the -// server-side and xdsCredentials are configured for security. The control plane -// does not provide any security configuration and therefore the xdsCredentials -// uses fallback credentials, which in this case is insecure creds. -func (s) TestServerSideXDS_Fallback(t *testing.T) { - fs, nodeID, lis, cleanup := commonSetup(t) - defer cleanup() +func (s) TestRBACToggledOn_WithBadRouteConfiguration(t *testing.T) { + // Turn RBAC support on. + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() - // Setup the fake management server to respond with a Listener resource that - // does not contain any security configuration. This should force the - // server-side xdsCredentials to use fallback. - listener := listenerResourceWithoutSecurityConfig(t, lis) - if err := fs.Update(e2e.UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{listener}, - }); err != nil { - t.Error(err) - } + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() - // Create a ClientConn and make a successful RPC. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - cc, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + host, port, err := hostPortFromListener(lis) if err != nil { - t.Fatalf("failed to dial local test server: %v", err) + t.Fatalf("failed to retrieve host and port of server: %v", err) } - defer cc.Close() + const serviceName = "my-service-fallback" - client := testpb.NewTestServiceClient(cc) - if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { - t.Fatalf("rpc EmptyCall() failed: %v", err) - } -} + // The inbound listener needs a route table that will never match on a VH, + // and thus shouldn't allow incoming RPC's to proceed. + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + // Since RBAC support is turned ON, all the RPC's should get denied with + // status code Unavailable due to not matching to a route of type Non + // Forwarding Action (Route Table not configured properly). + inboundLis := serverListenerWithBadRouteConfiguration(host, port) + resources.Listeners = append(resources.Listeners, inboundLis) -// TestServerSideXDS_FileWatcherCerts is an e2e test where xDS is enabled on the -// server-side and xdsCredentials are configured for security. The control plane -// sends security configuration pointing to the use of the file_watcher plugin, -// and we verify that a client connecting with TLS creds is able to successfully -// make an RPC. -func (s) TestServerSideXDS_FileWatcherCerts(t *testing.T) { - fs, nodeID, lis, cleanup := commonSetup(t) - defer cleanup() - - // Setup the fake management server to respond with a Listener resource with - // security configuration pointing to the file watcher plugin and requiring - // mTLS. - listener := listenerResourceWithSecurityConfig(t, lis) - if err := fs.Update(e2e.UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{listener}, - }); err != nil { - t.Error(err) - } - - // Create a ClientConn with TLS creds and make a successful RPC. - clientCreds := createClientTLSCredentials(t) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithTransportCredentials(clientCreds)) + // Setup the management server with client and server-side resources. + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithInsecure(), grpc.WithResolvers(resolver)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } defer cc.Close() client := testpb.NewTestServiceClient(cc) - if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { - t.Fatalf("rpc EmptyCall() failed: %v", err) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable { + t.Fatalf("EmptyCall() returned err with status: %v, if RBAC is disabled all RPC's should proceed as normal", status.Code(err)) + } + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Unavailable { + t.Fatalf("UnaryCall() returned err with status: %v, if RBAC is disabled all RPC's should proceed as normal", status.Code(err)) } } -// TestServerSideXDS_SecurityConfigChange is an e2e test where xDS is enabled on -// the server-side and xdsCredentials are configured for security. The control -// plane initially does not any security configuration. This forces the -// xdsCredentials to use fallback creds, which is this case is insecure creds. -// We verify that a client connecting with TLS creds is not able to successfully -// make an RPC. The control plan then sends a listener resource with security -// configuration pointing to the use of the file_watcher plugin and we verify -// that the same client is now able to successfully make an RPC. -func (s) TestServerSideXDS_SecurityConfigChange(t *testing.T) { - fs, nodeID, lis, cleanup := commonSetup(t) - defer cleanup() +func (s) TestRBACToggledOff_WithBadRouteConfiguration(t *testing.T) { + // Turn RBAC support off. + oldRBAC := env.RBACSupport + env.RBACSupport = false + defer func() { + env.RBACSupport = oldRBAC + }() - // Setup the fake management server to respond with a Listener resource that - // does not contain any security configuration. This should force the - // server-side xdsCredentials to use fallback. - listener := listenerResourceWithoutSecurityConfig(t, lis) - if err := fs.Update(e2e.UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{listener}, - }); err != nil { - t.Error(err) + managementServer, nodeID, bootstrapContents, resolver, cleanup1 := setupManagementServer(t) + defer cleanup1() + + lis, cleanup2 := setupGRPCServer(t, bootstrapContents) + defer cleanup2() + + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) } + const serviceName = "my-service-fallback" + + // The inbound listener needs a route table that will never match on a VH, + // and thus shouldn't allow incoming RPC's to proceed. + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: host, + Port: port, + SecLevel: e2e.SecurityLevelNone, + }) + // This bad route configuration shouldn't affect incoming RPC's from + // proceeding as normal, as the configuration shouldn't be parsed due to the + // RBAC Environment variable not being set to true. + inboundLis := serverListenerWithBadRouteConfiguration(host, port) + resources.Listeners = append(resources.Listeners, inboundLis) - // Create a ClientConn with TLS creds. This should fail since the server is - // using fallback credentials which in this case in insecure creds. - clientCreds := createClientTLSCredentials(t) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithTransportCredentials(clientCreds)) + // Setup the management server with client and server-side resources. + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithInsecure(), grpc.WithResolvers(resolver)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } defer cc.Close() - // We don't set 'waitForReady` here since we want this call to failfast. client := testpb.NewTestServiceClient(cc) - if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Code() != codes.Unavailable { - t.Fatal("rpc EmptyCall() succeeded when expected to fail") - } - - // Setup the fake management server to respond with a Listener resource with - // security configuration pointing to the file watcher plugin and requiring - // mTLS. - listener = listenerResourceWithSecurityConfig(t, lis) - if err := fs.Update(e2e.UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{listener}, - }); err != nil { - t.Error(err) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.OK { + t.Fatalf("EmptyCall() returned err with status: %v, if RBAC is disabled all RPC's should proceed as normal", status.Code(err)) } - - // Make another RPC with `waitForReady` set and expect this to succeed. - if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { - t.Fatalf("rpc EmptyCall() failed: %v", err) + if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.OK { + t.Fatalf("UnaryCall() returned err with status: %v, if RBAC is disabled all RPC's should proceed as normal", status.Code(err)) } } diff --git a/xds/internal/test/xds_server_serving_mode_test.go b/xds/internal/test/xds_server_serving_mode_test.go new file mode 100644 index 00000000000..ac4b3929cb6 --- /dev/null +++ b/xds/internal/test/xds_server_serving_mode_test.go @@ -0,0 +1,388 @@ +//go:build !386 +// +build !386 + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package xds_test contains e2e tests for xDS use. +package xds_test + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + xdscreds "google.golang.org/grpc/credentials/xds" + testpb "google.golang.org/grpc/test/grpc_testing" + "google.golang.org/grpc/xds" + xdstestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" +) + +// TestServerSideXDS_RedundantUpdateSuppression tests the scenario where the +// control plane sends the same resource update. It verifies that the mode +// change callback is not invoked and client connections to the server are not +// recycled. +func (s) TestServerSideXDS_RedundantUpdateSuppression(t *testing.T) { + managementServer, nodeID, bootstrapContents, _, cleanup := setupManagementServer(t) + defer cleanup() + + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + t.Fatal(err) + } + lis, err := xdstestutils.LocalTCPListener() + if err != nil { + t.Fatalf("testutils.LocalTCPListener() failed: %v", err) + } + updateCh := make(chan connectivity.ServingMode, 1) + + // Create a server option to get notified about serving mode changes. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + modeChangeOpt := xds.ServingModeCallback(func(addr net.Addr, args xds.ServingModeChangeArgs) { + t.Logf("serving mode for listener %q changed to %q, err: %v", addr.String(), args.Mode, args.Err) + updateCh <- args.Mode + }) + + // Initialize an xDS-enabled gRPC server and register the stubServer on it. + server := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) + defer server.Stop() + testpb.RegisterTestServiceServer(server, &testService{}) + + // Setup the management server to respond with the listener resources. + host, port, err := hostPortFromListener(lis) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + listener := e2e.DefaultServerListener(host, port, e2e.SecurityLevelNone) + resources := e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener}, + } + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + // Wait for the listener to move to "serving" mode. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh: + if mode != connectivity.ServingModeServing { + t.Fatalf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + + // Create a ClientConn and make a successful RPCs. + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc.Close() + waitForSuccessfulRPC(ctx, t, cc) + + // Start a goroutine to make sure that we do not see any connectivity state + // changes on the client connection. If redundant updates are not + // suppressed, server will recycle client connections. + errCh := make(chan error, 1) + go func() { + if cc.WaitForStateChange(ctx, connectivity.Ready) { + errCh <- fmt.Errorf("unexpected connectivity state change {%s --> %s} on the client connection", connectivity.Ready, cc.GetState()) + return + } + errCh <- nil + }() + + // Update the management server with the same listener resource. This will + // update the resource version though, and should result in a the management + // server sending the same resource to the xDS-enabled gRPC server. + if err := managementServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener}, + }); err != nil { + t.Fatal(err) + } + + // Since redundant resource updates are suppressed, we should not see the + // mode change callback being invoked. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + select { + case <-sCtx.Done(): + case mode := <-updateCh: + t.Fatalf("unexpected mode change callback with new mode %v", mode) + } + + // Make sure RPCs continue to succeed. + waitForSuccessfulRPC(ctx, t, cc) + + // Cancel the context to ensure that the WaitForStateChange call exits early + // and returns false. + cancel() + if err := <-errCh; err != nil { + t.Fatal(err) + } +} + +// TestServerSideXDS_ServingModeChanges tests the serving mode functionality in +// xDS enabled gRPC servers. It verifies that appropriate mode changes happen in +// the server, and also verifies behavior of clientConns under these modes. +func (s) TestServerSideXDS_ServingModeChanges(t *testing.T) { + managementServer, nodeID, bootstrapContents, _, cleanup := setupManagementServer(t) + defer cleanup() + + // Configure xDS credentials to be used on the server-side. + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{ + FallbackCreds: insecure.NewCredentials(), + }) + if err != nil { + t.Fatal(err) + } + + // Create two local listeners and pass it to Serve(). + lis1, err := xdstestutils.LocalTCPListener() + if err != nil { + t.Fatalf("testutils.LocalTCPListener() failed: %v", err) + } + lis2, err := xdstestutils.LocalTCPListener() + if err != nil { + t.Fatalf("testutils.LocalTCPListener() failed: %v", err) + } + + // Create a couple of channels on which mode updates will be pushed. + updateCh1 := make(chan connectivity.ServingMode, 1) + updateCh2 := make(chan connectivity.ServingMode, 1) + + // Create a server option to get notified about serving mode changes, and + // push the updated mode on the channels created above. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + modeChangeOpt := xds.ServingModeCallback(func(addr net.Addr, args xds.ServingModeChangeArgs) { + t.Logf("serving mode for listener %q changed to %q, err: %v", addr.String(), args.Mode, args.Err) + switch addr.String() { + case lis1.Addr().String(): + updateCh1 <- args.Mode + case lis2.Addr().String(): + updateCh2 <- args.Mode + default: + t.Logf("serving mode callback invoked for unknown listener address: %q", addr.String()) + } + }) + + // Initialize an xDS-enabled gRPC server and register the stubServer on it. + server := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) + defer server.Stop() + testpb.RegisterTestServiceServer(server, &testService{}) + + // Setup the management server to respond with server-side Listener + // resources for both listeners. + host1, port1, err := hostPortFromListener(lis1) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + listener1 := e2e.DefaultServerListener(host1, port1, e2e.SecurityLevelNone) + host2, port2, err := hostPortFromListener(lis2) + if err != nil { + t.Fatalf("failed to retrieve host and port of server: %v", err) + } + listener2 := e2e.DefaultServerListener(host2, port2, e2e.SecurityLevelNone) + resources := e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener1, listener2}, + } + if err := managementServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } + + go func() { + if err := server.Serve(lis1); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + go func() { + if err := server.Serve(lis2); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + + // Wait for both listeners to move to "serving" mode. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh1: + if mode != connectivity.ServingModeServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh2: + if mode != connectivity.ServingModeServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + + // Create a ClientConn to the first listener and make a successful RPCs. + cc1, err := grpc.Dial(lis1.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc1.Close() + waitForSuccessfulRPC(ctx, t, cc1) + + // Create a ClientConn to the second listener and make a successful RPCs. + cc2, err := grpc.Dial(lis2.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to dial local test server: %v", err) + } + defer cc2.Close() + waitForSuccessfulRPC(ctx, t, cc2) + + // Update the management server to remove the second listener resource. This + // should push only the second listener into "not-serving" mode. + if err := managementServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener1}, + }); err != nil { + t.Error(err) + } + + // Wait for lis2 to move to "not-serving" mode. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh2: + if mode != connectivity.ServingModeNotServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeNotServing) + } + } + + // Make sure RPCs succeed on cc1 and fail on cc2. + waitForSuccessfulRPC(ctx, t, cc1) + waitForFailedRPC(ctx, t, cc2) + + // Update the management server to remove the first listener resource as + // well. This should push the first listener into "not-serving" mode. Second + // listener is already in "not-serving" mode. + if err := managementServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{}, + }); err != nil { + t.Error(err) + } + + // Wait for lis1 to move to "not-serving" mode. lis2 was already removed + // from the xdsclient's resource cache. So, lis2's callback will not be + // invoked this time around. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh1: + if mode != connectivity.ServingModeNotServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeNotServing) + } + } + + // Make sure RPCs fail on both. + waitForFailedRPC(ctx, t, cc1) + waitForFailedRPC(ctx, t, cc2) + + // Make sure new connection attempts to "not-serving" servers fail. We use a + // short timeout since we expect this to fail. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if _, err := grpc.DialContext(sCtx, lis1.Addr().String(), grpc.WithBlock(), grpc.WithTransportCredentials(insecure.NewCredentials())); err == nil { + t.Fatal("successfully created clientConn to a server in \"not-serving\" state") + } + + // Update the management server with both listener resources. + if err := managementServer.Update(ctx, e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{listener1, listener2}, + }); err != nil { + t.Error(err) + } + + // Wait for both listeners to move to "serving" mode. + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh1: + if mode != connectivity.ServingModeServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for a mode change update: %v", err) + case mode := <-updateCh2: + if mode != connectivity.ServingModeServing { + t.Errorf("listener received new mode %v, want %v", mode, connectivity.ServingModeServing) + } + } + + // The clientConns created earlier should be able to make RPCs now. + waitForSuccessfulRPC(ctx, t, cc1) + waitForSuccessfulRPC(ctx, t, cc2) +} + +func waitForSuccessfulRPC(ctx context.Context, t *testing.T, cc *grpc.ClientConn) { + t.Helper() + + c := testpb.NewTestServiceClient(cc) + if _, err := c.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("rpc EmptyCall() failed: %v", err) + } +} + +func waitForFailedRPC(ctx context.Context, t *testing.T, cc *grpc.ClientConn) { + t.Helper() + + // Attempt one RPC before waiting for the ticker to expire. + c := testpb.NewTestServiceClient(cc) + if _, err := c.EmptyCall(ctx, &testpb.Empty{}); err != nil { + return + } + + ticker := time.NewTimer(1 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + t.Fatalf("failure when waiting for RPCs to fail: %v", ctx.Err()) + case <-ticker.C: + if _, err := c.EmptyCall(ctx, &testpb.Empty{}); err != nil { + return + } + } + } +} diff --git a/xds/internal/testutils/balancer.go b/xds/internal/testutils/balancer.go index dab84a84e07..ff74da71cc9 100644 --- a/xds/internal/testutils/balancer.go +++ b/xds/internal/testutils/balancer.go @@ -46,21 +46,28 @@ var TestSubConns []*TestSubConn func init() { for i := 0; i < TestSubConnsCount; i++ { TestSubConns = append(TestSubConns, &TestSubConn{ - id: fmt.Sprintf("sc%d", i), + id: fmt.Sprintf("sc%d", i), + ConnectCh: make(chan struct{}, 1), }) } } // TestSubConn implements the SubConn interface, to be used in tests. type TestSubConn struct { - id string + id string + ConnectCh chan struct{} } // UpdateAddresses is a no-op. func (tsc *TestSubConn) UpdateAddresses([]resolver.Address) {} // Connect is a no-op. -func (tsc *TestSubConn) Connect() {} +func (tsc *TestSubConn) Connect() { + select { + case tsc.ConnectCh <- struct{}{}: + default: + } +} // String implements stringer to print human friendly error message. func (tsc *TestSubConn) String() string { @@ -76,8 +83,9 @@ type TestClientConn struct { RemoveSubConnCh chan balancer.SubConn // the last 10 subconn removed. UpdateAddressesAddrsCh chan []resolver.Address // last updated address via UpdateAddresses(). - NewPickerCh chan balancer.Picker // the last picker updated. - NewStateCh chan connectivity.State // the last state. + NewPickerCh chan balancer.Picker // the last picker updated. + NewStateCh chan connectivity.State // the last state. + ResolveNowCh chan resolver.ResolveNowOptions // the last ResolveNow(). subConnIdx int } @@ -92,8 +100,9 @@ func NewTestClientConn(t *testing.T) *TestClientConn { RemoveSubConnCh: make(chan balancer.SubConn, 10), UpdateAddressesAddrsCh: make(chan []resolver.Address, 1), - NewPickerCh: make(chan balancer.Picker, 1), - NewStateCh: make(chan connectivity.State, 1), + NewPickerCh: make(chan balancer.Picker, 1), + NewStateCh: make(chan connectivity.State, 1), + ResolveNowCh: make(chan resolver.ResolveNowOptions, 1), } } @@ -151,8 +160,12 @@ func (tcc *TestClientConn) UpdateState(bs balancer.State) { } // ResolveNow panics. -func (tcc *TestClientConn) ResolveNow(resolver.ResolveNowOptions) { - panic("not implemented") +func (tcc *TestClientConn) ResolveNow(o resolver.ResolveNowOptions) { + select { + case <-tcc.ResolveNowCh: + default: + } + tcc.ResolveNowCh <- o } // Target panics. diff --git a/xds/internal/testutils/e2e/bootstrap.go b/xds/internal/testutils/e2e/bootstrap.go index 25993f19fc3..99702032f81 100644 --- a/xds/internal/testutils/e2e/bootstrap.go +++ b/xds/internal/testutils/e2e/bootstrap.go @@ -21,98 +21,13 @@ package e2e import ( "encoding/json" "fmt" - "io/ioutil" - "os" - - "google.golang.org/grpc/xds/internal/env" ) -// TransportAPI refers to the API version for xDS transport protocol. -type TransportAPI int - -const ( - // TransportV2 refers to the v2 xDS transport protocol. - TransportV2 TransportAPI = iota - // TransportV3 refers to the v3 xDS transport protocol. - TransportV3 -) - -// BootstrapOptions wraps the parameters passed to SetupBootstrapFile. -type BootstrapOptions struct { - // Version is the xDS transport protocol version. - Version TransportAPI - // NodeID is the node identifier of the gRPC client/server node in the - // proxyless service mesh. - NodeID string - // ServerURI is the address of the management server. - ServerURI string - // ServerResourceNameID is the Listener resource name to fetch. - ServerResourceNameID string - // CertificateProviders is the certificate providers configuration. - CertificateProviders map[string]json.RawMessage -} - -// SetupBootstrapFile creates a temporary file with bootstrap contents, based on -// the passed in options, and updates the bootstrap environment variable to -// point to this file. -// -// Returns a cleanup function which will be non-nil if the setup process was -// completed successfully. It is the responsibility of the caller to invoke the -// cleanup function at the end of the test. -func SetupBootstrapFile(opts BootstrapOptions) (func(), error) { - f, err := ioutil.TempFile("", "test_xds_bootstrap_*") - if err != nil { - return nil, fmt.Errorf("failed to created bootstrap file: %v", err) - } - - cfg := &bootstrapConfig{ - XdsServers: []server{ - { - ServerURI: opts.ServerURI, - ChannelCreds: []creds{ - { - Type: "insecure", - }, - }, - }, - }, - Node: node{ - ID: opts.NodeID, - }, - CertificateProviders: opts.CertificateProviders, - GRPCServerResourceNameID: opts.ServerResourceNameID, - } - switch opts.Version { - case TransportV2: - // TODO: Add any v2 specific fields. - case TransportV3: - cfg.XdsServers[0].ServerFeatures = append(cfg.XdsServers[0].ServerFeatures, "xds_v3") - default: - return nil, fmt.Errorf("unsupported xDS transport protocol version: %v", opts.Version) - } - - bootstrapContents, err := json.MarshalIndent(cfg, "", " ") - if err != nil { - return nil, fmt.Errorf("failed to created bootstrap file: %v", err) - } - if err := ioutil.WriteFile(f.Name(), bootstrapContents, 0644); err != nil { - return nil, fmt.Errorf("failed to created bootstrap file: %v", err) - } - logger.Infof("Created bootstrap file at %q with contents: %s\n", f.Name(), bootstrapContents) - - origBootstrapFileName := env.BootstrapFileName - env.BootstrapFileName = f.Name() - return func() { - os.Remove(f.Name()) - env.BootstrapFileName = origBootstrapFileName - }, nil -} - // DefaultFileWatcherConfig is a helper function to create a default certificate // provider plugin configuration. The test is expected to have setup the files // appropriately before this configuration is used to instantiate providers. -func DefaultFileWatcherConfig(certPath, keyPath, caPath string) map[string]json.RawMessage { - cfg := fmt.Sprintf(`{ +func DefaultFileWatcherConfig(certPath, keyPath, caPath string) json.RawMessage { + return json.RawMessage(fmt.Sprintf(`{ "plugin_name": "file_watcher", "config": { "certificate_file": %q, @@ -120,30 +35,5 @@ func DefaultFileWatcherConfig(certPath, keyPath, caPath string) map[string]json. "ca_certificate_file": %q, "refresh_interval": "600s" } - }`, certPath, keyPath, caPath) - return map[string]json.RawMessage{ - "google_cloud_private_spiffe": json.RawMessage(cfg), - } -} - -type bootstrapConfig struct { - XdsServers []server `json:"xds_servers,omitempty"` - Node node `json:"node,omitempty"` - CertificateProviders map[string]json.RawMessage `json:"certificate_providers,omitempty"` - GRPCServerResourceNameID string `json:"grpc_server_resource_name_id,omitempty"` -} - -type server struct { - ServerURI string `json:"server_uri,omitempty"` - ChannelCreds []creds `json:"channel_creds,omitempty"` - ServerFeatures []string `json:"server_features,omitempty"` -} - -type creds struct { - Type string `json:"type,omitempty"` - Config interface{} `json:"config,omitempty"` -} - -type node struct { - ID string `json:"id,omitempty"` + }`, certPath, keyPath, caPath)) } diff --git a/xds/internal/testutils/e2e/clientresources.go b/xds/internal/testutils/e2e/clientresources.go index 79424b13b91..f3f7f6307c5 100644 --- a/xds/internal/testutils/e2e/clientresources.go +++ b/xds/internal/testutils/e2e/clientresources.go @@ -19,9 +19,13 @@ package e2e import ( + "fmt" + "net" + "strconv" + "github.com/envoyproxy/go-control-plane/pkg/wellknown" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/internal/testutils" v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -30,37 +34,76 @@ import ( v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" v3routerpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/router/v3" v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" - anypb "github.com/golang/protobuf/ptypes/any" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" wrapperspb "github.com/golang/protobuf/ptypes/wrappers" ) -func any(m proto.Message) *anypb.Any { - a, err := ptypes.MarshalAny(m) - if err != nil { - panic("error marshalling any: " + err.Error()) - } - return a +const ( + // ServerListenerResourceNameTemplate is the Listener resource name template + // used on the server side. + ServerListenerResourceNameTemplate = "grpc/server?xds.resource.listening_address=%s" + // ClientSideCertProviderInstance is the certificate provider instance name + // used in the Cluster resource on the client side. + ClientSideCertProviderInstance = "client-side-certificate-provider-instance" + // ServerSideCertProviderInstance is the certificate provider instance name + // used in the Listener resource on the server side. + ServerSideCertProviderInstance = "server-side-certificate-provider-instance" +) + +// SecurityLevel allows the test to control the security level to be used in the +// resource returned by this package. +type SecurityLevel int + +const ( + // SecurityLevelNone is used when no security configuration is required. + SecurityLevelNone SecurityLevel = iota + // SecurityLevelTLS is used when security configuration corresponding to TLS + // is required. Only the server presents an identity certificate in this + // configuration. + SecurityLevelTLS + // SecurityLevelMTLS is used when security ocnfiguration corresponding to + // mTLS is required. Both client and server present identity certificates in + // this configuration. + SecurityLevelMTLS +) + +// ResourceParams wraps the arguments to be passed to DefaultClientResources. +type ResourceParams struct { + // DialTarget is the client's dial target. This is used as the name of the + // Listener resource. + DialTarget string + // NodeID is the id of the xdsClient to which this update is to be pushed. + NodeID string + // Host is the host of the default Endpoint resource. + Host string + // port is the port of the default Endpoint resource. + Port uint32 + // SecLevel controls the security configuration in the Cluster resource. + SecLevel SecurityLevel } // DefaultClientResources returns a set of resources (LDS, RDS, CDS, EDS) for a // client to generically connect to one server. -func DefaultClientResources(target, nodeID, host string, port uint32) UpdateOptions { - const routeConfigName = "route" - const clusterName = "cluster" - const endpointsName = "endpoints" - +func DefaultClientResources(params ResourceParams) UpdateOptions { + routeConfigName := "route-" + params.DialTarget + clusterName := "cluster-" + params.DialTarget + endpointsName := "endpoints-" + params.DialTarget return UpdateOptions{ - NodeID: nodeID, - Listeners: []*v3listenerpb.Listener{DefaultListener(target, routeConfigName)}, - Routes: []*v3routepb.RouteConfiguration{DefaultRouteConfig(routeConfigName, target, clusterName)}, - Clusters: []*v3clusterpb.Cluster{DefaultCluster(clusterName, endpointsName)}, - Endpoints: []*v3endpointpb.ClusterLoadAssignment{DefaultEndpoint(endpointsName, host, port)}, + NodeID: params.NodeID, + Listeners: []*v3listenerpb.Listener{DefaultClientListener(params.DialTarget, routeConfigName)}, + Routes: []*v3routepb.RouteConfiguration{DefaultRouteConfig(routeConfigName, params.DialTarget, clusterName)}, + Clusters: []*v3clusterpb.Cluster{DefaultCluster(clusterName, endpointsName, params.SecLevel)}, + Endpoints: []*v3endpointpb.ClusterLoadAssignment{DefaultEndpoint(endpointsName, params.Host, []uint32{params.Port})}, } } -// DefaultListener returns a basic xds Listener resource. -func DefaultListener(target, routeName string) *v3listenerpb.Listener { - hcm := any(&v3httppb.HttpConnectionManager{ +// RouterHTTPFilter is the HTTP Filter configuration for the Router filter. +var RouterHTTPFilter = HTTPFilter("router", &v3routerpb.Router{}) + +// DefaultClientListener returns a basic xds Listener resource to be used on +// the client side. +func DefaultClientListener(target, routeName string) *v3listenerpb.Listener { + hcm := testutils.MarshalAny(&v3httppb.HttpConnectionManager{ RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{Rds: &v3httppb.Rds{ ConfigSource: &v3corepb.ConfigSource{ ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, @@ -82,12 +125,164 @@ func DefaultListener(target, routeName string) *v3listenerpb.Listener { } } +// DefaultServerListener returns a basic xds Listener resource to be used on +// the server side. +func DefaultServerListener(host string, port uint32, secLevel SecurityLevel) *v3listenerpb.Listener { + var tlsContext *v3tlspb.DownstreamTlsContext + switch secLevel { + case SecurityLevelNone: + case SecurityLevelTLS: + tlsContext = &v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ServerSideCertProviderInstance, + }, + }, + } + case SecurityLevelMTLS: + tlsContext = &v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ServerSideCertProviderInstance, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ServerSideCertProviderInstance, + }, + }, + }, + } + } + + var ts *v3corepb.TransportSocket + if tlsContext != nil { + ts = &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(tlsContext), + }, + } + } + return &v3listenerpb.Listener{ + Name: fmt.Sprintf(ServerListenerResourceNameTemplate, net.JoinHostPort(host, strconv.Itoa(int(port)))), + Address: &v3corepb.Address{ + Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Address: host, + PortSpecifier: &v3corepb.SocketAddress_PortValue{ + PortValue: port, + }, + }, + }, + }, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "v4-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "0.0.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + // This "*" string matches on any incoming authority. This is to ensure any + // incoming RPC matches to Route_NonForwardingAction and will proceed as + // normal. + Domains: []string{"*"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{RouterHTTPFilter}, + }), + }, + }, + }, + TransportSocket: ts, + }, + { + Name: "v6-wildcard", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "::", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(0), + }, + }, + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + // This "*" string matches on any incoming authority. This is to ensure any + // incoming RPC matches to Route_NonForwardingAction and will proceed as + // normal. + Domains: []string{"*"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{RouterHTTPFilter}, + }), + }, + }, + }, + TransportSocket: ts, + }, + }, + } +} + // HTTPFilter constructs an xds HttpFilter with the provided name and config. func HTTPFilter(name string, config proto.Message) *v3httppb.HttpFilter { return &v3httppb.HttpFilter{ Name: name, ConfigType: &v3httppb.HttpFilter_TypedConfig{ - TypedConfig: any(config), + TypedConfig: testutils.MarshalAny(config), }, } } @@ -109,8 +304,36 @@ func DefaultRouteConfig(routeName, ldsTarget, clusterName string) *v3routepb.Rou } // DefaultCluster returns a basic xds Cluster resource. -func DefaultCluster(clusterName, edsServiceName string) *v3clusterpb.Cluster { - return &v3clusterpb.Cluster{ +func DefaultCluster(clusterName, edsServiceName string, secLevel SecurityLevel) *v3clusterpb.Cluster { + var tlsContext *v3tlspb.UpstreamTlsContext + switch secLevel { + case SecurityLevelNone: + case SecurityLevelTLS: + tlsContext = &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ClientSideCertProviderInstance, + }, + }, + }, + } + case SecurityLevelMTLS: + tlsContext = &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ClientSideCertProviderInstance, + }, + }, + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: ClientSideCertProviderInstance, + }, + }, + } + } + + cluster := &v3clusterpb.Cluster{ Name: clusterName, ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ @@ -123,24 +346,37 @@ func DefaultCluster(clusterName, edsServiceName string) *v3clusterpb.Cluster { }, LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, } + if tlsContext != nil { + cluster.TransportSocket = &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(tlsContext), + }, + } + } + return cluster } // DefaultEndpoint returns a basic xds Endpoint resource. -func DefaultEndpoint(clusterName string, host string, port uint32) *v3endpointpb.ClusterLoadAssignment { +func DefaultEndpoint(clusterName string, host string, ports []uint32) *v3endpointpb.ClusterLoadAssignment { + var lbEndpoints []*v3endpointpb.LbEndpoint + for _, port := range ports { + lbEndpoints = append(lbEndpoints, &v3endpointpb.LbEndpoint{ + HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{Endpoint: &v3endpointpb.Endpoint{ + Address: &v3corepb.Address{Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Protocol: v3corepb.SocketAddress_TCP, + Address: host, + PortSpecifier: &v3corepb.SocketAddress_PortValue{PortValue: port}}, + }}, + }}, + }) + } return &v3endpointpb.ClusterLoadAssignment{ ClusterName: clusterName, Endpoints: []*v3endpointpb.LocalityLbEndpoints{{ - Locality: &v3corepb.Locality{SubZone: "subzone"}, - LbEndpoints: []*v3endpointpb.LbEndpoint{{ - HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{Endpoint: &v3endpointpb.Endpoint{ - Address: &v3corepb.Address{Address: &v3corepb.Address_SocketAddress{ - SocketAddress: &v3corepb.SocketAddress{ - Protocol: v3corepb.SocketAddress_TCP, - Address: host, - PortSpecifier: &v3corepb.SocketAddress_PortValue{PortValue: uint32(port)}}, - }}, - }}, - }}, + Locality: &v3corepb.Locality{SubZone: "subzone"}, + LbEndpoints: lbEndpoints, LoadBalancingWeight: &wrapperspb.UInt32Value{Value: 1}, Priority: 0, }}, diff --git a/xds/internal/testutils/e2e/server.go b/xds/internal/testutils/e2e/server.go index cc55c595cae..e47dcc5213c 100644 --- a/xds/internal/testutils/e2e/server.go +++ b/xds/internal/testutils/e2e/server.go @@ -33,6 +33,7 @@ import ( v3discoverygrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" "github.com/envoyproxy/go-control-plane/pkg/cache/types" v3cache "github.com/envoyproxy/go-control-plane/pkg/cache/v3" + v3resource "github.com/envoyproxy/go-control-plane/pkg/resource/v3" v3server "github.com/envoyproxy/go-control-plane/pkg/server/v3" "google.golang.org/grpc" @@ -45,10 +46,22 @@ var logger = grpclog.Component("xds-e2e") // envoyproxy/go-control-plane/pkg/log. This is passed to the Snapshot cache. type serverLogger struct{} -func (l serverLogger) Debugf(format string, args ...interface{}) { logger.Infof(format, args...) } -func (l serverLogger) Infof(format string, args ...interface{}) { logger.Infof(format, args...) } -func (l serverLogger) Warnf(format string, args ...interface{}) { logger.Warningf(format, args...) } -func (l serverLogger) Errorf(format string, args ...interface{}) { logger.Errorf(format, args...) } +func (l serverLogger) Debugf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + logger.InfoDepth(1, msg) +} +func (l serverLogger) Infof(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + logger.InfoDepth(1, msg) +} +func (l serverLogger) Warnf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + logger.WarningDepth(1, msg) +} +func (l serverLogger) Errorf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + logger.ErrorDepth(1, msg) +} // ManagementServer is a thin wrapper around the xDS control plane // implementation provided by envoyproxy/go-control-plane. @@ -113,22 +126,38 @@ type UpdateOptions struct { Clusters []*v3clusterpb.Cluster Routes []*v3routepb.RouteConfiguration Listeners []*v3listenerpb.Listener + // SkipValidation indicates whether we want to skip validation (by not + // calling snapshot.Consistent()). It can be useful for negative tests, + // where we send updates that the client will NACK. + SkipValidation bool } // Update changes the resource snapshot held by the management server, which // updates connected clients as required. -func (s *ManagementServer) Update(opts UpdateOptions) error { +func (s *ManagementServer) Update(ctx context.Context, opts UpdateOptions) error { s.version++ // Create a snapshot with the passed in resources. - snapshot := v3cache.NewSnapshot(strconv.Itoa(s.version), resourceSlice(opts.Endpoints), resourceSlice(opts.Clusters), resourceSlice(opts.Routes), resourceSlice(opts.Listeners), nil /*runtimes*/, nil /*secrets*/) - if err := snapshot.Consistent(); err != nil { - return fmt.Errorf("failed to create new resource snapshot: %v", err) + resources := map[v3resource.Type][]types.Resource{ + v3resource.ListenerType: resourceSlice(opts.Listeners), + v3resource.RouteType: resourceSlice(opts.Routes), + v3resource.ClusterType: resourceSlice(opts.Clusters), + v3resource.EndpointType: resourceSlice(opts.Endpoints), + } + snapshot, err := v3cache.NewSnapshot(strconv.Itoa(s.version), resources) + if err != nil { + return fmt.Errorf("failed to create new snapshot cache: %v", err) + + } + if !opts.SkipValidation { + if err := snapshot.Consistent(); err != nil { + return fmt.Errorf("failed to create new resource snapshot: %v", err) + } } logger.Infof("Created new resource snapshot...") // Update the cache with the new resource snapshot. - if err := s.cache.SetSnapshot(opts.NodeID, snapshot); err != nil { + if err := s.cache.SetSnapshot(ctx, opts.NodeID, snapshot); err != nil { return fmt.Errorf("failed to update resource snapshot in management server: %v", err) } logger.Infof("Updated snapshot cache with resource snapshot...") @@ -141,7 +170,6 @@ func (s *ManagementServer) Stop() { s.cancel() } s.gs.Stop() - logger.Infof("Stopped the xDS management server...") } // resourceSlice accepts a slice of any type of proto messages and returns a diff --git a/xds/internal/testutils/fakeclient/client.go b/xds/internal/testutils/fakeclient/client.go index 0978125b8ae..b582fd9bee9 100644 --- a/xds/internal/testutils/fakeclient/client.go +++ b/xds/internal/testutils/fakeclient/client.go @@ -22,15 +22,21 @@ package fakeclient import ( "context" + "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // Client is a fake implementation of an xds client. It exposes a bunch of // channels to signal the occurrence of various events. type Client struct { + // Embed XDSClient so this fake client implements the interface, but it's + // never set (it's always nil). This may cause nil panic since not all the + // methods are implemented. + xdsclient.XDSClient + name string ldsWatchCh *testutils.Channel rdsWatchCh *testutils.Channel @@ -41,14 +47,16 @@ type Client struct { cdsCancelCh *testutils.Channel edsCancelCh *testutils.Channel loadReportCh *testutils.Channel - closeCh *testutils.Channel + lrsCancelCh *testutils.Channel loadStore *load.Store bootstrapCfg *bootstrap.Config - ldsCb func(xdsclient.ListenerUpdate, error) - rdsCb func(xdsclient.RouteConfigUpdate, error) - cdsCb func(xdsclient.ClusterUpdate, error) - edsCb func(xdsclient.EndpointsUpdate, error) + ldsCb func(xdsclient.ListenerUpdate, error) + rdsCbs map[string]func(xdsclient.RouteConfigUpdate, error) + cdsCbs map[string]func(xdsclient.ClusterUpdate, error) + edsCbs map[string]func(xdsclient.EndpointsUpdate, error) + + Closed *grpcsync.Event // fired when Close is called. } // WatchListener registers a LDS watch. @@ -87,10 +95,10 @@ func (xdsC *Client) WaitForCancelListenerWatch(ctx context.Context) error { // WatchRouteConfig registers a RDS watch. func (xdsC *Client) WatchRouteConfig(routeName string, callback func(xdsclient.RouteConfigUpdate, error)) func() { - xdsC.rdsCb = callback + xdsC.rdsCbs[routeName] = callback xdsC.rdsWatchCh.Send(routeName) return func() { - xdsC.rdsCancelCh.Send(nil) + xdsC.rdsCancelCh.Send(routeName) } } @@ -108,23 +116,39 @@ func (xdsC *Client) WaitForWatchRouteConfig(ctx context.Context) (string, error) // // Not thread safe with WatchRouteConfig. Only call this after // WaitForWatchRouteConfig. -func (xdsC *Client) InvokeWatchRouteConfigCallback(update xdsclient.RouteConfigUpdate, err error) { - xdsC.rdsCb(update, err) +func (xdsC *Client) InvokeWatchRouteConfigCallback(name string, update xdsclient.RouteConfigUpdate, err error) { + if len(xdsC.rdsCbs) != 1 { + xdsC.rdsCbs[name](update, err) + return + } + // Keeps functionality with previous usage of this on client side, if single + // callback call that callback. + var routeName string + for route := range xdsC.rdsCbs { + routeName = route + } + xdsC.rdsCbs[routeName](update, err) } // WaitForCancelRouteConfigWatch waits for a RDS watch to be cancelled and returns // context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForCancelRouteConfigWatch(ctx context.Context) error { - _, err := xdsC.rdsCancelCh.Receive(ctx) - return err +func (xdsC *Client) WaitForCancelRouteConfigWatch(ctx context.Context) (string, error) { + val, err := xdsC.rdsCancelCh.Receive(ctx) + if err != nil { + return "", err + } + return val.(string), err } // WatchCluster registers a CDS watch. func (xdsC *Client) WatchCluster(clusterName string, callback func(xdsclient.ClusterUpdate, error)) func() { - xdsC.cdsCb = callback + // Due to the tree like structure of aggregate clusters, there can be multiple callbacks persisted for each cluster + // node. However, the client doesn't care about the parent child relationship between the nodes, only that it invokes + // the right callback for a particular cluster. + xdsC.cdsCbs[clusterName] = callback xdsC.cdsWatchCh.Send(clusterName) return func() { - xdsC.cdsCancelCh.Send(nil) + xdsC.cdsCancelCh.Send(clusterName) } } @@ -143,22 +167,36 @@ func (xdsC *Client) WaitForWatchCluster(ctx context.Context) (string, error) { // Not thread safe with WatchCluster. Only call this after // WaitForWatchCluster. func (xdsC *Client) InvokeWatchClusterCallback(update xdsclient.ClusterUpdate, err error) { - xdsC.cdsCb(update, err) + // Keeps functionality with previous usage of this, if single callback call that callback. + if len(xdsC.cdsCbs) == 1 { + var clusterName string + for cluster := range xdsC.cdsCbs { + clusterName = cluster + } + xdsC.cdsCbs[clusterName](update, err) + } else { + // Have what callback you call with the update determined by the service name in the ClusterUpdate. Left up to the + // caller to make sure the cluster update matches with a persisted callback. + xdsC.cdsCbs[update.ClusterName](update, err) + } } // WaitForCancelClusterWatch waits for a CDS watch to be cancelled and returns // context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForCancelClusterWatch(ctx context.Context) error { - _, err := xdsC.cdsCancelCh.Receive(ctx) - return err +func (xdsC *Client) WaitForCancelClusterWatch(ctx context.Context) (string, error) { + clusterNameReceived, err := xdsC.cdsCancelCh.Receive(ctx) + if err != nil { + return "", err + } + return clusterNameReceived.(string), err } // WatchEndpoints registers an EDS watch for provided clusterName. func (xdsC *Client) WatchEndpoints(clusterName string, callback func(xdsclient.EndpointsUpdate, error)) (cancel func()) { - xdsC.edsCb = callback + xdsC.edsCbs[clusterName] = callback xdsC.edsWatchCh.Send(clusterName) return func() { - xdsC.edsCancelCh.Send(nil) + xdsC.edsCancelCh.Send(clusterName) } } @@ -176,15 +214,28 @@ func (xdsC *Client) WaitForWatchEDS(ctx context.Context) (string, error) { // // Not thread safe with WatchEndpoints. Only call this after // WaitForWatchEDS. -func (xdsC *Client) InvokeWatchEDSCallback(update xdsclient.EndpointsUpdate, err error) { - xdsC.edsCb(update, err) +func (xdsC *Client) InvokeWatchEDSCallback(name string, update xdsclient.EndpointsUpdate, err error) { + if len(xdsC.edsCbs) != 1 { + // This may panic if name isn't found. But it's fine for tests. + xdsC.edsCbs[name](update, err) + return + } + // Keeps functionality with previous usage of this, if single callback call + // that callback. + for n := range xdsC.edsCbs { + name = n + } + xdsC.edsCbs[name](update, err) } // WaitForCancelEDSWatch waits for a EDS watch to be cancelled and returns // context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForCancelEDSWatch(ctx context.Context) error { - _, err := xdsC.edsCancelCh.Receive(ctx) - return err +func (xdsC *Client) WaitForCancelEDSWatch(ctx context.Context) (string, error) { + edsNameReceived, err := xdsC.edsCancelCh.Receive(ctx) + if err != nil { + return "", err + } + return edsNameReceived.(string), err } // ReportLoadArgs wraps the arguments passed to ReportLoad. @@ -196,7 +247,16 @@ type ReportLoadArgs struct { // ReportLoad starts reporting load about clusterName to server. func (xdsC *Client) ReportLoad(server string) (loadStore *load.Store, cancel func()) { xdsC.loadReportCh.Send(ReportLoadArgs{Server: server}) - return xdsC.loadStore, func() {} + return xdsC.loadStore, func() { + xdsC.lrsCancelCh.Send(nil) + } +} + +// WaitForCancelReportLoad waits for a load report to be cancelled and returns +// context.DeadlineExceeded otherwise. +func (xdsC *Client) WaitForCancelReportLoad(ctx context.Context) error { + _, err := xdsC.lrsCancelCh.Receive(ctx) + return err } // LoadStore returns the underlying load data store. @@ -208,19 +268,15 @@ func (xdsC *Client) LoadStore() *load.Store { // returns the arguments passed to it. func (xdsC *Client) WaitForReportLoad(ctx context.Context) (ReportLoadArgs, error) { val, err := xdsC.loadReportCh.Receive(ctx) - return val.(ReportLoadArgs), err + if err != nil { + return ReportLoadArgs{}, err + } + return val.(ReportLoadArgs), nil } -// Close closes the xds client. +// Close fires xdsC.Closed, indicating it was called. func (xdsC *Client) Close() { - xdsC.closeCh.Send(nil) -} - -// WaitForClose waits for Close to be invoked on this client and returns -// context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForClose(ctx context.Context) error { - _, err := xdsC.closeCh.Receive(ctx) - return err + xdsC.Closed.Fire() } // BootstrapConfig returns the bootstrap config. @@ -250,15 +306,19 @@ func NewClientWithName(name string) *Client { return &Client{ name: name, ldsWatchCh: testutils.NewChannel(), - rdsWatchCh: testutils.NewChannel(), - cdsWatchCh: testutils.NewChannel(), - edsWatchCh: testutils.NewChannel(), + rdsWatchCh: testutils.NewChannelWithSize(10), + cdsWatchCh: testutils.NewChannelWithSize(10), + edsWatchCh: testutils.NewChannelWithSize(10), ldsCancelCh: testutils.NewChannel(), - rdsCancelCh: testutils.NewChannel(), - cdsCancelCh: testutils.NewChannel(), - edsCancelCh: testutils.NewChannel(), + rdsCancelCh: testutils.NewChannelWithSize(10), + cdsCancelCh: testutils.NewChannelWithSize(10), + edsCancelCh: testutils.NewChannelWithSize(10), loadReportCh: testutils.NewChannel(), - closeCh: testutils.NewChannel(), + lrsCancelCh: testutils.NewChannel(), loadStore: load.NewStore(), + rdsCbs: make(map[string]func(xdsclient.RouteConfigUpdate, error)), + cdsCbs: make(map[string]func(xdsclient.ClusterUpdate, error)), + edsCbs: make(map[string]func(xdsclient.EndpointsUpdate, error)), + Closed: grpcsync.NewEvent(), } } diff --git a/xds/internal/testutils/protos.go b/xds/internal/testutils/protos.go index e0dba0e2b30..fc3cdf307fc 100644 --- a/xds/internal/testutils/protos.go +++ b/xds/internal/testutils/protos.go @@ -59,7 +59,7 @@ type ClusterLoadAssignmentBuilder struct { // NewClusterLoadAssignmentBuilder creates a ClusterLoadAssignmentBuilder. func NewClusterLoadAssignmentBuilder(clusterName string, dropPercents map[string]uint32) *ClusterLoadAssignmentBuilder { - var drops []*v2xdspb.ClusterLoadAssignment_Policy_DropOverload + drops := make([]*v2xdspb.ClusterLoadAssignment_Policy_DropOverload, 0, len(dropPercents)) for n, d := range dropPercents { drops = append(drops, &v2xdspb.ClusterLoadAssignment_Policy_DropOverload{ Category: n, @@ -88,7 +88,7 @@ type AddLocalityOptions struct { // AddLocality adds a locality to the builder. func (clab *ClusterLoadAssignmentBuilder) AddLocality(subzone string, weight uint32, priority uint32, addrsWithPort []string, opts *AddLocalityOptions) { - var lbEndPoints []*v2endpointpb.LbEndpoint + lbEndPoints := make([]*v2endpointpb.LbEndpoint, 0, len(addrsWithPort)) for i, a := range addrsWithPort { host, portStr, err := net.SplitHostPort(a) if err != nil { diff --git a/xds/internal/xdsclient/attributes.go b/xds/internal/xdsclient/attributes.go new file mode 100644 index 00000000000..d2357df0727 --- /dev/null +++ b/xds/internal/xdsclient/attributes.go @@ -0,0 +1,59 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/load" +) + +type clientKeyType string + +const clientKey = clientKeyType("grpc.xds.internal.client.Client") + +// XDSClient is a full fledged gRPC client which queries a set of discovery APIs +// (collectively termed as xDS) on a remote management server, to discover +// various dynamic resources. +type XDSClient interface { + WatchListener(string, func(ListenerUpdate, error)) func() + WatchRouteConfig(string, func(RouteConfigUpdate, error)) func() + WatchCluster(string, func(ClusterUpdate, error)) func() + WatchEndpoints(clusterName string, edsCb func(EndpointsUpdate, error)) (cancel func()) + ReportLoad(server string) (*load.Store, func()) + + DumpLDS() (string, map[string]UpdateWithMD) + DumpRDS() (string, map[string]UpdateWithMD) + DumpCDS() (string, map[string]UpdateWithMD) + DumpEDS() (string, map[string]UpdateWithMD) + + BootstrapConfig() *bootstrap.Config + Close() +} + +// FromResolverState returns the Client from state, or nil if not present. +func FromResolverState(state resolver.State) XDSClient { + cs, _ := state.Attributes.Value(clientKey).(XDSClient) + return cs +} + +// SetClient sets c in state and returns the new state. +func SetClient(state resolver.State, c XDSClient) resolver.State { + state.Attributes = state.Attributes.WithValues(clientKey, c) + return state +} diff --git a/xds/internal/client/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go similarity index 88% rename from xds/internal/client/bootstrap/bootstrap.go rename to xds/internal/xdsclient/bootstrap/bootstrap.go index c48bb797f36..fa229d99593 100644 --- a/xds/internal/client/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -35,7 +35,8 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" - "google.golang.org/grpc/xds/internal/env" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/xds/internal/version" ) @@ -79,10 +80,12 @@ type Config struct { // CertProviderConfigs contains a mapping from certificate provider plugin // instance names to parsed buildable configs. CertProviderConfigs map[string]*certprovider.BuildableConfig - // ServerResourceNameID contains the value to be used as the id in the - // resource name used to fetch the Listener resource on the xDS-enabled gRPC - // server. - ServerResourceNameID string + // ServerListenerResourceNameTemplate is a template for the name of the + // Listener resource to subscribe to for a gRPC server. If the token `%s` is + // present in the string, it will be replaced with the server's listening + // "IP:port" (e.g., "0.0.0.0:8080", "[::]:8080"). For example, a value of + // "example/resource/%s" could become "example/resource/0.0.0.0:8080". + ServerListenerResourceNameTemplate string } type channelCreds struct { @@ -124,16 +127,18 @@ func bootstrapConfigFromEnvVariable() ([]byte, error) { // // The format of the bootstrap file will be as follows: // { -// "xds_server": { -// "server_uri": , -// "channel_creds": [ -// { -// "type": , -// "config": -// } -// ], -// "server_features": [ ... ], -// }, +// "xds_servers": [ +// { +// "server_uri": , +// "channel_creds": [ +// { +// "type": , +// "config": +// } +// ], +// "server_features": [ ... ], +// } +// ], // "node": , // "certificate_providers" : { // "default": { @@ -145,7 +150,7 @@ func bootstrapConfigFromEnvVariable() ([]byte, error) { // "config": { foo plugin config in JSON } // } // }, -// "grpc_server_resource_name_id": "grpc/server" +// "server_listener_resource_name_template": "grpc/server?xds.resource.listening_address=%s" // } // // Currently, we support exactly one type of credential, which is @@ -157,13 +162,19 @@ func bootstrapConfigFromEnvVariable() ([]byte, error) { // fields left unspecified, in which case the caller should use some sane // defaults. func NewConfig() (*Config, error) { - config := &Config{} - data, err := bootstrapConfigFromEnvVariable() if err != nil { return nil, fmt.Errorf("xds: Failed to read bootstrap config: %v", err) } logger.Debugf("Bootstrap content: %s", data) + return NewConfigFromContents(data) +} + +// NewConfigFromContents returns a new Config using the specified bootstrap +// file contents instead of reading the environment variable. This is only +// suitable for testing purposes. +func NewConfigFromContents(data []byte) (*Config, error) { + config := &Config{} var jsonData map[string]json.RawMessage if err := json.Unmarshal(data, &jsonData); err != nil { @@ -241,8 +252,8 @@ func NewConfig() (*Config, error) { configs[instance] = bc } config.CertProviderConfigs = configs - case "grpc_server_resource_name_id": - if err := json.Unmarshal(v, &config.ServerResourceNameID); err != nil { + case "server_listener_resource_name_template": + if err := json.Unmarshal(v, &config.ServerListenerResourceNameTemplate); err != nil { return nil, fmt.Errorf("xds: json.Unmarshal(%v) for field %q failed during bootstrap: %v", string(v), k, err) } } @@ -268,7 +279,7 @@ func NewConfig() (*Config, error) { if err := config.updateNodeProto(); err != nil { return nil, err } - logger.Infof("Bootstrap config for creating xds-client: %+v", config) + logger.Infof("Bootstrap config for creating xds-client: %v", pretty.ToJSON(config)) return config, nil } diff --git a/xds/internal/client/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go similarity index 95% rename from xds/internal/client/bootstrap/bootstrap_test.go rename to xds/internal/xdsclient/bootstrap/bootstrap_test.go index e881594b877..501d62102d2 100644 --- a/xds/internal/client/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -36,7 +36,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" - "google.golang.org/grpc/xds/internal/env" + "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/xds/internal/version" ) @@ -240,8 +240,8 @@ func (c *Config) compare(want *Config) error { if diff := cmp.Diff(want.NodeProto, c.NodeProto, cmp.Comparer(proto.Equal)); diff != "" { return fmt.Errorf("config.NodeProto diff (-want, +got):\n%s", diff) } - if c.ServerResourceNameID != want.ServerResourceNameID { - return fmt.Errorf("config.ServerResourceNameID is %q, want %q", c.ServerResourceNameID, want.ServerResourceNameID) + if c.ServerListenerResourceNameTemplate != want.ServerListenerResourceNameTemplate { + return fmt.Errorf("config.ServerListenerResourceNameTemplate is %q, want %q", c.ServerListenerResourceNameTemplate, want.ServerListenerResourceNameTemplate) } // A vanilla cmp.Equal or cmp.Diff will not produce useful error message @@ -711,9 +711,9 @@ func TestNewConfigWithCertificateProviders(t *testing.T) { } } -func TestNewConfigWithServerResourceNameID(t *testing.T) { +func TestNewConfigWithServerListenerResourceNameTemplate(t *testing.T) { cancel := setupBootstrapOverride(map[string]string{ - "badServerResourceNameID": ` + "badServerListenerResourceNameTemplate:": ` { "node": { "id": "ENVOY_NODE_ID", @@ -727,9 +727,9 @@ func TestNewConfigWithServerResourceNameID(t *testing.T) { { "type": "google_default" } ] }], - "grpc_server_resource_name_id": 123456789 + "server_listener_resource_name_template": 123456789 }`, - "goodServerResourceNameID": ` + "goodServerListenerResourceNameTemplate": ` { "node": { "id": "ENVOY_NODE_ID", @@ -743,7 +743,7 @@ func TestNewConfigWithServerResourceNameID(t *testing.T) { { "type": "google_default" } ] }], - "grpc_server_resource_name_id": "grpc/server" + "server_listener_resource_name_template": "grpc/server?xds.resource.listening_address=%s" }`, }) defer cancel() @@ -754,17 +754,17 @@ func TestNewConfigWithServerResourceNameID(t *testing.T) { wantErr bool }{ { - name: "badServerResourceNameID", + name: "badServerListenerResourceNameTemplate", wantErr: true, }, { - name: "goodServerResourceNameID", + name: "goodServerListenerResourceNameTemplate", wantConfig: &Config{ - BalancerName: "trafficdirector.googleapis.com:443", - Creds: grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()), - TransportAPI: version.TransportV2, - NodeProto: v2NodeProto, - ServerResourceNameID: "grpc/server", + BalancerName: "trafficdirector.googleapis.com:443", + Creds: grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()), + TransportAPI: version.TransportV2, + NodeProto: v2NodeProto, + ServerListenerResourceNameTemplate: "grpc/server?xds.resource.listening_address=%s", }, }, } diff --git a/xds/internal/client/bootstrap/logging.go b/xds/internal/xdsclient/bootstrap/logging.go similarity index 100% rename from xds/internal/client/bootstrap/logging.go rename to xds/internal/xdsclient/bootstrap/logging.go diff --git a/xds/internal/client/callback.go b/xds/internal/xdsclient/callback.go similarity index 60% rename from xds/internal/client/callback.go rename to xds/internal/xdsclient/callback.go index da8e2f62d6c..0c2665e84c0 100644 --- a/xds/internal/client/callback.go +++ b/xds/internal/xdsclient/callback.go @@ -16,7 +16,12 @@ * */ -package client +package xdsclient + +import ( + "google.golang.org/grpc/internal/pretty" + "google.golang.org/protobuf/proto" +) type watcherInfoWithUpdate struct { wi *watchInfo @@ -74,39 +79,49 @@ func (c *clientImpl) callCallback(wiu *watcherInfoWithUpdate) { // // A response can contain multiple resources. They will be parsed and put in a // map from resource name to the resource content. -func (c *clientImpl) NewListeners(updates map[string]ListenerUpdate, metadata UpdateMetadata) { +func (c *clientImpl) NewListeners(updates map[string]ListenerUpdateErrTuple, metadata UpdateMetadata) { c.mu.Lock() defer c.mu.Unlock() + c.ldsVersion = metadata.Version if metadata.ErrState != nil { - // On NACK, update overall version to the NACKed resp. c.ldsVersion = metadata.ErrState.Version - for name := range updates { - if _, ok := c.ldsWatchers[name]; ok { + } + for name, uErr := range updates { + if s, ok := c.ldsWatchers[name]; ok { + if uErr.Err != nil { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.ldsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.ldsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + wi.newError(uErr.Err) + } + continue } - } - return - } - - // If no error received, the status is ACK. - c.ldsVersion = metadata.Version - for name, update := range updates { - if s, ok := c.ldsWatchers[name]; ok { - // Only send the update if this is not an error. - for wi := range s { - wi.newUpdate(update) + // If we get here, it means that the update is a valid one. Notify + // watchers only if this is a first time update or it is different + // from the one currently cached. + if cur, ok := c.ldsCache[name]; !ok || !proto.Equal(cur.Raw, uErr.Update.Raw) { + for wi := range s { + wi.newUpdate(uErr.Update) + } } // Sync cache. - c.logger.Debugf("LDS resource with name %v, value %+v added to cache", name, update) - c.ldsCache[name] = update - c.ldsMD[name] = metadata + c.logger.Debugf("LDS resource with name %v, value %+v added to cache", name, pretty.ToJSON(uErr)) + c.ldsCache[name] = uErr.Update + // Set status to ACK, and clear error state. The metadata might be a + // NACK metadata because some other resources in the same response + // are invalid. + mdCopy := metadata + mdCopy.Status = ServiceStatusACKed + mdCopy.ErrState = nil + if metadata.ErrState != nil { + mdCopy.Version = metadata.ErrState.Version + } + c.ldsMD[name] = mdCopy } } // Resources not in the new update were removed by the server, so delete @@ -133,39 +148,50 @@ func (c *clientImpl) NewListeners(updates map[string]ListenerUpdate, metadata Up // // A response can contain multiple resources. They will be parsed and put in a // map from resource name to the resource content. -func (c *clientImpl) NewRouteConfigs(updates map[string]RouteConfigUpdate, metadata UpdateMetadata) { +func (c *clientImpl) NewRouteConfigs(updates map[string]RouteConfigUpdateErrTuple, metadata UpdateMetadata) { c.mu.Lock() defer c.mu.Unlock() + // If no error received, the status is ACK. + c.rdsVersion = metadata.Version if metadata.ErrState != nil { - // On NACK, update overall version to the NACKed resp. c.rdsVersion = metadata.ErrState.Version - for name := range updates { - if _, ok := c.rdsWatchers[name]; ok { + } + for name, uErr := range updates { + if s, ok := c.rdsWatchers[name]; ok { + if uErr.Err != nil { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.rdsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.rdsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + wi.newError(uErr.Err) + } + continue } - } - return - } - - // If no error received, the status is ACK. - c.rdsVersion = metadata.Version - for name, update := range updates { - if s, ok := c.rdsWatchers[name]; ok { - // Only send the update if this is not an error. - for wi := range s { - wi.newUpdate(update) + // If we get here, it means that the update is a valid one. Notify + // watchers only if this is a first time update or it is different + // from the one currently cached. + if cur, ok := c.rdsCache[name]; !ok || !proto.Equal(cur.Raw, uErr.Update.Raw) { + for wi := range s { + wi.newUpdate(uErr.Update) + } } // Sync cache. - c.logger.Debugf("RDS resource with name %v, value %+v added to cache", name, update) - c.rdsCache[name] = update - c.rdsMD[name] = metadata + c.logger.Debugf("RDS resource with name %v, value %+v added to cache", name, pretty.ToJSON(uErr)) + c.rdsCache[name] = uErr.Update + // Set status to ACK, and clear error state. The metadata might be a + // NACK metadata because some other resources in the same response + // are invalid. + mdCopy := metadata + mdCopy.Status = ServiceStatusACKed + mdCopy.ErrState = nil + if metadata.ErrState != nil { + mdCopy.Version = metadata.ErrState.Version + } + c.rdsMD[name] = mdCopy } } } @@ -175,39 +201,51 @@ func (c *clientImpl) NewRouteConfigs(updates map[string]RouteConfigUpdate, metad // // A response can contain multiple resources. They will be parsed and put in a // map from resource name to the resource content. -func (c *clientImpl) NewClusters(updates map[string]ClusterUpdate, metadata UpdateMetadata) { +func (c *clientImpl) NewClusters(updates map[string]ClusterUpdateErrTuple, metadata UpdateMetadata) { c.mu.Lock() defer c.mu.Unlock() + c.cdsVersion = metadata.Version if metadata.ErrState != nil { - // On NACK, update overall version to the NACKed resp. c.cdsVersion = metadata.ErrState.Version - for name := range updates { - if _, ok := c.cdsWatchers[name]; ok { + } + for name, uErr := range updates { + if s, ok := c.cdsWatchers[name]; ok { + if uErr.Err != nil { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.cdsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.cdsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + // Send the watcher the individual error, instead of the + // overall combined error from the metadata.ErrState. + wi.newError(uErr.Err) + } + continue } - } - return - } - - // If no error received, the status is ACK. - c.cdsVersion = metadata.Version - for name, update := range updates { - if s, ok := c.cdsWatchers[name]; ok { - // Only send the update if this is not an error. - for wi := range s { - wi.newUpdate(update) + // If we get here, it means that the update is a valid one. Notify + // watchers only if this is a first time update or it is different + // from the one currently cached. + if cur, ok := c.cdsCache[name]; !ok || !proto.Equal(cur.Raw, uErr.Update.Raw) { + for wi := range s { + wi.newUpdate(uErr.Update) + } } // Sync cache. - c.logger.Debugf("CDS resource with name %v, value %+v added to cache", name, update) - c.cdsCache[name] = update - c.cdsMD[name] = metadata + c.logger.Debugf("CDS resource with name %v, value %+v added to cache", name, pretty.ToJSON(uErr)) + c.cdsCache[name] = uErr.Update + // Set status to ACK, and clear error state. The metadata might be a + // NACK metadata because some other resources in the same response + // are invalid. + mdCopy := metadata + mdCopy.Status = ServiceStatusACKed + mdCopy.ErrState = nil + if metadata.ErrState != nil { + mdCopy.Version = metadata.ErrState.Version + } + c.cdsMD[name] = mdCopy } } // Resources not in the new update were removed by the server, so delete @@ -234,39 +272,64 @@ func (c *clientImpl) NewClusters(updates map[string]ClusterUpdate, metadata Upda // // A response can contain multiple resources. They will be parsed and put in a // map from resource name to the resource content. -func (c *clientImpl) NewEndpoints(updates map[string]EndpointsUpdate, metadata UpdateMetadata) { +func (c *clientImpl) NewEndpoints(updates map[string]EndpointsUpdateErrTuple, metadata UpdateMetadata) { c.mu.Lock() defer c.mu.Unlock() + c.edsVersion = metadata.Version if metadata.ErrState != nil { - // On NACK, update overall version to the NACKed resp. c.edsVersion = metadata.ErrState.Version - for name := range updates { - if _, ok := c.edsWatchers[name]; ok { + } + for name, uErr := range updates { + if s, ok := c.edsWatchers[name]; ok { + if uErr.Err != nil { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.edsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.edsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + // Send the watcher the individual error, instead of the + // overall combined error from the metadata.ErrState. + wi.newError(uErr.Err) + } + continue + } + // If we get here, it means that the update is a valid one. Notify + // watchers only if this is a first time update or it is different + // from the one currently cached. + if cur, ok := c.edsCache[name]; !ok || !proto.Equal(cur.Raw, uErr.Update.Raw) { + for wi := range s { + wi.newUpdate(uErr.Update) + } } + // Sync cache. + c.logger.Debugf("EDS resource with name %v, value %+v added to cache", name, pretty.ToJSON(uErr)) + c.edsCache[name] = uErr.Update + // Set status to ACK, and clear error state. The metadata might be a + // NACK metadata because some other resources in the same response + // are invalid. + mdCopy := metadata + mdCopy.Status = ServiceStatusACKed + mdCopy.ErrState = nil + if metadata.ErrState != nil { + mdCopy.Version = metadata.ErrState.Version + } + c.edsMD[name] = mdCopy } - return } +} - // If no error received, the status is ACK. - c.edsVersion = metadata.Version - for name, update := range updates { - if s, ok := c.edsWatchers[name]; ok { - // Only send the update if this is not an error. - for wi := range s { - wi.newUpdate(update) - } - // Sync cache. - c.logger.Debugf("EDS resource with name %v, value %+v added to cache", name, update) - c.edsCache[name] = update - c.edsMD[name] = metadata +// NewConnectionError is called by the underlying xdsAPIClient when it receives +// a connection error. The error will be forwarded to all the resource watchers. +func (c *clientImpl) NewConnectionError(err error) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, s := range c.edsWatchers { + for wi := range s { + wi.newError(NewErrorf(ErrorTypeConnection, "xds: error received from xDS stream: %v", err)) } } } diff --git a/xds/internal/xdsclient/cds_test.go b/xds/internal/xdsclient/cds_test.go new file mode 100644 index 00000000000..21e3b05b908 --- /dev/null +++ b/xds/internal/xdsclient/cds_test.go @@ -0,0 +1,1590 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "regexp" + "strings" + "testing" + + v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" + v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + v3aggregateclusterpb "github.com/envoyproxy/go-control-plane/envoy/extensions/clusters/aggregate/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + anypb "github.com/golang/protobuf/ptypes/any" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/internal/xds/matcher" + "google.golang.org/grpc/xds/internal/version" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +const ( + clusterName = "clusterName" + serviceName = "service" +) + +var emptyUpdate = ClusterUpdate{ClusterName: clusterName, EnableLRS: false} + +func (s) TestValidateCluster_Failure(t *testing.T) { + tests := []struct { + name string + cluster *v3clusterpb.Cluster + wantUpdate ClusterUpdate + wantErr bool + }{ + { + name: "non-supported-cluster-type-static", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + }, + LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "non-supported-cluster-type-original-dst", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_ORIGINAL_DST}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + }, + LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "no-eds-config", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "no-ads-config-source", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{}, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "non-round-robin-or-ring-hash-lb-policy", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + }, + LbPolicy: v3clusterpb.Cluster_LEAST_REQUEST, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "logical-dns-multiple-localities", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_LOGICAL_DNS}, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + LoadAssignment: &v3endpointpb.ClusterLoadAssignment{ + Endpoints: []*v3endpointpb.LocalityLbEndpoints{ + // Invalid if there are more than one locality. + {LbEndpoints: nil}, + {LbEndpoints: nil}, + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "ring-hash-hash-function-not-xx-hash", + cluster: &v3clusterpb.Cluster{ + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + HashFunction: v3clusterpb.Cluster_RingHashLbConfig_MURMUR_HASH_2, + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "ring-hash-min-bound-greater-than-max", + cluster: &v3clusterpb.Cluster{ + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + MinimumRingSize: wrapperspb.UInt64(100), + MaximumRingSize: wrapperspb.UInt64(10), + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "ring-hash-min-bound-greater-than-upper-bound", + cluster: &v3clusterpb.Cluster{ + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + MinimumRingSize: wrapperspb.UInt64(ringHashSizeUpperBound + 1), + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + { + name: "ring-hash-max-bound-greater-than-upper-bound", + cluster: &v3clusterpb.Cluster{ + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + MaximumRingSize: wrapperspb.UInt64(ringHashSizeUpperBound + 1), + }, + }, + }, + wantUpdate: emptyUpdate, + wantErr: true, + }, + } + + oldAggregateAndDNSSupportEnv := env.AggregateAndDNSSupportEnv + env.AggregateAndDNSSupportEnv = true + defer func() { env.AggregateAndDNSSupportEnv = oldAggregateAndDNSSupportEnv }() + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if update, err := validateClusterAndConstructClusterUpdate(test.cluster); err == nil { + t.Errorf("validateClusterAndConstructClusterUpdate(%+v) = %v, wanted error", test.cluster, update) + } + }) + } +} + +func (s) TestValidateCluster_Success(t *testing.T) { + tests := []struct { + name string + cluster *v3clusterpb.Cluster + wantUpdate ClusterUpdate + }{ + { + name: "happy-case-logical-dns", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_LOGICAL_DNS}, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + LoadAssignment: &v3endpointpb.ClusterLoadAssignment{ + Endpoints: []*v3endpointpb.LocalityLbEndpoints{{ + LbEndpoints: []*v3endpointpb.LbEndpoint{{ + HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{ + Endpoint: &v3endpointpb.Endpoint{ + Address: &v3corepb.Address{ + Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Address: "dns_host", + PortSpecifier: &v3corepb.SocketAddress_PortValue{ + PortValue: 8080, + }, + }, + }, + }, + }, + }, + }}, + }}, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + ClusterType: ClusterTypeLogicalDNS, + DNSHostName: "dns_host:8080", + }, + }, + { + name: "happy-case-aggregate-v3", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_ClusterType{ + ClusterType: &v3clusterpb.Cluster_CustomClusterType{ + Name: "envoy.clusters.aggregate", + TypedConfig: testutils.MarshalAny(&v3aggregateclusterpb.ClusterConfig{ + Clusters: []string{"a", "b", "c"}, + }), + }, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, EnableLRS: false, ClusterType: ClusterTypeAggregate, + PrioritizedClusterNames: []string{"a", "b", "c"}, + }, + }, + { + name: "happy-case-no-service-name-no-lrs", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: emptyUpdate, + }, + { + name: "happy-case-no-lrs", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + }, + wantUpdate: ClusterUpdate{ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: false}, + }, + { + name: "happiest-case", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }, + wantUpdate: ClusterUpdate{ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: true}, + }, + { + name: "happiest-case-with-circuitbreakers", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + CircuitBreakers: &v3clusterpb.CircuitBreakers{ + Thresholds: []*v3clusterpb.CircuitBreakers_Thresholds{ + { + Priority: v3corepb.RoutingPriority_DEFAULT, + MaxRequests: wrapperspb.UInt32(512), + }, + { + Priority: v3corepb.RoutingPriority_HIGH, + MaxRequests: nil, + }, + }, + }, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }, + wantUpdate: ClusterUpdate{ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: true, MaxRequests: func() *uint32 { i := uint32(512); return &i }()}, + }, + { + name: "happiest-case-with-ring-hash-lb-policy-with-default-config", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: true, + LBPolicy: &ClusterLBPolicyRingHash{MinimumRingSize: defaultRingHashMinSize, MaximumRingSize: defaultRingHashMaxSize}, + }, + }, + { + name: "happiest-case-with-ring-hash-lb-policy-with-none-default-config", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_RING_HASH, + LbConfig: &v3clusterpb.Cluster_RingHashLbConfig_{ + RingHashLbConfig: &v3clusterpb.Cluster_RingHashLbConfig{ + MinimumRingSize: wrapperspb.UInt64(10), + MaximumRingSize: wrapperspb.UInt64(100), + }, + }, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, EDSServiceName: serviceName, EnableLRS: true, + LBPolicy: &ClusterLBPolicyRingHash{MinimumRingSize: 10, MaximumRingSize: 100}, + }, + }, + } + + oldAggregateAndDNSSupportEnv := env.AggregateAndDNSSupportEnv + env.AggregateAndDNSSupportEnv = true + defer func() { env.AggregateAndDNSSupportEnv = oldAggregateAndDNSSupportEnv }() + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + update, err := validateClusterAndConstructClusterUpdate(test.cluster) + if err != nil { + t.Errorf("validateClusterAndConstructClusterUpdate(%+v) failed: %v", test.cluster, err) + } + if diff := cmp.Diff(update, test.wantUpdate, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("validateClusterAndConstructClusterUpdate(%+v) got diff: %v (-got, +want)", test.cluster, diff) + } + }) + } +} + +func (s) TestValidateClusterWithSecurityConfig_EnvVarOff(t *testing.T) { + // Turn off the env var protection for client-side security. + origClientSideSecurityEnvVar := env.ClientSideSecuritySupport + env.ClientSideSecuritySupport = false + defer func() { env.ClientSideSecuritySupport = origClientSideSecurityEnvVar }() + + cluster := &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "rootInstance", + CertificateName: "rootCert", + }, + }, + }, + }), + }, + }, + } + wantUpdate := ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + } + gotUpdate, err := validateClusterAndConstructClusterUpdate(cluster) + if err != nil { + t.Errorf("validateClusterAndConstructClusterUpdate() failed: %v", err) + } + if diff := cmp.Diff(wantUpdate, gotUpdate); diff != "" { + t.Errorf("validateClusterAndConstructClusterUpdate() returned unexpected diff (-want, got):\n%s", diff) + } +} + +func (s) TestSecurityConfigFromCommonTLSContextUsingNewFields_ErrorCases(t *testing.T) { + tests := []struct { + name string + common *v3tlspb.CommonTlsContext + server bool + wantErr string + }{ + { + name: "unsupported-tls_certificates-field-for-identity-certs", + common: &v3tlspb.CommonTlsContext{ + TlsCertificates: []*v3tlspb.TlsCertificate{ + {CertificateChain: &v3corepb.DataSource{}}, + }, + }, + wantErr: "unsupported field tls_certificates is set in CommonTlsContext message", + }, + { + name: "unsupported-tls_certificates_sds_secret_configs-field-for-identity-certs", + common: &v3tlspb.CommonTlsContext{ + TlsCertificateSdsSecretConfigs: []*v3tlspb.SdsSecretConfig{ + {Name: "sds-secrets-config"}, + }, + }, + wantErr: "unsupported field tls_certificate_sds_secret_configs is set in CommonTlsContext message", + }, + { + name: "unsupported-sds-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + wantErr: "validation context contains unexpected type", + }, + { + name: "missing-ca_certificate_provider_instance-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{}, + }, + }, + wantErr: "expected field ca_certificate_provider_instance is missing in CommonTlsContext message", + }, + { + name: "unsupported-field-verify_certificate_spki-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + VerifyCertificateSpki: []string{"spki"}, + }, + }, + }, + wantErr: "unsupported verify_certificate_spki field in CommonTlsContext message", + }, + { + name: "unsupported-field-verify_certificate_hash-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + VerifyCertificateHash: []string{"hash"}, + }, + }, + }, + wantErr: "unsupported verify_certificate_hash field in CommonTlsContext message", + }, + { + name: "unsupported-field-require_signed_certificate_timestamp-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + RequireSignedCertificateTimestamp: &wrapperspb.BoolValue{Value: true}, + }, + }, + }, + wantErr: "unsupported require_sugned_ceritificate_timestamp field in CommonTlsContext message", + }, + { + name: "unsupported-field-crl-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + Crl: &v3corepb.DataSource{}, + }, + }, + }, + wantErr: "unsupported crl field in CommonTlsContext message", + }, + { + name: "unsupported-field-custom_validator_config-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + CustomValidatorConfig: &v3corepb.TypedExtensionConfig{}, + }, + }, + }, + wantErr: "unsupported custom_validator_config field in CommonTlsContext message", + }, + { + name: "invalid-match_subject_alt_names-field-in-validation-context", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}}, + }, + }, + }, + }, + wantErr: "empty prefix is not allowed in StringMatcher", + }, + { + name: "unsupported-field-matching-subject-alt-names-in-validation-context-of-server", + common: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "sanPrefix"}}, + }, + }, + }, + }, + server: true, + wantErr: "match_subject_alt_names field in validation context is not supported on the server", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := securityConfigFromCommonTLSContextUsingNewFields(test.common, test.server) + if err == nil { + t.Fatal("securityConfigFromCommonTLSContextUsingNewFields() succeeded when expected to fail") + } + if !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("securityConfigFromCommonTLSContextUsingNewFields() returned err: %v, wantErr: %v", err, test.wantErr) + } + }) + } +} + +func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { + const ( + identityPluginInstance = "identityPluginInstance" + identityCertName = "identityCert" + rootPluginInstance = "rootPluginInstance" + rootCertName = "rootCert" + clusterName = "cluster" + serviceName = "service" + sanExact = "san-exact" + sanPrefix = "san-prefix" + sanSuffix = "san-suffix" + sanRegexBad = "??" + sanRegexGood = "san?regex?" + sanContains = "san-contains" + ) + var sanRE = regexp.MustCompile(sanRegexGood) + + tests := []struct { + name string + cluster *v3clusterpb.Cluster + wantUpdate ClusterUpdate + wantErr bool + }{ + { + name: "transport-socket-matches", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocketMatches: []*v3clusterpb.Cluster_TransportSocketMatch{ + {Name: "transport-socket-match-1"}, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-name", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "unsupported-foo", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-typeURL", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3HTTPConnManagerURL, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-type", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-tls-params-field", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsParams: &v3tlspb.TlsParameters{}, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-custom-handshaker-field", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + CustomHandshaker: &v3corepb.TypedExtensionConfig{}, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-unsupported-validation-context", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "transport-socket-without-validation-context", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{}, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "empty-prefix-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "empty-suffix-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "empty-contains-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "invalid-regex-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexBad}}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "invalid-regex-in-matching-SAN-with-new-fields", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexBad}}}, + }, + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }, + }), + }, + }, + }, + wantErr: true, + }, + { + name: "happy-case-with-no-identity-certs-using-deprecated-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + }, + }, + }, + { + name: "happy-case-with-no-identity-certs-using-new-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + }, + }, + }, + { + name: "happy-case-with-validation-context-provider-instance-using-deprecated-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: identityPluginInstance, + CertificateName: identityCertName, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + IdentityInstanceName: identityPluginInstance, + IdentityCertName: identityCertName, + }, + }, + }, + { + name: "happy-case-with-validation-context-provider-instance-using-new-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: identityPluginInstance, + CertificateName: identityCertName, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + IdentityInstanceName: identityPluginInstance, + IdentityCertName: identityCertName, + }, + }, + }, + { + name: "happy-case-with-combined-validation-context-using-deprecated-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: identityPluginInstance, + CertificateName: identityCertName, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + { + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: sanExact}, + IgnoreCase: true, + }, + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: sanPrefix}}, + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: sanSuffix}}, + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexGood}}}, + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: sanContains}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + IdentityInstanceName: identityPluginInstance, + IdentityCertName: identityCertName, + SubjectAltNameMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP(sanExact), nil, nil, nil, nil, true), + matcher.StringMatcherForTesting(nil, newStringP(sanPrefix), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP(sanSuffix), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, sanRE, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP(sanContains), nil, false), + }, + }, + }, + }, + { + name: "happy-case-with-combined-validation-context-using-new-fields", + cluster: &v3clusterpb.Cluster{ + Name: clusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: identityPluginInstance, + CertificateName: identityCertName, + }, + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + { + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: sanExact}, + IgnoreCase: true, + }, + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: sanPrefix}}, + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: sanSuffix}}, + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexGood}}}, + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: sanContains}}, + }, + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + }, + }), + }, + }, + }, + wantUpdate: ClusterUpdate{ + ClusterName: clusterName, + EDSServiceName: serviceName, + EnableLRS: false, + SecurityCfg: &SecurityConfig{ + RootInstanceName: rootPluginInstance, + RootCertName: rootCertName, + IdentityInstanceName: identityPluginInstance, + IdentityCertName: identityCertName, + SubjectAltNameMatchers: []matcher.StringMatcher{ + matcher.StringMatcherForTesting(newStringP(sanExact), nil, nil, nil, nil, true), + matcher.StringMatcherForTesting(nil, newStringP(sanPrefix), nil, nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, newStringP(sanSuffix), nil, nil, false), + matcher.StringMatcherForTesting(nil, nil, nil, nil, sanRE, false), + matcher.StringMatcherForTesting(nil, nil, nil, newStringP(sanContains), nil, false), + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + update, err := validateClusterAndConstructClusterUpdate(test.cluster) + if (err != nil) != test.wantErr { + t.Errorf("validateClusterAndConstructClusterUpdate() returned err %v wantErr %v)", err, test.wantErr) + } + if diff := cmp.Diff(test.wantUpdate, update, cmpopts.EquateEmpty(), cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Errorf("validateClusterAndConstructClusterUpdate() returned unexpected diff (-want, +got):\n%s", diff) + } + }) + } +} + +func (s) TestUnmarshalCluster(t *testing.T) { + const ( + v2ClusterName = "v2clusterName" + v3ClusterName = "v3clusterName" + v2Service = "v2Service" + v3Service = "v2Service" + ) + var ( + v2ClusterAny = testutils.MarshalAny(&v2xdspb.Cluster{ + Name: v2ClusterName, + ClusterDiscoveryType: &v2xdspb.Cluster_Type{Type: v2xdspb.Cluster_EDS}, + EdsClusterConfig: &v2xdspb.Cluster_EdsClusterConfig{ + EdsConfig: &v2corepb.ConfigSource{ + ConfigSourceSpecifier: &v2corepb.ConfigSource_Ads{ + Ads: &v2corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: v2Service, + }, + LbPolicy: v2xdspb.Cluster_ROUND_ROBIN, + LrsServer: &v2corepb.ConfigSource{ + ConfigSourceSpecifier: &v2corepb.ConfigSource_Self{ + Self: &v2corepb.SelfConfigSource{}, + }, + }, + }) + + v3ClusterAny = testutils.MarshalAny(&v3clusterpb.Cluster{ + Name: v3ClusterName, + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: v3Service, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + LrsServer: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Self{ + Self: &v3corepb.SelfConfigSource{}, + }, + }, + }) + ) + const testVersion = "test-version-cds" + + tests := []struct { + name string + resources []*anypb.Any + wantUpdate map[string]ClusterUpdateErrTuple + wantMD UpdateMetadata + wantErr bool + }{ + { + name: "non-cluster resource type", + resources: []*anypb.Any{{TypeUrl: version.V3HTTPConnManagerURL}}, + wantMD: UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + }, + wantErr: true, + }, + { + name: "badly marshaled cluster resource", + resources: []*anypb.Any{ + { + TypeUrl: version.V3ClusterURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + }, + wantErr: true, + }, + { + name: "bad cluster resource", + resources: []*anypb.Any{ + testutils.MarshalAny(&v3clusterpb.Cluster{ + Name: "test", + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, + }), + }, + wantUpdate: map[string]ClusterUpdateErrTuple{ + "test": {Err: cmpopts.AnyError}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + }, + wantErr: true, + }, + { + name: "v2 cluster", + resources: []*anypb.Any{v2ClusterAny}, + wantUpdate: map[string]ClusterUpdateErrTuple{ + v2ClusterName: {Update: ClusterUpdate{ + ClusterName: v2ClusterName, + EDSServiceName: v2Service, EnableLRS: true, + Raw: v2ClusterAny, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 cluster", + resources: []*anypb.Any{v3ClusterAny}, + wantUpdate: map[string]ClusterUpdateErrTuple{ + v3ClusterName: {Update: ClusterUpdate{ + ClusterName: v3ClusterName, + EDSServiceName: v3Service, EnableLRS: true, + Raw: v3ClusterAny, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "multiple clusters", + resources: []*anypb.Any{v2ClusterAny, v3ClusterAny}, + wantUpdate: map[string]ClusterUpdateErrTuple{ + v2ClusterName: {Update: ClusterUpdate{ + ClusterName: v2ClusterName, + EDSServiceName: v2Service, EnableLRS: true, + Raw: v2ClusterAny, + }}, + v3ClusterName: {Update: ClusterUpdate{ + ClusterName: v3ClusterName, + EDSServiceName: v3Service, EnableLRS: true, + Raw: v3ClusterAny, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + // To test that unmarshal keeps processing on errors. + name: "good and bad clusters", + resources: []*anypb.Any{ + v2ClusterAny, + // bad cluster resource + testutils.MarshalAny(&v3clusterpb.Cluster{ + Name: "bad", + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_STATIC}, + }), + v3ClusterAny, + }, + wantUpdate: map[string]ClusterUpdateErrTuple{ + v2ClusterName: {Update: ClusterUpdate{ + ClusterName: v2ClusterName, + EDSServiceName: v2Service, EnableLRS: true, + Raw: v2ClusterAny, + }}, + v3ClusterName: {Update: ClusterUpdate{ + ClusterName: v3ClusterName, + EDSServiceName: v3Service, EnableLRS: true, + Raw: v3ClusterAny, + }}, + "bad": {Err: cmpopts.AnyError}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + }, + wantErr: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + update, md, err := UnmarshalCluster(opts) + if (err != nil) != test.wantErr { + t.Fatalf("UnmarshalCluster(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) + } + if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { + t.Errorf("got unexpected update, diff (-got +want): %v", diff) + } + if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { + t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) + } + }) + } +} diff --git a/xds/internal/client/client.go b/xds/internal/xdsclient/client.go similarity index 68% rename from xds/internal/client/client.go rename to xds/internal/xdsclient/client.go index 21881cc6eae..3230c66c06e 100644 --- a/xds/internal/client/client.go +++ b/xds/internal/xdsclient/client.go @@ -16,14 +16,15 @@ * */ -// Package client implements a full fledged gRPC client for the xDS API used by -// the xds resolver and balancer implementations. -package client +// Package xdsclient implements a full fledged gRPC client for the xDS API used +// by the xds resolver and balancer implementations. +package xdsclient import ( "context" "errors" "fmt" + "regexp" "sync" "time" @@ -32,8 +33,10 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/xdsclient/load" "google.golang.org/grpc" "google.golang.org/grpc/internal/backoff" @@ -42,8 +45,8 @@ import ( "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) var ( @@ -69,12 +72,20 @@ func getAPIClientBuilder(version version.TransportAPI) APIClientBuilder { return nil } +// UpdateValidatorFunc performs validations on update structs using +// context/logic available at the xdsClient layer. Since these validation are +// performed on internal update structs, they can be shared between different +// API clients. +type UpdateValidatorFunc func(interface{}) error + // BuildOptions contains options to be passed to client builders. type BuildOptions struct { // Parent is a top-level xDS client which has the intelligence to take // appropriate action based on xDS responses received from the management // server. Parent UpdateHandler + // Validator performs post unmarshal validation checks. + Validator UpdateValidatorFunc // NodeProto contains the Node proto to be used in xDS requests. The actual // type depends on the transport protocol version used. NodeProto proto.Message @@ -131,14 +142,17 @@ type loadReportingOptions struct { // resource updates from an APIClient for a specific version. type UpdateHandler interface { // NewListeners handles updates to xDS listener resources. - NewListeners(map[string]ListenerUpdate, UpdateMetadata) + NewListeners(map[string]ListenerUpdateErrTuple, UpdateMetadata) // NewRouteConfigs handles updates to xDS RouteConfiguration resources. - NewRouteConfigs(map[string]RouteConfigUpdate, UpdateMetadata) + NewRouteConfigs(map[string]RouteConfigUpdateErrTuple, UpdateMetadata) // NewClusters handles updates to xDS Cluster resources. - NewClusters(map[string]ClusterUpdate, UpdateMetadata) + NewClusters(map[string]ClusterUpdateErrTuple, UpdateMetadata) // NewEndpoints handles updates to xDS ClusterLoadAssignment (or tersely // referred to as Endpoints) resources. - NewEndpoints(map[string]EndpointsUpdate, UpdateMetadata) + NewEndpoints(map[string]EndpointsUpdateErrTuple, UpdateMetadata) + // NewConnectionError handles connection errors from the xDS stream. The + // error will be reported to all the resource watchers. + NewConnectionError(err error) } // ServiceStatus is the status of the update. @@ -193,9 +207,15 @@ type UpdateMetadata struct { type ListenerUpdate struct { // RouteConfigName is the route configuration name corresponding to the // target which is being watched through LDS. + // + // Only one of RouteConfigName and InlineRouteConfig is set. RouteConfigName string - // SecurityCfg contains security configuration sent by the control plane. - SecurityCfg *SecurityConfig + // InlineRouteConfig is the inline route configuration (RDS response) + // returned inside LDS. + // + // Only one of RouteConfigName and InlineRouteConfig is set. + InlineRouteConfig *RouteConfigUpdate + // MaxStreamDuration contains the HTTP connection manager's // common_http_protocol_options.max_stream_duration field, or zero if // unset. @@ -203,6 +223,8 @@ type ListenerUpdate struct { // HTTPFilters is a list of HTTP filters (name, config) from the LDS // response. HTTPFilters []HTTPFilter + // InboundListenerCfg contains inbound listener configuration. + InboundListenerCfg *InboundListenerConfig // Raw is the resource from the xds response. Raw *anypb.Any @@ -221,15 +243,23 @@ type HTTPFilter struct { Config httpfilter.FilterConfig } -func (lu *ListenerUpdate) String() string { - return fmt.Sprintf("{RouteConfigName: %q, SecurityConfig: %+v", lu.RouteConfigName, lu.SecurityCfg) +// InboundListenerConfig contains information about the inbound listener, i.e +// the server-side listener. +type InboundListenerConfig struct { + // Address is the local address on which the inbound listener is expected to + // accept incoming connections. + Address string + // Port is the local port on which the inbound listener is expected to + // accept incoming connections. + Port string + // FilterChains is the list of filter chains associated with this listener. + FilterChains *FilterChainManager } // RouteConfigUpdate contains information received in an RDS response, which is // of interest to the registered RDS watcher. type RouteConfigUpdate struct { VirtualHosts []*VirtualHost - // Raw is the resource from the xds response. Raw *anypb.Any } @@ -248,18 +278,81 @@ type VirtualHost struct { // may be unused if the matching Route contains an override for that // filter. HTTPFilterConfigOverride map[string]httpfilter.FilterConfig + RetryConfig *RetryConfig +} + +// RetryConfig contains all retry-related configuration in either a VirtualHost +// or Route. +type RetryConfig struct { + // RetryOn is a set of status codes on which to retry. Only Canceled, + // DeadlineExceeded, Internal, ResourceExhausted, and Unavailable are + // supported; any other values will be omitted. + RetryOn map[codes.Code]bool + NumRetries uint32 // maximum number of retry attempts + RetryBackoff RetryBackoff // retry backoff policy +} + +// RetryBackoff describes the backoff policy for retries. +type RetryBackoff struct { + BaseInterval time.Duration // initial backoff duration between attempts + MaxInterval time.Duration // maximum backoff duration +} + +// HashPolicyType specifies the type of HashPolicy from a received RDS Response. +type HashPolicyType int + +const ( + // HashPolicyTypeHeader specifies to hash a Header in the incoming request. + HashPolicyTypeHeader HashPolicyType = iota + // HashPolicyTypeChannelID specifies to hash a unique Identifier of the + // Channel. In grpc-go, this will be done using the ClientConn pointer. + HashPolicyTypeChannelID +) + +// HashPolicy specifies the HashPolicy if the upstream cluster uses a hashing +// load balancer. +type HashPolicy struct { + HashPolicyType HashPolicyType + Terminal bool + // Fields used for type HEADER. + HeaderName string + Regex *regexp.Regexp + RegexSubstitution string } +// RouteAction is the action of the route from a received RDS response. +type RouteAction int + +const ( + // RouteActionUnsupported are routing types currently unsupported by grpc. + // According to A36, "A Route with an inappropriate action causes RPCs + // matching that route to fail." + RouteActionUnsupported RouteAction = iota + // RouteActionRoute is the expected route type on the client side. Route + // represents routing a request to some upstream cluster. On the client + // side, if an RPC matches to a route that is not RouteActionRoute, the RPC + // will fail according to A36. + RouteActionRoute + // RouteActionNonForwardingAction is the expected route type on the server + // side. NonForwardingAction represents when a route will generate a + // response directly, without forwarding to an upstream host. + RouteActionNonForwardingAction +) + // Route is both a specification of how to match a request as well as an // indication of the action to take upon match. type Route struct { - Path, Prefix, Regex *string + Path *string + Prefix *string + Regex *regexp.Regexp // Indicates if prefix/path matching should be case insensitive. The default // is false (case sensitive). CaseInsensitive bool Headers []*HeaderMatcher Fraction *uint32 + HashPolicies []*HashPolicy + // If the matchers above indicate a match, the below configuration is used. WeightedClusters map[string]WeightedCluster // If MaxStreamDuration is nil, it indicates neither of the route action's @@ -273,6 +366,9 @@ type Route struct { // unused if the matching WeightedCluster contains an override for that // filter. HTTPFilterConfigOverride map[string]httpfilter.FilterConfig + RetryConfig *RetryConfig + + RouteAction RouteAction } // WeightedCluster contains settings for an xds RouteAction.WeightedCluster. @@ -286,20 +382,20 @@ type WeightedCluster struct { // HeaderMatcher represents header matchers. type HeaderMatcher struct { - Name string `json:"name"` - InvertMatch *bool `json:"invertMatch,omitempty"` - ExactMatch *string `json:"exactMatch,omitempty"` - RegexMatch *string `json:"regexMatch,omitempty"` - PrefixMatch *string `json:"prefixMatch,omitempty"` - SuffixMatch *string `json:"suffixMatch,omitempty"` - RangeMatch *Int64Range `json:"rangeMatch,omitempty"` - PresentMatch *bool `json:"presentMatch,omitempty"` + Name string + InvertMatch *bool + ExactMatch *string + RegexMatch *regexp.Regexp + PrefixMatch *string + SuffixMatch *string + RangeMatch *Int64Range + PresentMatch *bool } // Int64Range is a range for header range match. type Int64Range struct { - Start int64 `json:"start"` - End int64 `json:"end"` + Start int64 + End int64 } // SecurityConfig contains the security configuration received as part of the @@ -322,29 +418,107 @@ type SecurityConfig struct { // IdentityCertName is the certificate name to be passed to the plugin // (looked up from the bootstrap file) while fetching identity certificates. IdentityCertName string - // AcceptedSANs is a list of Subject Alternative Names. During the TLS - // handshake, the SAN present in the peer certificate is compared against - // this list, and the handshake succeeds only if a match is found. Used only - // on the client-side. - AcceptedSANs []string + // SubjectAltNameMatchers is an optional list of match criteria for SANs + // specified on the peer certificate. Used only on the client-side. + // + // Some intricacies: + // - If this field is empty, then any peer certificate is accepted. + // - If the peer certificate contains a wildcard DNS SAN, and an `exact` + // matcher is configured, a wildcard DNS match is performed instead of a + // regular string comparison. + SubjectAltNameMatchers []matcher.StringMatcher // RequireClientCert indicates if the server handshake process expects the // client to present a certificate. Set to true when performing mTLS. Used // only on the server-side. RequireClientCert bool } +// Equal returns true if sc is equal to other. +func (sc *SecurityConfig) Equal(other *SecurityConfig) bool { + switch { + case sc == nil && other == nil: + return true + case (sc != nil) != (other != nil): + return false + } + switch { + case sc.RootInstanceName != other.RootInstanceName: + return false + case sc.RootCertName != other.RootCertName: + return false + case sc.IdentityInstanceName != other.IdentityInstanceName: + return false + case sc.IdentityCertName != other.IdentityCertName: + return false + case sc.RequireClientCert != other.RequireClientCert: + return false + default: + if len(sc.SubjectAltNameMatchers) != len(other.SubjectAltNameMatchers) { + return false + } + for i := 0; i < len(sc.SubjectAltNameMatchers); i++ { + if !sc.SubjectAltNameMatchers[i].Equal(other.SubjectAltNameMatchers[i]) { + return false + } + } + } + return true +} + +// ClusterType is the type of cluster from a received CDS response. +type ClusterType int + +const ( + // ClusterTypeEDS represents the EDS cluster type, which will delegate endpoint + // discovery to the management server. + ClusterTypeEDS ClusterType = iota + // ClusterTypeLogicalDNS represents the Logical DNS cluster type, which essentially + // maps to the gRPC behavior of using the DNS resolver with pick_first LB policy. + ClusterTypeLogicalDNS + // ClusterTypeAggregate represents the Aggregate Cluster type, which provides a + // prioritized list of clusters to use. It is used for failover between clusters + // with a different configuration. + ClusterTypeAggregate +) + +// ClusterLBPolicyRingHash represents ring_hash lb policy, and also contains its +// config. +type ClusterLBPolicyRingHash struct { + MinimumRingSize uint64 + MaximumRingSize uint64 +} + // ClusterUpdate contains information from a received CDS response, which is of // interest to the registered CDS watcher. type ClusterUpdate struct { - // ServiceName is the service name corresponding to the clusterName which - // is being watched for through CDS. - ServiceName string + ClusterType ClusterType + // ClusterName is the clusterName being watched for through CDS. + ClusterName string + // EDSServiceName is an optional name for EDS. If it's not set, the balancer + // should watch ClusterName for the EDS resources. + EDSServiceName string // EnableLRS indicates whether or not load should be reported through LRS. EnableLRS bool // SecurityCfg contains security configuration sent by the control plane. SecurityCfg *SecurityConfig // MaxRequests for circuit breaking, if any (otherwise nil). MaxRequests *uint32 + // DNSHostName is used only for cluster type DNS. It's the DNS name to + // resolve in "host:port" form + DNSHostName string + // PrioritizedClusterNames is used only for cluster type aggregate. It represents + // a prioritized list of cluster names. + PrioritizedClusterNames []string + + // LBPolicy is the lb policy for this cluster. + // + // This only support round_robin and ring_hash. + // - if it's nil, the lb policy is round_robin + // - if it's not nil, the lb policy is ring_hash, the this field has the config. + // + // When we add more support policies, this can be made an interface, and + // will be set to different types based on the policy type. + LBPolicy *ClusterLBPolicyRingHash // Raw is the resource from the xds response. Raw *anypb.Any @@ -514,6 +688,7 @@ func newWithConfig(config *bootstrap.Config, watchExpiryTimeout time.Duration) ( apiClient, err := newAPIClient(config.TransportAPI, cc, BuildOptions{ Parent: c, + Validator: c.updateValidator, NodeProto: config.NodeProto, Backoff: backoff.DefaultExponential.Backoff, Logger: c.logger, @@ -529,7 +704,7 @@ func newWithConfig(config *bootstrap.Config, watchExpiryTimeout time.Duration) ( // BootstrapConfig returns the configuration read from the bootstrap file. // Callers must treat the return value as read-only. -func (c *Client) BootstrapConfig() *bootstrap.Config { +func (c *clientRefCounted) BootstrapConfig() *bootstrap.Config { return c.config } @@ -567,6 +742,64 @@ func (c *clientImpl) Close() { c.logger.Infof("Shutdown") } +func (c *clientImpl) filterChainUpdateValidator(fc *FilterChain) error { + if fc == nil { + return nil + } + return c.securityConfigUpdateValidator(fc.SecurityCfg) +} + +func (c *clientImpl) securityConfigUpdateValidator(sc *SecurityConfig) error { + if sc == nil { + return nil + } + if sc.IdentityInstanceName != "" { + if _, ok := c.config.CertProviderConfigs[sc.IdentityInstanceName]; !ok { + return fmt.Errorf("identitiy certificate provider instance name %q missing in bootstrap configuration", sc.IdentityInstanceName) + } + } + if sc.RootInstanceName != "" { + if _, ok := c.config.CertProviderConfigs[sc.RootInstanceName]; !ok { + return fmt.Errorf("root certificate provider instance name %q missing in bootstrap configuration", sc.RootInstanceName) + } + } + return nil +} + +func (c *clientImpl) updateValidator(u interface{}) error { + switch update := u.(type) { + case ListenerUpdate: + if update.InboundListenerCfg == nil || update.InboundListenerCfg.FilterChains == nil { + return nil + } + + fcm := update.InboundListenerCfg.FilterChains + for _, dst := range fcm.dstPrefixMap { + for _, srcType := range dst.srcTypeArr { + if srcType == nil { + continue + } + for _, src := range srcType.srcPrefixMap { + for _, fc := range src.srcPortMap { + if err := c.filterChainUpdateValidator(fc); err != nil { + return err + } + } + } + } + } + return c.filterChainUpdateValidator(fcm.def) + case ClusterUpdate: + return c.securityConfigUpdateValidator(update.SecurityCfg) + default: + // We currently invoke this update validation function only for LDS and + // CDS updates. In the future, if we wish to invoke it for other xDS + // updates, corresponding plumbing needs to be added to those unmarshal + // functions. + } + return nil +} + // ResourceType identifies resources in a transport protocol agnostic way. These // will be used in transport version agnostic code, while the versioned API // clients will map these to appropriate version URLs. diff --git a/xds/internal/client/client_test.go b/xds/internal/xdsclient/client_test.go similarity index 78% rename from xds/internal/client/client_test.go rename to xds/internal/xdsclient/client_test.go index 8275ea60e0d..7c3423cd5ad 100644 --- a/xds/internal/client/client_test.go +++ b/xds/internal/xdsclient/client_test.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "context" @@ -26,15 +26,16 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/testing/protocmp" ) @@ -62,19 +63,11 @@ const ( var ( cmpOpts = cmp.Options{ cmpopts.EquateEmpty(), + cmp.FilterValues(func(x, y error) bool { return true }, cmpopts.EquateErrors()), cmp.Comparer(func(a, b time.Time) bool { return true }), - cmp.Comparer(func(x, y error) bool { - if x == nil || y == nil { - return x == nil && y == nil - } - return x.Error() == y.Error() - }), protocmp.Transform(), } - // When comparing NACK UpdateMetadata, we only care if error is nil, but not - // the details in error. - errPlaceHolder = fmt.Errorf("error whose details don't matter") cmpOptsIgnoreDetails = cmp.Options{ cmp.Comparer(func(a, b time.Time) bool { return true }), cmp.Comparer(func(x, y error) bool { @@ -170,7 +163,7 @@ func (s) TestWatchCallAnotherWatch(t *testing.T) { clusterUpdateCh := testutils.NewChannel() firstTime := true client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) // Calls another watch inline, to ensure there's deadlock. client.WatchCluster("another-random-name", func(ClusterUpdate, error) {}) @@ -183,63 +176,89 @@ func (s) TestWatchCallAnotherWatch(t *testing.T) { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } - wantUpdate2 := ClusterUpdate{ServiceName: testEDSName + "2"} - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate2}, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate2); err != nil { + // The second update needs to be different in the underlying resource proto + // for the watch callback to be invoked. + wantUpdate2 := ClusterUpdate{ClusterName: testEDSName + "2", Raw: &anypb.Any{}} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate2}}, UpdateMetadata{}) + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate2, nil); err != nil { t.Fatal(err) } } -func verifyListenerUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ListenerUpdate) error { +func verifyListenerUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ListenerUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for listener update: %v", err) } - gotUpdate := u.(ldsUpdateErr) - if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate) { - return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) + gotUpdate := u.(ListenerUpdateErrTuple) + if wantErr != nil { + if gotUpdate.Err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.Err, wantErr) + } + return nil + } + if gotUpdate.Err != nil || !cmp.Equal(gotUpdate.Update, wantUpdate, protocmp.Transform()) { + return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.Update, gotUpdate.Err, wantUpdate) } return nil } -func verifyRouteConfigUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate RouteConfigUpdate) error { +func verifyRouteConfigUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate RouteConfigUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for route configuration update: %v", err) } - gotUpdate := u.(rdsUpdateErr) - if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate) { - return fmt.Errorf("unexpected route config update: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) + gotUpdate := u.(RouteConfigUpdateErrTuple) + if wantErr != nil { + if gotUpdate.Err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.Err, wantErr) + } + return nil + } + if gotUpdate.Err != nil || !cmp.Equal(gotUpdate.Update, wantUpdate, protocmp.Transform()) { + return fmt.Errorf("unexpected route config update: (%v, %v), want: (%v, nil)", gotUpdate.Update, gotUpdate.Err, wantUpdate) } return nil } -func verifyClusterUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ClusterUpdate) error { +func verifyClusterUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ClusterUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for cluster update: %v", err) } - gotUpdate := u.(clusterUpdateErr) - if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate) { - return fmt.Errorf("unexpected clusterUpdate: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) + gotUpdate := u.(ClusterUpdateErrTuple) + if wantErr != nil { + if gotUpdate.Err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.Err, wantErr) + } + return nil + } + if !cmp.Equal(gotUpdate.Update, wantUpdate, protocmp.Transform()) { + return fmt.Errorf("unexpected clusterUpdate: (%v, %v), want: (%v, nil)", gotUpdate.Update, gotUpdate.Err, wantUpdate) } return nil } -func verifyEndpointsUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate EndpointsUpdate) error { +func verifyEndpointsUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate EndpointsUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for endpoints update: %v", err) } - gotUpdate := u.(endpointsUpdateErr) - if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate, cmpopts.EquateEmpty()) { - return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) + gotUpdate := u.(EndpointsUpdateErrTuple) + if wantErr != nil { + if gotUpdate.Err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.Err, wantErr) + } + return nil + } + if gotUpdate.Err != nil || !cmp.Equal(gotUpdate.Update, wantUpdate, cmpopts.EquateEmpty(), protocmp.Transform()) { + return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.Update, gotUpdate.Err, wantUpdate) } return nil } @@ -261,7 +280,7 @@ func (s) TestClientNewSingleton(t *testing.T) { defer cleanup() // The first New(). Should create a Client and a new APIClient. - client, err := New() + client, err := newRefCounted() if err != nil { t.Fatalf("failed to create client: %v", err) } @@ -278,7 +297,7 @@ func (s) TestClientNewSingleton(t *testing.T) { // and should not create new API client. const count = 9 for i := 0; i < count; i++ { - tc, terr := New() + tc, terr := newRefCounted() if terr != nil { client.Close() t.Fatalf("%d-th call to New() failed with error: %v", i, terr) @@ -322,7 +341,7 @@ func (s) TestClientNewSingleton(t *testing.T) { // Call New() again after the previous Client is actually closed. Should // create a Client and a new APIClient. - client2, err2 := New() + client2, err2 := newRefCounted() if err2 != nil { t.Fatalf("failed to create client: %v", err) } diff --git a/xds/internal/client/dump.go b/xds/internal/xdsclient/dump.go similarity index 99% rename from xds/internal/client/dump.go rename to xds/internal/xdsclient/dump.go index 3fd18f6103b..db9b474f370 100644 --- a/xds/internal/client/dump.go +++ b/xds/internal/xdsclient/dump.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import anypb "github.com/golang/protobuf/ptypes/any" diff --git a/xds/internal/client/tests/dump_test.go b/xds/internal/xdsclient/dump_test.go similarity index 76% rename from xds/internal/client/tests/dump_test.go rename to xds/internal/xdsclient/dump_test.go index 58220866eb1..d03479ca4ad 100644 --- a/xds/internal/client/tests/dump_test.go +++ b/xds/internal/xdsclient/dump_test.go @@ -16,7 +16,7 @@ * */ -package tests_test +package xdsclient_test import ( "fmt" @@ -28,7 +28,6 @@ import ( v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" - "github.com/golang/protobuf/ptypes" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/protobuf/testing/protocmp" @@ -37,9 +36,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/internal/testutils" xdstestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const defaultTestWatchExpiryTimeout = 500 * time.Millisecond @@ -56,29 +56,22 @@ func (s) TestLDSConfigDump(t *testing.T) { listenersT := &v3listenerpb.Listener{ Name: ldsTargets[i], ApiListener: &v3listenerpb.ApiListener{ - ApiListener: func() *anypb.Any { - mcm, _ := ptypes.MarshalAny(&v3httppb.HttpConnectionManager{ - RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ - Rds: &v3httppb.Rds{ - ConfigSource: &v3corepb.ConfigSource{ - ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, - }, - RouteConfigName: routeConfigNames[i], + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, }, + RouteConfigName: routeConfigNames[i], }, - CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ - MaxStreamDuration: durationpb.New(time.Second), - }, - }) - return mcm - }(), + }, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + }), }, } - anyT, err := ptypes.MarshalAny(listenersT) - if err != nil { - t.Fatalf("failed to marshal proto to any: %v", err) - } - listenerRaws[ldsTargets[i]] = anyT + listenerRaws[ldsTargets[i]] = testutils.MarshalAny(listenersT) } client, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ @@ -90,6 +83,7 @@ func (s) TestLDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpLDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -107,16 +101,16 @@ func (s) TestLDSConfigDump(t *testing.T) { t.Fatalf(err.Error()) } - update0 := make(map[string]xdsclient.ListenerUpdate) + update0 := make(map[string]xdsclient.ListenerUpdateErrTuple) want0 := make(map[string]xdsclient.UpdateWithMD) for n, r := range listenerRaws { - update0[n] = xdsclient.ListenerUpdate{Raw: r} + update0[n] = xdsclient.ListenerUpdateErrTuple{Update: xdsclient.ListenerUpdate{Raw: r}} want0[n] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}, Raw: r, } } - client.NewListeners(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewListeners(update0, xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpLDS, testVersion, want0); err != nil { @@ -125,11 +119,13 @@ func (s) TestLDSConfigDump(t *testing.T) { const nackVersion = "lds-version-nack" var nackErr = fmt.Errorf("lds nack error") - client.NewListeners( - map[string]xdsclient.ListenerUpdate{ - ldsTargets[0]: {}, + updateHandler.NewListeners( + map[string]xdsclient.ListenerUpdateErrTuple{ + ldsTargets[0]: {Err: nackErr}, + ldsTargets[1]: {Update: xdsclient.ListenerUpdate{Raw: listenerRaws[ldsTargets[1]]}}, }, xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, Err: nackErr, @@ -143,6 +139,7 @@ func (s) TestLDSConfigDump(t *testing.T) { // message, as well as the NACK error. wantDump[ldsTargets[0]] = xdsclient.UpdateWithMD{ MD: xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, Version: testVersion, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, @@ -153,7 +150,7 @@ func (s) TestLDSConfigDump(t *testing.T) { } wantDump[ldsTargets[1]] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: nackVersion}, Raw: listenerRaws[ldsTargets[1]], } if err := compareDump(client.DumpLDS, nackVersion, wantDump); err != nil { @@ -188,11 +185,7 @@ func (s) TestRDSConfigDump(t *testing.T) { }, } - anyT, err := ptypes.MarshalAny(routeConfigT) - if err != nil { - t.Fatalf("failed to marshal proto to any: %v", err) - } - routeRaws[rdsTargets[i]] = anyT + routeRaws[rdsTargets[i]] = testutils.MarshalAny(routeConfigT) } client, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ @@ -204,6 +197,7 @@ func (s) TestRDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpRDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -221,16 +215,16 @@ func (s) TestRDSConfigDump(t *testing.T) { t.Fatalf(err.Error()) } - update0 := make(map[string]xdsclient.RouteConfigUpdate) + update0 := make(map[string]xdsclient.RouteConfigUpdateErrTuple) want0 := make(map[string]xdsclient.UpdateWithMD) for n, r := range routeRaws { - update0[n] = xdsclient.RouteConfigUpdate{Raw: r} + update0[n] = xdsclient.RouteConfigUpdateErrTuple{Update: xdsclient.RouteConfigUpdate{Raw: r}} want0[n] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}, Raw: r, } } - client.NewRouteConfigs(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewRouteConfigs(update0, xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpRDS, testVersion, want0); err != nil { @@ -239,11 +233,13 @@ func (s) TestRDSConfigDump(t *testing.T) { const nackVersion = "rds-version-nack" var nackErr = fmt.Errorf("rds nack error") - client.NewRouteConfigs( - map[string]xdsclient.RouteConfigUpdate{ - rdsTargets[0]: {}, + updateHandler.NewRouteConfigs( + map[string]xdsclient.RouteConfigUpdateErrTuple{ + rdsTargets[0]: {Err: nackErr}, + rdsTargets[1]: {Update: xdsclient.RouteConfigUpdate{Raw: routeRaws[rdsTargets[1]]}}, }, xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, Err: nackErr, @@ -257,6 +253,7 @@ func (s) TestRDSConfigDump(t *testing.T) { // message, as well as the NACK error. wantDump[rdsTargets[0]] = xdsclient.UpdateWithMD{ MD: xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, Version: testVersion, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, @@ -266,7 +263,7 @@ func (s) TestRDSConfigDump(t *testing.T) { Raw: routeRaws[rdsTargets[0]], } wantDump[rdsTargets[1]] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: nackVersion}, Raw: routeRaws[rdsTargets[1]], } if err := compareDump(client.DumpRDS, nackVersion, wantDump); err != nil { @@ -302,11 +299,7 @@ func (s) TestCDSConfigDump(t *testing.T) { }, } - anyT, err := ptypes.MarshalAny(clusterT) - if err != nil { - t.Fatalf("failed to marshal proto to any: %v", err) - } - clusterRaws[cdsTargets[i]] = anyT + clusterRaws[cdsTargets[i]] = testutils.MarshalAny(clusterT) } client, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ @@ -318,6 +311,7 @@ func (s) TestCDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpCDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -335,16 +329,16 @@ func (s) TestCDSConfigDump(t *testing.T) { t.Fatalf(err.Error()) } - update0 := make(map[string]xdsclient.ClusterUpdate) + update0 := make(map[string]xdsclient.ClusterUpdateErrTuple) want0 := make(map[string]xdsclient.UpdateWithMD) for n, r := range clusterRaws { - update0[n] = xdsclient.ClusterUpdate{Raw: r} + update0[n] = xdsclient.ClusterUpdateErrTuple{Update: xdsclient.ClusterUpdate{Raw: r}} want0[n] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}, Raw: r, } } - client.NewClusters(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewClusters(update0, xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpCDS, testVersion, want0); err != nil { @@ -353,11 +347,13 @@ func (s) TestCDSConfigDump(t *testing.T) { const nackVersion = "cds-version-nack" var nackErr = fmt.Errorf("cds nack error") - client.NewClusters( - map[string]xdsclient.ClusterUpdate{ - cdsTargets[0]: {}, + updateHandler.NewClusters( + map[string]xdsclient.ClusterUpdateErrTuple{ + cdsTargets[0]: {Err: nackErr}, + cdsTargets[1]: {Update: xdsclient.ClusterUpdate{Raw: clusterRaws[cdsTargets[1]]}}, }, xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, Err: nackErr, @@ -371,6 +367,7 @@ func (s) TestCDSConfigDump(t *testing.T) { // message, as well as the NACK error. wantDump[cdsTargets[0]] = xdsclient.UpdateWithMD{ MD: xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, Version: testVersion, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, @@ -380,7 +377,7 @@ func (s) TestCDSConfigDump(t *testing.T) { Raw: clusterRaws[cdsTargets[0]], } wantDump[cdsTargets[1]] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: nackVersion}, Raw: clusterRaws[cdsTargets[1]], } if err := compareDump(client.DumpCDS, nackVersion, wantDump); err != nil { @@ -402,11 +399,7 @@ func (s) TestEDSConfigDump(t *testing.T) { clab0.AddLocality(localityNames[i], 1, 1, []string{addrs[i]}, nil) claT := clab0.Build() - anyT, err := ptypes.MarshalAny(claT) - if err != nil { - t.Fatalf("failed to marshal proto to any: %v", err) - } - endpointRaws[edsTargets[i]] = anyT + endpointRaws[edsTargets[i]] = testutils.MarshalAny(claT) } client, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ @@ -418,6 +411,7 @@ func (s) TestEDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpEDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -435,16 +429,16 @@ func (s) TestEDSConfigDump(t *testing.T) { t.Fatalf(err.Error()) } - update0 := make(map[string]xdsclient.EndpointsUpdate) + update0 := make(map[string]xdsclient.EndpointsUpdateErrTuple) want0 := make(map[string]xdsclient.UpdateWithMD) for n, r := range endpointRaws { - update0[n] = xdsclient.EndpointsUpdate{Raw: r} + update0[n] = xdsclient.EndpointsUpdateErrTuple{Update: xdsclient.EndpointsUpdate{Raw: r}} want0[n] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}, Raw: r, } } - client.NewEndpoints(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewEndpoints(update0, xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpEDS, testVersion, want0); err != nil { @@ -453,11 +447,13 @@ func (s) TestEDSConfigDump(t *testing.T) { const nackVersion = "eds-version-nack" var nackErr = fmt.Errorf("eds nack error") - client.NewEndpoints( - map[string]xdsclient.EndpointsUpdate{ - edsTargets[0]: {}, + updateHandler.NewEndpoints( + map[string]xdsclient.EndpointsUpdateErrTuple{ + edsTargets[0]: {Err: nackErr}, + edsTargets[1]: {Update: xdsclient.EndpointsUpdate{Raw: endpointRaws[edsTargets[1]]}}, }, xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, Err: nackErr, @@ -471,6 +467,7 @@ func (s) TestEDSConfigDump(t *testing.T) { // message, as well as the NACK error. wantDump[edsTargets[0]] = xdsclient.UpdateWithMD{ MD: xdsclient.UpdateMetadata{ + Status: xdsclient.ServiceStatusNACKed, Version: testVersion, ErrState: &xdsclient.UpdateErrorMetadata{ Version: nackVersion, @@ -480,7 +477,7 @@ func (s) TestEDSConfigDump(t *testing.T) { Raw: endpointRaws[edsTargets[0]], } wantDump[edsTargets[1]] = xdsclient.UpdateWithMD{ - MD: xdsclient.UpdateMetadata{Version: testVersion}, + MD: xdsclient.UpdateMetadata{Status: xdsclient.ServiceStatusACKed, Version: nackVersion}, Raw: endpointRaws[edsTargets[1]], } if err := compareDump(client.DumpEDS, nackVersion, wantDump); err != nil { diff --git a/xds/internal/client/eds_test.go b/xds/internal/xdsclient/eds_test.go similarity index 82% rename from xds/internal/client/eds_test.go rename to xds/internal/xdsclient/eds_test.go index daa5d6525e1..d0af8a988d8 100644 --- a/xds/internal/client/eds_test.go +++ b/xds/internal/xdsclient/eds_test.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "fmt" @@ -27,10 +27,11 @@ import ( v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" - "github.com/golang/protobuf/proto" anypb "github.com/golang/protobuf/ptypes/any" wrapperspb "github.com/golang/protobuf/ptypes/wrappers" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/version" ) @@ -120,29 +121,24 @@ func (s) TestEDSParseRespProto(t *testing.T) { } func (s) TestUnmarshalEndpoints(t *testing.T) { - var v3EndpointsAny = &anypb.Any{ - TypeUrl: version.V3EndpointsURL, - Value: func() []byte { - clab0 := newClaBuilder("test", nil) - clab0.addLocality("locality-1", 1, 1, []string{"addr1:314"}, &addLocalityOptions{ - Health: []v3corepb.HealthStatus{v3corepb.HealthStatus_UNHEALTHY}, - Weight: []uint32{271}, - }) - clab0.addLocality("locality-2", 1, 0, []string{"addr2:159"}, &addLocalityOptions{ - Health: []v3corepb.HealthStatus{v3corepb.HealthStatus_DRAINING}, - Weight: []uint32{828}, - }) - e := clab0.Build() - me, _ := proto.Marshal(e) - return me - }(), - } + var v3EndpointsAny = testutils.MarshalAny(func() *v3endpointpb.ClusterLoadAssignment { + clab0 := newClaBuilder("test", nil) + clab0.addLocality("locality-1", 1, 1, []string{"addr1:314"}, &addLocalityOptions{ + Health: []v3corepb.HealthStatus{v3corepb.HealthStatus_UNHEALTHY}, + Weight: []uint32{271}, + }) + clab0.addLocality("locality-2", 1, 0, []string{"addr2:159"}, &addLocalityOptions{ + Health: []v3corepb.HealthStatus{v3corepb.HealthStatus_DRAINING}, + Weight: []uint32{828}, + }) + return clab0.Build() + }()) const testVersion = "test-version-eds" tests := []struct { name string resources []*anypb.Any - wantUpdate map[string]EndpointsUpdate + wantUpdate map[string]EndpointsUpdateErrTuple wantMD UpdateMetadata wantErr bool }{ @@ -154,7 +150,7 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -172,33 +168,26 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, }, { name: "bad endpoints resource", - resources: []*anypb.Any{ - { - TypeUrl: version.V3EndpointsURL, - Value: func() []byte { - clab0 := newClaBuilder("test", nil) - clab0.addLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) - clab0.addLocality("locality-2", 1, 2, []string{"addr2:159"}, nil) - e := clab0.Build() - me, _ := proto.Marshal(e) - return me - }(), - }, - }, - wantUpdate: map[string]EndpointsUpdate{"test": {}}, + resources: []*anypb.Any{testutils.MarshalAny(func() *v3endpointpb.ClusterLoadAssignment { + clab0 := newClaBuilder("test", nil) + clab0.addLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) + clab0.addLocality("locality-2", 1, 2, []string{"addr2:159"}, nil) + return clab0.Build() + }())}, + wantUpdate: map[string]EndpointsUpdateErrTuple{"test": {Err: cmpopts.AnyError}}, wantMD: UpdateMetadata{ Status: ServiceStatusNACKed, Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -206,8 +195,8 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { { name: "v3 endpoints", resources: []*anypb.Any{v3EndpointsAny}, - wantUpdate: map[string]EndpointsUpdate{ - "test": { + wantUpdate: map[string]EndpointsUpdateErrTuple{ + "test": {Update: EndpointsUpdate{ Drops: nil, Localities: []Locality{ { @@ -232,7 +221,7 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { }, }, Raw: v3EndpointsAny, - }, + }}, }, wantMD: UpdateMetadata{ Status: ServiceStatusACKed, @@ -244,21 +233,15 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { name: "good and bad endpoints", resources: []*anypb.Any{ v3EndpointsAny, - { - // bad endpoints resource - TypeUrl: version.V3EndpointsURL, - Value: func() []byte { - clab0 := newClaBuilder("bad", nil) - clab0.addLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) - clab0.addLocality("locality-2", 1, 2, []string{"addr2:159"}, nil) - e := clab0.Build() - me, _ := proto.Marshal(e) - return me - }(), - }, + testutils.MarshalAny(func() *v3endpointpb.ClusterLoadAssignment { + clab0 := newClaBuilder("bad", nil) + clab0.addLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) + clab0.addLocality("locality-2", 1, 2, []string{"addr2:159"}, nil) + return clab0.Build() + }()), }, - wantUpdate: map[string]EndpointsUpdate{ - "test": { + wantUpdate: map[string]EndpointsUpdateErrTuple{ + "test": {Update: EndpointsUpdate{ Drops: nil, Localities: []Locality{ { @@ -283,15 +266,15 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { }, }, Raw: v3EndpointsAny, - }, - "bad": {}, + }}, + "bad": {Err: cmpopts.AnyError}, }, wantMD: UpdateMetadata{ Status: ServiceStatusNACKed, Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -299,9 +282,13 @@ func (s) TestUnmarshalEndpoints(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - update, md, err := UnmarshalEndpoints(testVersion, test.resources, nil) + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + update, md, err := UnmarshalEndpoints(opts) if (err != nil) != test.wantErr { - t.Fatalf("UnmarshalEndpoints(), got err: %v, wantErr: %v", err, test.wantErr) + t.Fatalf("UnmarshalEndpoints(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) } if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { t.Errorf("got unexpected update, diff (-got +want): %v", diff) diff --git a/xds/internal/client/errors.go b/xds/internal/xdsclient/errors.go similarity index 98% rename from xds/internal/client/errors.go rename to xds/internal/xdsclient/errors.go index 34ae2738db0..4d6cdaaf9b4 100644 --- a/xds/internal/client/errors.go +++ b/xds/internal/xdsclient/errors.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import "fmt" diff --git a/xds/internal/xdsclient/filter_chain.go b/xds/internal/xdsclient/filter_chain.go new file mode 100644 index 00000000000..f2b29f52a44 --- /dev/null +++ b/xds/internal/xdsclient/filter_chain.go @@ -0,0 +1,852 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "errors" + "fmt" + "net" + + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/version" +) + +const ( + // Used as the map key for unspecified prefixes. The actual value of this + // key is immaterial. + unspecifiedPrefixMapKey = "unspecified" + + // An unspecified destination or source prefix should be considered a less + // specific match than a wildcard prefix, `0.0.0.0/0` or `::/0`. Also, an + // unspecified prefix should match most v4 and v6 addresses compared to the + // wildcard prefixes which match only a specific network (v4 or v6). + // + // We use these constants when looking up the most specific prefix match. A + // wildcard prefix will match 0 bits, and to make sure that a wildcard + // prefix is considered a more specific match than an unspecified prefix, we + // use a value of -1 for the latter. + noPrefixMatch = -2 + unspecifiedPrefixMatch = -1 +) + +// FilterChain captures information from within a FilterChain message in a +// Listener resource. +type FilterChain struct { + // SecurityCfg contains transport socket security configuration. + SecurityCfg *SecurityConfig + // HTTPFilters represent the HTTP Filters that comprise this FilterChain. + HTTPFilters []HTTPFilter + // RouteConfigName is the route configuration name for this FilterChain. + // + // Only one of RouteConfigName and InlineRouteConfig is set. + RouteConfigName string + // InlineRouteConfig is the inline route configuration (RDS response) + // returned for this filter chain. + // + // Only one of RouteConfigName and InlineRouteConfig is set. + InlineRouteConfig *RouteConfigUpdate +} + +// VirtualHostWithInterceptors captures information present in a VirtualHost +// update, and also contains routes with instantiated HTTP Filters. +type VirtualHostWithInterceptors struct { + // Domains are the domain names which map to this Virtual Host. On the + // server side, this will be dictated by the :authority header of the + // incoming RPC. + Domains []string + // Routes are the Routes for this Virtual Host. + Routes []RouteWithInterceptors +} + +// RouteWithInterceptors captures information in a Route, and contains +// a usable matcher and also instantiated HTTP Filters. +type RouteWithInterceptors struct { + // M is the matcher used to match to this route. + M *CompositeMatcher + // RouteAction is the type of routing action to initiate once matched to. + RouteAction RouteAction + // Interceptors are interceptors instantiated for this route. These will be + // constructed from a combination of the top level configuration and any + // HTTP Filter overrides present in Virtual Host or Route. + Interceptors []resolver.ServerInterceptor +} + +// ConstructUsableRouteConfiguration takes Route Configuration and converts it +// into matchable route configuration, with instantiated HTTP Filters per route. +func (f *FilterChain) ConstructUsableRouteConfiguration(config RouteConfigUpdate) ([]VirtualHostWithInterceptors, error) { + vhs := make([]VirtualHostWithInterceptors, len(config.VirtualHosts)) + for _, vh := range config.VirtualHosts { + vhwi, err := f.convertVirtualHost(vh) + if err != nil { + return nil, fmt.Errorf("virtual host construction: %v", err) + } + vhs = append(vhs, vhwi) + } + return vhs, nil +} + +func (f *FilterChain) convertVirtualHost(virtualHost *VirtualHost) (VirtualHostWithInterceptors, error) { + rs := make([]RouteWithInterceptors, len(virtualHost.Routes)) + for i, r := range virtualHost.Routes { + var err error + rs[i].RouteAction = r.RouteAction + rs[i].M, err = RouteToMatcher(r) + if err != nil { + return VirtualHostWithInterceptors{}, fmt.Errorf("matcher construction: %v", err) + } + for _, filter := range f.HTTPFilters { + // Route is highest priority on server side, as there is no concept + // of an upstream cluster on server side. + override := r.HTTPFilterConfigOverride[filter.Name] + if override == nil { + // Virtual Host is second priority. + override = virtualHost.HTTPFilterConfigOverride[filter.Name] + } + sb, ok := filter.Filter.(httpfilter.ServerInterceptorBuilder) + if !ok { + // Should not happen if it passed xdsClient validation. + return VirtualHostWithInterceptors{}, fmt.Errorf("filter does not support use in server") + } + si, err := sb.BuildServerInterceptor(filter.Config, override) + if err != nil { + return VirtualHostWithInterceptors{}, fmt.Errorf("filter construction: %v", err) + } + if si != nil { + rs[i].Interceptors = append(rs[i].Interceptors, si) + } + } + } + return VirtualHostWithInterceptors{Domains: virtualHost.Domains, Routes: rs}, nil +} + +// SourceType specifies the connection source IP match type. +type SourceType int + +const ( + // SourceTypeAny matches connection attempts from any source. + SourceTypeAny SourceType = iota + // SourceTypeSameOrLoopback matches connection attempts from the same host. + SourceTypeSameOrLoopback + // SourceTypeExternal matches connection attempts from a different host. + SourceTypeExternal +) + +// FilterChainManager contains all the match criteria specified through all +// filter chains in a single Listener resource. It also contains the default +// filter chain specified in the Listener resource. It provides two important +// pieces of functionality: +// 1. Validate the filter chains in an incoming Listener resource to make sure +// that there aren't filter chains which contain the same match criteria. +// 2. As part of performing the above validation, it builds an internal data +// structure which will if used to look up the matching filter chain at +// connection time. +// +// The logic specified in the documentation around the xDS FilterChainMatch +// proto mentions 8 criteria to match on. +// The following order applies: +// +// 1. Destination port. +// 2. Destination IP address. +// 3. Server name (e.g. SNI for TLS protocol), +// 4. Transport protocol. +// 5. Application protocols (e.g. ALPN for TLS protocol). +// 6. Source type (e.g. any, local or external network). +// 7. Source IP address. +// 8. Source port. +type FilterChainManager struct { + // Destination prefix is the first match criteria that we support. + // Therefore, this multi-stage map is indexed on destination prefixes + // specified in the match criteria. + // Unspecified destination prefix matches end up as a wildcard entry here + // with a key of 0.0.0.0/0. + dstPrefixMap map[string]*destPrefixEntry + + // At connection time, we do not have the actual destination prefix to match + // on. We only have the real destination address of the incoming connection. + // This means that we cannot use the above map at connection time. This list + // contains the map entries from the above map that we can use at connection + // time to find matching destination prefixes in O(n) time. + // + // TODO: Implement LC-trie to support logarithmic time lookups. If that + // involves too much time/effort, sort this slice based on the netmask size. + dstPrefixes []*destPrefixEntry + + def *FilterChain // Default filter chain, if specified. + + // RouteConfigNames are the route configuration names which need to be + // dynamically queried for RDS Configuration for any FilterChains which + // specify to load RDS Configuration dynamically. + RouteConfigNames map[string]bool +} + +// destPrefixEntry is the value type of the map indexed on destination prefixes. +type destPrefixEntry struct { + // The actual destination prefix. Set to nil for unspecified prefixes. + net *net.IPNet + // We need to keep track of the transport protocols seen as part of the + // config validation (and internal structure building) phase. The only two + // values that we support are empty string and "raw_buffer", with the latter + // taking preference. Once we have seen one filter chain with "raw_buffer", + // we can drop everything other filter chain with an empty transport + // protocol. + rawBufferSeen bool + // For each specified source type in the filter chain match criteria, this + // array points to the set of specified source prefixes. + // Unspecified source type matches end up as a wildcard entry here with an + // index of 0, which actually represents the source type `ANY`. + srcTypeArr sourceTypesArray +} + +// An array for the fixed number of source types that we have. +type sourceTypesArray [3]*sourcePrefixes + +// sourcePrefixes contains source prefix related information specified in the +// match criteria. These are pointed to by the array of source types. +type sourcePrefixes struct { + // These are very similar to the 'dstPrefixMap' and 'dstPrefixes' field of + // FilterChainManager. Go there for more info. + srcPrefixMap map[string]*sourcePrefixEntry + srcPrefixes []*sourcePrefixEntry +} + +// sourcePrefixEntry contains match criteria per source prefix. +type sourcePrefixEntry struct { + // The actual destination prefix. Set to nil for unspecified prefixes. + net *net.IPNet + // Mapping from source ports specified in the match criteria to the actual + // filter chain. Unspecified source port matches en up as a wildcard entry + // here with a key of 0. + srcPortMap map[int]*FilterChain +} + +// NewFilterChainManager parses the received Listener resource and builds a +// FilterChainManager. Returns a non-nil error on validation failures. +// +// This function is only exported so that tests outside of this package can +// create a FilterChainManager. +func NewFilterChainManager(lis *v3listenerpb.Listener) (*FilterChainManager, error) { + // Parse all the filter chains and build the internal data structures. + fci := &FilterChainManager{ + dstPrefixMap: make(map[string]*destPrefixEntry), + RouteConfigNames: make(map[string]bool), + } + if err := fci.addFilterChains(lis.GetFilterChains()); err != nil { + return nil, err + } + // Build the source and dest prefix slices used by Lookup(). + fcSeen := false + for _, dstPrefix := range fci.dstPrefixMap { + fci.dstPrefixes = append(fci.dstPrefixes, dstPrefix) + for _, st := range dstPrefix.srcTypeArr { + if st == nil { + continue + } + for _, srcPrefix := range st.srcPrefixMap { + st.srcPrefixes = append(st.srcPrefixes, srcPrefix) + for _, fc := range srcPrefix.srcPortMap { + if fc != nil { + fcSeen = true + } + } + } + } + } + + // Retrieve the default filter chain. The match criteria specified on the + // default filter chain is never used. The default filter chain simply gets + // used when none of the other filter chains match. + var def *FilterChain + if dfc := lis.GetDefaultFilterChain(); dfc != nil { + var err error + if def, err = fci.filterChainFromProto(dfc); err != nil { + return nil, err + } + } + fci.def = def + + // If there are no supported filter chains and no default filter chain, we + // fail here. This will call the Listener resource to be NACK'ed. + if !fcSeen && fci.def == nil { + return nil, fmt.Errorf("no supported filter chains and no default filter chain") + } + return fci, nil +} + +// addFilterChains parses the filter chains in fcs and adds the required +// internal data structures corresponding to the match criteria. +func (fci *FilterChainManager) addFilterChains(fcs []*v3listenerpb.FilterChain) error { + for _, fc := range fcs { + fcm := fc.GetFilterChainMatch() + if fcm.GetDestinationPort().GetValue() != 0 { + // Destination port is the first match criteria and we do not + // support filter chains which contains this match criteria. + logger.Warningf("Dropping filter chain %+v since it contains unsupported destination_port match field", fc) + continue + } + + // Build the internal representation of the filter chain match fields. + if err := fci.addFilterChainsForDestPrefixes(fc); err != nil { + return err + } + } + + return nil +} + +func (fci *FilterChainManager) addFilterChainsForDestPrefixes(fc *v3listenerpb.FilterChain) error { + ranges := fc.GetFilterChainMatch().GetPrefixRanges() + dstPrefixes := make([]*net.IPNet, 0, len(ranges)) + for _, pr := range ranges { + cidr := fmt.Sprintf("%s/%d", pr.GetAddressPrefix(), pr.GetPrefixLen().GetValue()) + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("failed to parse destination prefix range: %+v", pr) + } + dstPrefixes = append(dstPrefixes, ipnet) + } + + if len(dstPrefixes) == 0 { + // Use the unspecified entry when destination prefix is unspecified, and + // set the `net` field to nil. + if fci.dstPrefixMap[unspecifiedPrefixMapKey] == nil { + fci.dstPrefixMap[unspecifiedPrefixMapKey] = &destPrefixEntry{} + } + return fci.addFilterChainsForServerNames(fci.dstPrefixMap[unspecifiedPrefixMapKey], fc) + } + for _, prefix := range dstPrefixes { + p := prefix.String() + if fci.dstPrefixMap[p] == nil { + fci.dstPrefixMap[p] = &destPrefixEntry{net: prefix} + } + if err := fci.addFilterChainsForServerNames(fci.dstPrefixMap[p], fc); err != nil { + return err + } + } + return nil +} + +func (fci *FilterChainManager) addFilterChainsForServerNames(dstEntry *destPrefixEntry, fc *v3listenerpb.FilterChain) error { + // Filter chains specifying server names in their match criteria always fail + // a match at connection time. So, these filter chains can be dropped now. + if len(fc.GetFilterChainMatch().GetServerNames()) != 0 { + logger.Warningf("Dropping filter chain %+v since it contains unsupported server_names match field", fc) + return nil + } + + return fci.addFilterChainsForTransportProtocols(dstEntry, fc) +} + +func (fci *FilterChainManager) addFilterChainsForTransportProtocols(dstEntry *destPrefixEntry, fc *v3listenerpb.FilterChain) error { + tp := fc.GetFilterChainMatch().GetTransportProtocol() + switch { + case tp != "" && tp != "raw_buffer": + // Only allow filter chains with transport protocol set to empty string + // or "raw_buffer". + logger.Warningf("Dropping filter chain %+v since it contains unsupported value for transport_protocols match field", fc) + return nil + case tp == "" && dstEntry.rawBufferSeen: + // If we have already seen filter chains with transport protocol set to + // "raw_buffer", we can drop filter chains with transport protocol set + // to empty string, since the former takes precedence. + logger.Warningf("Dropping filter chain %+v since it contains unsupported value for transport_protocols match field", fc) + return nil + case tp != "" && !dstEntry.rawBufferSeen: + // This is the first "raw_buffer" that we are seeing. Set the bit and + // reset the source types array which might contain entries for filter + // chains with transport protocol set to empty string. + dstEntry.rawBufferSeen = true + dstEntry.srcTypeArr = sourceTypesArray{} + } + return fci.addFilterChainsForApplicationProtocols(dstEntry, fc) +} + +func (fci *FilterChainManager) addFilterChainsForApplicationProtocols(dstEntry *destPrefixEntry, fc *v3listenerpb.FilterChain) error { + if len(fc.GetFilterChainMatch().GetApplicationProtocols()) != 0 { + logger.Warningf("Dropping filter chain %+v since it contains unsupported application_protocols match field", fc) + return nil + } + return fci.addFilterChainsForSourceType(dstEntry, fc) +} + +// addFilterChainsForSourceType adds source types to the internal data +// structures and delegates control to addFilterChainsForSourcePrefixes to +// continue building the internal data structure. +func (fci *FilterChainManager) addFilterChainsForSourceType(dstEntry *destPrefixEntry, fc *v3listenerpb.FilterChain) error { + var srcType SourceType + switch st := fc.GetFilterChainMatch().GetSourceType(); st { + case v3listenerpb.FilterChainMatch_ANY: + srcType = SourceTypeAny + case v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK: + srcType = SourceTypeSameOrLoopback + case v3listenerpb.FilterChainMatch_EXTERNAL: + srcType = SourceTypeExternal + default: + return fmt.Errorf("unsupported source type: %v", st) + } + + st := int(srcType) + if dstEntry.srcTypeArr[st] == nil { + dstEntry.srcTypeArr[st] = &sourcePrefixes{srcPrefixMap: make(map[string]*sourcePrefixEntry)} + } + return fci.addFilterChainsForSourcePrefixes(dstEntry.srcTypeArr[st].srcPrefixMap, fc) +} + +// addFilterChainsForSourcePrefixes adds source prefixes to the internal data +// structures and delegates control to addFilterChainsForSourcePorts to continue +// building the internal data structure. +func (fci *FilterChainManager) addFilterChainsForSourcePrefixes(srcPrefixMap map[string]*sourcePrefixEntry, fc *v3listenerpb.FilterChain) error { + ranges := fc.GetFilterChainMatch().GetSourcePrefixRanges() + srcPrefixes := make([]*net.IPNet, 0, len(ranges)) + for _, pr := range fc.GetFilterChainMatch().GetSourcePrefixRanges() { + cidr := fmt.Sprintf("%s/%d", pr.GetAddressPrefix(), pr.GetPrefixLen().GetValue()) + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("failed to parse source prefix range: %+v", pr) + } + srcPrefixes = append(srcPrefixes, ipnet) + } + + if len(srcPrefixes) == 0 { + // Use the unspecified entry when destination prefix is unspecified, and + // set the `net` field to nil. + if srcPrefixMap[unspecifiedPrefixMapKey] == nil { + srcPrefixMap[unspecifiedPrefixMapKey] = &sourcePrefixEntry{ + srcPortMap: make(map[int]*FilterChain), + } + } + return fci.addFilterChainsForSourcePorts(srcPrefixMap[unspecifiedPrefixMapKey], fc) + } + for _, prefix := range srcPrefixes { + p := prefix.String() + if srcPrefixMap[p] == nil { + srcPrefixMap[p] = &sourcePrefixEntry{ + net: prefix, + srcPortMap: make(map[int]*FilterChain), + } + } + if err := fci.addFilterChainsForSourcePorts(srcPrefixMap[p], fc); err != nil { + return err + } + } + return nil +} + +// addFilterChainsForSourcePorts adds source ports to the internal data +// structures and completes the process of building the internal data structure. +// It is here that we determine if there are multiple filter chains with +// overlapping matching rules. +func (fci *FilterChainManager) addFilterChainsForSourcePorts(srcEntry *sourcePrefixEntry, fcProto *v3listenerpb.FilterChain) error { + ports := fcProto.GetFilterChainMatch().GetSourcePorts() + srcPorts := make([]int, 0, len(ports)) + for _, port := range ports { + srcPorts = append(srcPorts, int(port)) + } + + fc, err := fci.filterChainFromProto(fcProto) + if err != nil { + return err + } + + if len(srcPorts) == 0 { + // Use the wildcard port '0', when source ports are unspecified. + if curFC := srcEntry.srcPortMap[0]; curFC != nil { + return errors.New("multiple filter chains with overlapping matching rules are defined") + } + srcEntry.srcPortMap[0] = fc + return nil + } + for _, port := range srcPorts { + if curFC := srcEntry.srcPortMap[port]; curFC != nil { + return errors.New("multiple filter chains with overlapping matching rules are defined") + } + srcEntry.srcPortMap[port] = fc + } + return nil +} + +// filterChainFromProto extracts the relevant information from the FilterChain +// proto and stores it in our internal representation. It also persists any +// RouteNames which need to be queried dynamically via RDS. +func (fci *FilterChainManager) filterChainFromProto(fc *v3listenerpb.FilterChain) (*FilterChain, error) { + filterChain, err := processNetworkFilters(fc.GetFilters()) + if err != nil { + return nil, err + } + // These route names will be dynamically queried via RDS in the wrapped + // listener, which receives the LDS response, if specified for the filter + // chain. + if filterChain.RouteConfigName != "" { + fci.RouteConfigNames[filterChain.RouteConfigName] = true + } + // If the transport_socket field is not specified, it means that the control + // plane has not sent us any security config. This is fine and the server + // will use the fallback credentials configured as part of the + // xdsCredentials. + ts := fc.GetTransportSocket() + if ts == nil { + return filterChain, nil + } + if name := ts.GetName(); name != transportSocketName { + return nil, fmt.Errorf("transport_socket field has unexpected name: %s", name) + } + any := ts.GetTypedConfig() + if any == nil || any.TypeUrl != version.V3DownstreamTLSContextURL { + return nil, fmt.Errorf("transport_socket field has unexpected typeURL: %s", any.TypeUrl) + } + downstreamCtx := &v3tlspb.DownstreamTlsContext{} + if err := proto.Unmarshal(any.GetValue(), downstreamCtx); err != nil { + return nil, fmt.Errorf("failed to unmarshal DownstreamTlsContext in LDS response: %v", err) + } + if downstreamCtx.GetRequireSni().GetValue() { + return nil, fmt.Errorf("require_sni field set to true in DownstreamTlsContext message: %v", downstreamCtx) + } + if downstreamCtx.GetOcspStaplePolicy() != v3tlspb.DownstreamTlsContext_LENIENT_STAPLING { + return nil, fmt.Errorf("ocsp_staple_policy field set to unsupported value in DownstreamTlsContext message: %v", downstreamCtx) + } + // The following fields from `DownstreamTlsContext` are ignore: + // - disable_stateless_session_resumption + // - session_ticket_keys + // - session_ticket_keys_sds_secret_config + // - session_timeout + if downstreamCtx.GetCommonTlsContext() == nil { + return nil, errors.New("DownstreamTlsContext in LDS response does not contain a CommonTlsContext") + } + sc, err := securityConfigFromCommonTLSContext(downstreamCtx.GetCommonTlsContext(), true) + if err != nil { + return nil, err + } + if sc == nil { + // sc == nil is a valid case where the control plane has not sent us any + // security configuration. xDS creds will use fallback creds. + return filterChain, nil + } + sc.RequireClientCert = downstreamCtx.GetRequireClientCertificate().GetValue() + if sc.RequireClientCert && sc.RootInstanceName == "" { + return nil, errors.New("security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set") + } + filterChain.SecurityCfg = sc + return filterChain, nil +} + +func processNetworkFilters(filters []*v3listenerpb.Filter) (*FilterChain, error) { + filterChain := &FilterChain{} + seenNames := make(map[string]bool, len(filters)) + seenHCM := false + for _, filter := range filters { + name := filter.GetName() + if name == "" { + return nil, fmt.Errorf("network filters {%+v} is missing name field in filter: {%+v}", filters, filter) + } + if seenNames[name] { + return nil, fmt.Errorf("network filters {%+v} has duplicate filter name %q", filters, name) + } + seenNames[name] = true + + // Network filters have a oneof field named `config_type` where we + // only support `TypedConfig` variant. + switch typ := filter.GetConfigType().(type) { + case *v3listenerpb.Filter_TypedConfig: + // The typed_config field has an `anypb.Any` proto which could + // directly contain the serialized bytes of the actual filter + // configuration, or it could be encoded as a `TypedStruct`. + // TODO: Add support for `TypedStruct`. + tc := filter.GetTypedConfig() + + // The only network filter that we currently support is the v3 + // HttpConnectionManager. So, we can directly check the type_url + // and unmarshal the config. + // TODO: Implement a registry of supported network filters (like + // we have for HTTP filters), when we have to support network + // filters other than HttpConnectionManager. + if tc.GetTypeUrl() != version.V3HTTPConnManagerURL { + return nil, fmt.Errorf("network filters {%+v} has unsupported network filter %q in filter {%+v}", filters, tc.GetTypeUrl(), filter) + } + hcm := &v3httppb.HttpConnectionManager{} + if err := ptypes.UnmarshalAny(tc, hcm); err != nil { + return nil, fmt.Errorf("network filters {%+v} failed unmarshaling of network filter {%+v}: %v", filters, filter, err) + } + // "Any filters after HttpConnectionManager should be ignored during + // connection processing but still be considered for validity. + // HTTPConnectionManager must have valid http_filters." - A36 + filters, err := processHTTPFilters(hcm.GetHttpFilters(), true) + if err != nil { + return nil, fmt.Errorf("network filters {%+v} had invalid server side HTTP Filters {%+v}: %v", filters, hcm.GetHttpFilters(), err) + } + if !seenHCM { + // Validate for RBAC in only the HCM that will be used, since this isn't a logical validation failure, + // it's simply a validation to support RBAC HTTP Filter. + // "HttpConnectionManager.xff_num_trusted_hops must be unset or zero and + // HttpConnectionManager.original_ip_detection_extensions must be empty. If + // either field has an incorrect value, the Listener must be NACKed." - A41 + if hcm.XffNumTrustedHops != 0 { + return nil, fmt.Errorf("xff_num_trusted_hops must be unset or zero %+v", hcm) + } + if len(hcm.OriginalIpDetectionExtensions) != 0 { + return nil, fmt.Errorf("original_ip_detection_extensions must be empty %+v", hcm) + } + + // TODO: Implement terminal filter logic, as per A36. + filterChain.HTTPFilters = filters + seenHCM = true + if !env.RBACSupport { + continue + } + switch hcm.RouteSpecifier.(type) { + case *v3httppb.HttpConnectionManager_Rds: + if hcm.GetRds().GetConfigSource().GetAds() == nil { + return nil, fmt.Errorf("ConfigSource is not ADS: %+v", hcm) + } + name := hcm.GetRds().GetRouteConfigName() + if name == "" { + return nil, fmt.Errorf("empty route_config_name: %+v", hcm) + } + filterChain.RouteConfigName = name + case *v3httppb.HttpConnectionManager_RouteConfig: + // "RouteConfiguration validation logic inherits all + // previous validations made for client-side usage as RDS + // does not distinguish between client-side and + // server-side." - A36 + // Can specify v3 here, as will never get to this function + // if v2. + routeU, err := generateRDSUpdateFromRouteConfiguration(hcm.GetRouteConfig(), nil, false) + if err != nil { + return nil, fmt.Errorf("failed to parse inline RDS resp: %v", err) + } + filterChain.InlineRouteConfig = &routeU + case nil: + // No-op, as no route specifier is a valid configuration on + // the server side. + default: + return nil, fmt.Errorf("unsupported type %T for RouteSpecifier", hcm.RouteSpecifier) + } + } + default: + return nil, fmt.Errorf("network filters {%+v} has unsupported config_type %T in filter %s", filters, typ, filter.GetName()) + } + } + if !seenHCM { + return nil, fmt.Errorf("network filters {%+v} missing HttpConnectionManager filter", filters) + } + return filterChain, nil +} + +// FilterChainLookupParams wraps parameters to be passed to Lookup. +type FilterChainLookupParams struct { + // IsUnspecified indicates whether the server is listening on a wildcard + // address, "0.0.0.0" for IPv4 and "::" for IPv6. Only when this is set to + // true, do we consider the destination prefixes specified in the filter + // chain match criteria. + IsUnspecifiedListener bool + // DestAddr is the local address of an incoming connection. + DestAddr net.IP + // SourceAddr is the remote address of an incoming connection. + SourceAddr net.IP + // SourcePort is the remote port of an incoming connection. + SourcePort int +} + +// Lookup returns the most specific matching filter chain to be used for an +// incoming connection on the server side. +// +// Returns a non-nil error if no matching filter chain could be found or +// multiple matching filter chains were found, and in both cases, the incoming +// connection must be dropped. +func (fci *FilterChainManager) Lookup(params FilterChainLookupParams) (*FilterChain, error) { + dstPrefixes := filterByDestinationPrefixes(fci.dstPrefixes, params.IsUnspecifiedListener, params.DestAddr) + if len(dstPrefixes) == 0 { + if fci.def != nil { + return fci.def, nil + } + return nil, fmt.Errorf("no matching filter chain based on destination prefix match for %+v", params) + } + + srcType := SourceTypeExternal + if params.SourceAddr.Equal(params.DestAddr) || params.SourceAddr.IsLoopback() { + srcType = SourceTypeSameOrLoopback + } + srcPrefixes := filterBySourceType(dstPrefixes, srcType) + if len(srcPrefixes) == 0 { + if fci.def != nil { + return fci.def, nil + } + return nil, fmt.Errorf("no matching filter chain based on source type match for %+v", params) + } + srcPrefixEntry, err := filterBySourcePrefixes(srcPrefixes, params.SourceAddr) + if err != nil { + return nil, err + } + if fc := filterBySourcePorts(srcPrefixEntry, params.SourcePort); fc != nil { + return fc, nil + } + if fci.def != nil { + return fci.def, nil + } + return nil, fmt.Errorf("no matching filter chain after all match criteria for %+v", params) +} + +// filterByDestinationPrefixes is the first stage of the filter chain +// matching algorithm. It takes the complete set of configured filter chain +// matchers and returns the most specific matchers based on the destination +// prefix match criteria (the prefixes which match the most number of bits). +func filterByDestinationPrefixes(dstPrefixes []*destPrefixEntry, isUnspecified bool, dstAddr net.IP) []*destPrefixEntry { + if !isUnspecified { + // Destination prefix matchers are considered only when the listener is + // bound to the wildcard address. + return dstPrefixes + } + + var matchingDstPrefixes []*destPrefixEntry + maxSubnetMatch := noPrefixMatch + for _, prefix := range dstPrefixes { + if prefix.net != nil && !prefix.net.Contains(dstAddr) { + // Skip prefixes which don't match. + continue + } + // For unspecified prefixes, since we do not store a real net.IPNet + // inside prefix, we do not perform a match. Instead we simply set + // the matchSize to -1, which is less than the matchSize (0) for a + // wildcard prefix. + matchSize := unspecifiedPrefixMatch + if prefix.net != nil { + matchSize, _ = prefix.net.Mask.Size() + } + if matchSize < maxSubnetMatch { + continue + } + if matchSize > maxSubnetMatch { + maxSubnetMatch = matchSize + matchingDstPrefixes = make([]*destPrefixEntry, 0, 1) + } + matchingDstPrefixes = append(matchingDstPrefixes, prefix) + } + return matchingDstPrefixes +} + +// filterBySourceType is the second stage of the matching algorithm. It +// trims the filter chains based on the most specific source type match. +func filterBySourceType(dstPrefixes []*destPrefixEntry, srcType SourceType) []*sourcePrefixes { + var ( + srcPrefixes []*sourcePrefixes + bestSrcTypeMatch int + ) + for _, prefix := range dstPrefixes { + var ( + srcPrefix *sourcePrefixes + match int + ) + switch srcType { + case SourceTypeExternal: + match = int(SourceTypeExternal) + srcPrefix = prefix.srcTypeArr[match] + case SourceTypeSameOrLoopback: + match = int(SourceTypeSameOrLoopback) + srcPrefix = prefix.srcTypeArr[match] + } + if srcPrefix == nil { + match = int(SourceTypeAny) + srcPrefix = prefix.srcTypeArr[match] + } + if match < bestSrcTypeMatch { + continue + } + if match > bestSrcTypeMatch { + bestSrcTypeMatch = match + srcPrefixes = make([]*sourcePrefixes, 0) + } + if srcPrefix != nil { + // The source type array always has 3 entries, but these could be + // nil if the appropriate source type match was not specified. + srcPrefixes = append(srcPrefixes, srcPrefix) + } + } + return srcPrefixes +} + +// filterBySourcePrefixes is the third stage of the filter chain matching +// algorithm. It trims the filter chains based on the source prefix. At most one +// filter chain with the most specific match progress to the next stage. +func filterBySourcePrefixes(srcPrefixes []*sourcePrefixes, srcAddr net.IP) (*sourcePrefixEntry, error) { + var matchingSrcPrefixes []*sourcePrefixEntry + maxSubnetMatch := noPrefixMatch + for _, sp := range srcPrefixes { + for _, prefix := range sp.srcPrefixes { + if prefix.net != nil && !prefix.net.Contains(srcAddr) { + // Skip prefixes which don't match. + continue + } + // For unspecified prefixes, since we do not store a real net.IPNet + // inside prefix, we do not perform a match. Instead we simply set + // the matchSize to -1, which is less than the matchSize (0) for a + // wildcard prefix. + matchSize := unspecifiedPrefixMatch + if prefix.net != nil { + matchSize, _ = prefix.net.Mask.Size() + } + if matchSize < maxSubnetMatch { + continue + } + if matchSize > maxSubnetMatch { + maxSubnetMatch = matchSize + matchingSrcPrefixes = make([]*sourcePrefixEntry, 0, 1) + } + matchingSrcPrefixes = append(matchingSrcPrefixes, prefix) + } + } + if len(matchingSrcPrefixes) == 0 { + // Finding no match is not an error condition. The caller will end up + // using the default filter chain if one was configured. + return nil, nil + } + // We expect at most a single matching source prefix entry at this point. If + // we have multiple entries here, and some of their source port matchers had + // wildcard entries, we could be left with more than one matching filter + // chain and hence would have been flagged as an invalid configuration at + // config validation time. + if len(matchingSrcPrefixes) != 1 { + return nil, errors.New("multiple matching filter chains") + } + return matchingSrcPrefixes[0], nil +} + +// filterBySourcePorts is the last stage of the filter chain matching +// algorithm. It trims the filter chains based on the source ports. +func filterBySourcePorts(spe *sourcePrefixEntry, srcPort int) *FilterChain { + if spe == nil { + return nil + } + // A match could be a wildcard match (this happens when the match + // criteria does not specify source ports) or a specific port match (this + // happens when the match criteria specifies a set of ports and the source + // port of the incoming connection matches one of the specified ports). The + // latter is considered to be a more specific match. + if fc := spe.srcPortMap[srcPort]; fc != nil { + return fc + } + if fc := spe.srcPortMap[0]; fc != nil { + return fc + } + return nil +} diff --git a/xds/internal/xdsclient/filter_chain_test.go b/xds/internal/xdsclient/filter_chain_test.go new file mode 100644 index 00000000000..2cc73b0a511 --- /dev/null +++ b/xds/internal/xdsclient/filter_chain_test.go @@ -0,0 +1,2939 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "testing" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3routerpb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/router/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/wrapperspb" + + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/httpfilter/router" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/grpc/xds/internal/version" +) + +const ( + topLevel = "top level" + vhLevel = "virtual host level" + rLevel = "route level" +) + +var ( + routeConfig = &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}} + inlineRouteConfig = &RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*Route{{Prefix: newStringP("/"), RouteAction: RouteActionNonForwardingAction}}, + }}} + emptyValidNetworkFilters = []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + } + validServerSideHTTPFilter1 = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } + validServerSideHTTPFilter2 = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter2", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } + emptyRouterFilter = e2e.RouterHTTPFilter + routerBuilder = httpfilter.Get(router.TypeURL) + routerConfig, _ = routerBuilder.ParseFilterConfig(testutils.MarshalAny(&v3routerpb.Router{})) + routerFilter = HTTPFilter{Name: "router", Filter: routerBuilder, Config: routerConfig} + routerFilterList = []HTTPFilter{routerFilter} +) + +// TestNewFilterChainImpl_Failure_BadMatchFields verifies cases where we have a +// single filter chain with match criteria that contains unsupported fields. +func TestNewFilterChainImpl_Failure_BadMatchFields(t *testing.T) { + tests := []struct { + desc string + lis *v3listenerpb.Listener + }{ + { + desc: "unsupported destination port field", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{DestinationPort: &wrapperspb.UInt32Value{Value: 666}}, + }, + }, + }, + }, + { + desc: "unsupported server names field", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ServerNames: []string{"example-server"}}, + }, + }, + }, + }, + { + desc: "unsupported transport protocol ", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{TransportProtocol: "tls"}, + }, + }, + }, + }, + { + desc: "unsupported application protocol field", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ApplicationProtocols: []string{"h2"}}, + }, + }, + }, + }, + { + desc: "bad dest address prefix", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{{AddressPrefix: "a.b.c.d"}}}, + }, + }, + }, + }, + { + desc: "bad dest prefix length", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.1.1.0", 50)}}, + }, + }, + }, + }, + { + desc: "bad source address prefix", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePrefixRanges: []*v3corepb.CidrRange{{AddressPrefix: "a.b.c.d"}}}, + }, + }, + }, + }, + { + desc: "bad source prefix length", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.1.1.0", 50)}}, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if fci, err := NewFilterChainManager(test.lis); err == nil { + t.Fatalf("NewFilterChainManager() returned %v when expected to fail", fci) + } + }) + } +} + +// TestNewFilterChainImpl_Failure_OverlappingMatchingRules verifies cases where +// there are multiple filter chains and they have overlapping match rules. +func TestNewFilterChainImpl_Failure_OverlappingMatchingRules(t *testing.T) { + tests := []struct { + desc string + lis *v3listenerpb.Listener + }{ + { + desc: "matching destination prefixes with no other matchers", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16), cidrRangeFromAddressAndPrefixLen("10.0.0.0", 0)}, + }, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.2.2", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + }, + { + desc: "matching source type", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_ANY}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_EXTERNAL}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_EXTERNAL}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + }, + { + desc: "matching source prefixes", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16), cidrRangeFromAddressAndPrefixLen("10.0.0.0", 0)}, + }, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.2.2", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + }, + { + desc: "matching source ports", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3, 4, 5}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{5, 6, 7}}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + }, + } + + const wantErr = "multiple filter chains with overlapping matching rules are defined" + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if _, err := NewFilterChainManager(test.lis); err == nil || !strings.Contains(err.Error(), wantErr) { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: %s", err, wantErr) + } + }) + } +} + +// TestNewFilterChainImpl_Failure_BadSecurityConfig verifies cases where the +// security configuration in the filter chain is invalid. +func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { + tests := []struct { + desc string + lis *v3listenerpb.Listener + wantErr string + }{ + { + desc: "no filter chains", + lis: &v3listenerpb.Listener{}, + wantErr: "no supported filter chains and no default filter chain", + }, + { + desc: "unexpected transport socket name", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{Name: "unsupported-transport-socket-name"}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "transport_socket field has unexpected name", + }, + { + desc: "unexpected transport socket URL", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{}), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "transport_socket field has unexpected typeURL", + }, + { + desc: "badly marshaled transport socket", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3DownstreamTLSContextURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "failed to unmarshal DownstreamTlsContext in LDS response", + }, + { + desc: "missing CommonTlsContext", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{}), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "DownstreamTlsContext in LDS response does not contain a CommonTlsContext", + }, + { + desc: "require_sni-set-to-true-in-downstreamTlsContext", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireSni: &wrapperspb.BoolValue{Value: true}, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "require_sni field set to true in DownstreamTlsContext message", + }, + { + desc: "unsupported-ocsp_staple_policy-in-downstreamTlsContext", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + OcspStaplePolicy: v3tlspb.DownstreamTlsContext_STRICT_STAPLING, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "ocsp_staple_policy field set to unsupported value in DownstreamTlsContext message", + }, + { + desc: "unsupported validation context in transport socket", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "validation context contains unexpected type", + }, + { + desc: "unsupported match_subject_alt_names field in transport socket", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "validation context contains unexpected type", + }, + { + desc: "no root certificate provider with require_client_cert", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set", + }, + { + desc: "no identity certificate provider", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{}, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + wantErr: "security configuration on the server-side does not contain identity certificate provider instance name", + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + _, err := NewFilterChainManager(test.lis) + if err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: %s", err, test.wantErr) + } + }) + } +} + +// TestNewFilterChainImpl_Success_RouteUpdate tests the construction of the +// filter chain with valid HTTP Filters present. +func TestNewFilterChainImpl_Success_RouteUpdate(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + name string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + name: "rds", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + RouteConfigName: "route-1", + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + RouteConfigName: "route-1", + HTTPFilters: routerFilterList, + }, + RouteConfigNames: map[string]bool{"route-1": true}, + }, + }, + { + name: "inline route config", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + // two rds tests whether the Filter Chain Manager successfully persists + // the two RDS names that need to be dynamically queried. + { + name: "two rds", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: "route-2", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + RouteConfigName: "route-1", + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + RouteConfigName: "route-2", + HTTPFilters: routerFilterList, + }, + RouteConfigNames: map[string]bool{ + "route-1": true, + "route-2": true, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpOpts) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +// TestNewFilterChainImpl_Failure_BadRouteUpdate verifies cases where the Route +// Update in the filter chain are invalid. +func TestNewFilterChainImpl_Failure_BadRouteUpdate(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + name string + lis *v3listenerpb.Listener + wantErr string + }{ + { + name: "not-ads", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + RouteConfigName: "route-1", + }, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantErr: "ConfigSource is not ADS", + }, + { + name: "unsupported-route-specifier", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + wantErr: "unsupported type", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := NewFilterChainManager(test.lis) + if err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: %s", err, test.wantErr) + } + }) + } +} + +// TestNewFilterChainImpl_Failure_BadHTTPFilters verifies cases where the HTTP +// Filters in the filter chain are invalid. +func TestNewFilterChainImpl_Failure_BadHTTPFilters(t *testing.T) { + tests := []struct { + name string + lis *v3listenerpb.Listener + wantErr string + }{ + { + name: "client side HTTP filter", + lis: &v3listenerpb.Listener{ + Name: "grpc/server?xds.resource.listening_address=0.0.0.0:9999", + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + { + Name: "clientOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, + }, + }, + }), + }, + }, + }, + }, + }, + }, + wantErr: "invalid server side HTTP Filters", + }, + { + name: "one valid then one invalid HTTP filter", + lis: &v3listenerpb.Listener{ + Name: "grpc/server?xds.resource.listening_address=0.0.0.0:9999", + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + { + Name: "clientOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, + }, + }, + }), + }, + }, + }, + }, + }, + }, + wantErr: "invalid server side HTTP Filters", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := NewFilterChainManager(test.lis) + if err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: %s", err, test.wantErr) + } + }) + } +} + +// TestNewFilterChainImpl_Success_HTTPFilters tests the construction of the +// filter chain with valid HTTP Filters present. +func TestNewFilterChainImpl_Success_HTTPFilters(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + name string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + name: "singular valid http filter", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: {HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + { + name: "two valid http filters", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: {HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + // In the case of two HTTP Connection Manager's being present, the + // second HTTP Connection Manager should be validated, but ignored. + { + name: "two hcms", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + { + Name: "hcm2", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + { + Name: "hcm2", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + emptyRouterFilter, + }, + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: {HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + routerFilter, + }, + InlineRouteConfig: inlineRouteConfig, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpOpts) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +// TestNewFilterChainImpl_Success_SecurityConfig verifies cases where the +// security configuration in the filter chain contains valid data. +func TestNewFilterChainImpl_Success_SecurityConfig(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + desc string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + desc: "empty transport socket", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: emptyValidNetworkFilters, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "no validation context", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "validation context with certificate provider", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultRootPluginInstance", + CertificateName: "defaultRootCertName", + }, + }, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + RootInstanceName: "rootPluginInstance", + RootCertName: "rootCertName", + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + RootInstanceName: "defaultRootPluginInstance", + RootCertName: "defaultRootCertName", + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpopts.EquateEmpty()) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +// TestNewFilterChainImpl_Success_UnsupportedMatchFields verifies cases where +// there are multiple filter chains, and one of them is valid while the other +// contains unsupported match fields. These configurations should lead to +// success at config validation time and the filter chains which contains +// unsupported match fields will be skipped at lookup time. +func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + unspecifiedEntry := &destPrefixEntry{ + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + } + + tests := []struct { + desc string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + desc: "unsupported destination port", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "good-chain", + Filters: emptyValidNetworkFilters, + }, + { + Name: "unsupported-destination-port", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + DestinationPort: &wrapperspb.UInt32Value{Value: 666}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: unspecifiedEntry, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "unsupported server names", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "good-chain", + Filters: emptyValidNetworkFilters, + }, + { + Name: "unsupported-server-names", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + ServerNames: []string{"example-server"}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: unspecifiedEntry, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "unsupported transport protocol", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "good-chain", + Filters: emptyValidNetworkFilters, + }, + { + Name: "unsupported-transport-protocol", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + TransportProtocol: "tls", + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: unspecifiedEntry, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "unsupported application protocol", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "good-chain", + Filters: emptyValidNetworkFilters, + }, + { + Name: "unsupported-application-protocol", + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + ApplicationProtocols: []string{"h2"}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: unspecifiedEntry, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpopts.EquateEmpty()) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +// TestNewFilterChainImpl_Success_AllCombinations verifies different +// combinations of the supported match criteria. +func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + tests := []struct { + desc string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + desc: "multiple destination prefixes", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + // Unspecified destination prefix. + FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, + }, + { + // v4 wildcard destination prefix. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("0.0.0.0", 0)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + Filters: emptyValidNetworkFilters, + }, + { + // v6 wildcard destination prefix. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("::", 0)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}}, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "0.0.0.0/0": { + net: ipNetFromCIDR("0.0.0.0/0"), + srcTypeArr: [3]*sourcePrefixes{ + nil, + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "::/0": { + net: ipNetFromCIDR("::/0"), + srcTypeArr: [3]*sourcePrefixes{ + nil, + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "10.0.0.0/8": { + net: ipNetFromCIDR("10.0.0.0/8"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "multiple source types", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + nil, + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "multiple source prefixes", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + "10.0.0.0/8": { + net: ipNetFromCIDR("10.0.0.0/8"), + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.0.0/16"), + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "multiple source ports", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + SourcePorts: []uint32{1, 2, 3}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 1: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + 2: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + 3: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + nil, + nil, + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.0.0/16"), + srcPortMap: map[int]*FilterChain{ + 1: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + 2: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + 3: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + { + desc: "some chains have unsupported fields", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}, + TransportProtocol: "raw_buffer", + }, + Filters: emptyValidNetworkFilters, + }, + { + // This chain will be dropped in favor of the above + // filter chain because they both have the same + // destination prefix, but this one has an empty + // transport protocol while the above chain has the more + // preferred "raw_buffer". + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}, + TransportProtocol: "", + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + { + // This chain will be dropped for unsupported server + // names. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.1", 32)}, + ServerNames: []string{"foo", "bar"}, + }, + Filters: emptyValidNetworkFilters, + }, + { + // This chain will be dropped for unsupported transport + // protocol. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.2", 32)}, + TransportProtocol: "not-raw-buffer", + }, + Filters: emptyValidNetworkFilters, + }, + { + // This chain will be dropped for unsupported + // application protocol. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.3", 32)}, + ApplicationProtocols: []string{"h2"}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.0.0/16": { + net: ipNetFromCIDR("192.168.2.2/16"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "10.0.0.0/8": { + net: ipNetFromCIDR("10.0.0.0/8"), + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + "192.168.100.1/32": { + net: ipNetFromCIDR("192.168.100.1/32"), + srcTypeArr: [3]*sourcePrefixes{}, + }, + "192.168.100.2/32": { + net: ipNetFromCIDR("192.168.100.2/32"), + srcTypeArr: [3]*sourcePrefixes{}, + }, + "192.168.100.3/32": { + net: ipNetFromCIDR("192.168.100.3/32"), + srcTypeArr: [3]*sourcePrefixes{}, + }, + }, + def: &FilterChain{ + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{})) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + +func TestLookup_Failures(t *testing.T) { + tests := []struct { + desc string + lis *v3listenerpb.Listener + params FilterChainLookupParams + wantErr string + }{ + { + desc: "no destination prefix match", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(10, 1, 1, 1), + }, + wantErr: "no matching filter chain based on destination prefix match", + }, + { + desc: "no source type match", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 2), + }, + wantErr: "no matching filter chain based on source type match", + }, + { + desc: "no source prefix match", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 24)}, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + }, + wantErr: "no matching filter chain after all match criteria", + }, + { + desc: "multiple matching filter chains", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, + SourcePorts: []uint32{1}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + // IsUnspecified is not set. This means that the destination + // prefix matchers will be ignored. + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + SourcePort: 1, + }, + wantErr: "multiple matching filter chains", + }, + { + desc: "no default filter chain", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3}}, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + SourcePort: 80, + }, + wantErr: "no matching filter chain after all match criteria", + }, + { + desc: "most specific match dropped for unsupported field", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + // This chain will be picked in the destination prefix + // stage, but will be dropped at the server names stage. + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.1", 32)}, + ServerNames: []string{"foo"}, + }, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.0", 16)}, + }, + Filters: emptyValidNetworkFilters, + }, + }, + }, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + SourcePort: 80, + }, + wantErr: "no matching filter chain based on source type match", + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + fci, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() failed: %v", err) + } + fc, err := fci.Lookup(test.params) + if err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("FilterChainManager.Lookup(%v) = (%v, %v) want (nil, %s)", test.params, fc, err, test.wantErr) + } + }) + } +} + +func TestLookup_Successes(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + lisWithDefaultChain := &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{InstanceName: "instance1"}, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + }, + // A default filter chain with an empty transport socket. + DefaultFilterChain: &v3listenerpb.FilterChain{ + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{InstanceName: "default"}, + }, + }), + }, + }, + Filters: emptyValidNetworkFilters, + }, + } + lisWithoutDefaultChain := &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: transportSocketWithInstanceName("unspecified-dest-and-source-prefix"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("0.0.0.0", 0)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("0.0.0.0", 0)}, + }, + TransportSocket: transportSocketWithInstanceName("wildcard-prefixes-v4"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("::", 0)}, + }, + TransportSocket: transportSocketWithInstanceName("wildcard-source-prefix-v6"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-unspecified-source-type"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 24)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-specific-source-type"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 24)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.92.1", 24)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + }, + TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-specific-source-type-specific-source-prefix"), + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 24)}, + SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.92.1", 24)}, + SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, + SourcePorts: []uint32{80}, + }, + TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-specific-source-type-specific-source-prefix-specific-source-port"), + Filters: emptyValidNetworkFilters, + }, + }, + } + + tests := []struct { + desc string + lis *v3listenerpb.Listener + params FilterChainLookupParams + wantFC *FilterChain + }{ + { + desc: "default filter chain", + lis: lisWithDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(10, 1, 1, 1), + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "default"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "unspecified destination match", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.ParseIP("2001:68::db8"), + SourceAddr: net.IPv4(10, 1, 1, 1), + SourcePort: 1, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "unspecified-dest-and-source-prefix"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "wildcard destination match v4", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(10, 1, 1, 1), + SourceAddr: net.IPv4(10, 1, 1, 1), + SourcePort: 1, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "wildcard-prefixes-v4"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "wildcard source match v6", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.ParseIP("2001:68::1"), + SourceAddr: net.ParseIP("2001:68::2"), + SourcePort: 1, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "wildcard-source-prefix-v6"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "specific destination and wildcard source type match", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 100, 1), + SourceAddr: net.IPv4(192, 168, 100, 1), + SourcePort: 80, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "specific-destination-prefix-unspecified-source-type"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "specific destination and source type match", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 1, 1), + SourceAddr: net.IPv4(10, 1, 1, 1), + SourcePort: 80, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "specific-destination-prefix-specific-source-type"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "specific destination source type and source prefix", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 1, 1), + SourceAddr: net.IPv4(192, 168, 92, 100), + SourcePort: 70, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "specific-destination-prefix-specific-source-type-specific-source-prefix"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + { + desc: "specific destination source type source prefix and source port", + lis: lisWithoutDefaultChain, + params: FilterChainLookupParams{ + IsUnspecifiedListener: true, + DestAddr: net.IPv4(192, 168, 1, 1), + SourceAddr: net.IPv4(192, 168, 92, 100), + SourcePort: 80, + }, + wantFC: &FilterChain{ + SecurityCfg: &SecurityConfig{IdentityInstanceName: "specific-destination-prefix-specific-source-type-specific-source-prefix-specific-source-port"}, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + fci, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() failed: %v", err) + } + gotFC, err := fci.Lookup(test.params) + if err != nil { + t.Fatalf("FilterChainManager.Lookup(%v) failed: %v", test.params, err) + } + if !cmp.Equal(gotFC, test.wantFC, cmpopts.EquateEmpty()) { + t.Fatalf("FilterChainManager.Lookup(%v) = %v, want %v", test.params, gotFC, test.wantFC) + } + }) + } +} + +type filterCfg struct { + httpfilter.FilterConfig + // Level is what differentiates top level filters ("top level") vs. second + // level ("virtual host level"), and third level ("route level"). + level string +} + +type filterBuilder struct { + httpfilter.Filter +} + +var _ httpfilter.ServerInterceptorBuilder = &filterBuilder{} + +func (fb *filterBuilder) BuildServerInterceptor(config httpfilter.FilterConfig, override httpfilter.FilterConfig) (iresolver.ServerInterceptor, error) { + var level string + level = config.(filterCfg).level + + if override != nil { + level = override.(filterCfg).level + } + return &serverInterceptor{level: level}, nil +} + +type serverInterceptor struct { + level string +} + +func (si *serverInterceptor) AllowRPC(context.Context) error { + return errors.New(si.level) +} + +func TestHTTPFilterInstantiation(t *testing.T) { + tests := []struct { + name string + filters []HTTPFilter + routeConfig RouteConfigUpdate + // A list of strings which will be built from iterating through the + // filters ["top level", "vh level", "route level", "route level"...] + // wantErrs is the list of error strings that will be constructed from + // the deterministic iteration through the vh list and route list. The + // error string will be determined by the level of config that the + // filter builder receives (i.e. top level, vs. virtual host level vs. + // route level). + wantErrs []string + }{ + { + name: "one http filter no overrides", + filters: []HTTPFilter{ + {Name: "server-interceptor", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + }, + }, + }, + }}, + wantErrs: []string{topLevel}, + }, + { + name: "one http filter vh override", + filters: []HTTPFilter{ + {Name: "server-interceptor", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + }, + }, + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor": filterCfg{level: vhLevel}, + }, + }, + }}, + wantErrs: []string{vhLevel}, + }, + { + name: "one http filter route override", + filters: []HTTPFilter{ + {Name: "server-interceptor", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor": filterCfg{level: rLevel}, + }, + }, + }, + }, + }}, + wantErrs: []string{rLevel}, + }, + // This tests the scenario where there are three http filters, and one + // gets overridden by route and one by virtual host. + { + name: "three http filters vh override route override", + filters: []HTTPFilter{ + {Name: "server-interceptor1", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + {Name: "server-interceptor2", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + {Name: "server-interceptor3", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor3": filterCfg{level: rLevel}, + }, + }, + }, + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor2": filterCfg{level: vhLevel}, + }, + }, + }}, + wantErrs: []string{topLevel, vhLevel, rLevel}, + }, + // This tests the scenario where there are three http filters, and two + // virtual hosts with different vh + route overrides for each virtual + // host. + { + name: "three http filters two vh", + filters: []HTTPFilter{ + {Name: "server-interceptor1", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + {Name: "server-interceptor2", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + {Name: "server-interceptor3", Filter: &filterBuilder{}, Config: filterCfg{level: topLevel}}, + }, + routeConfig: RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{ + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor3": filterCfg{level: rLevel}, + }, + }, + }, + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor2": filterCfg{level: vhLevel}, + }, + }, + { + Domains: []string{"target"}, + Routes: []*Route{{ + Prefix: newStringP("1"), + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor1": filterCfg{level: rLevel}, + "server-interceptor2": filterCfg{level: rLevel}, + }, + }, + }, + HTTPFilterConfigOverride: map[string]httpfilter.FilterConfig{ + "server-interceptor2": filterCfg{level: vhLevel}, + "server-interceptor3": filterCfg{level: vhLevel}, + }, + }, + }}, + wantErrs: []string{topLevel, vhLevel, rLevel, rLevel, rLevel, vhLevel}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + fc := FilterChain{ + HTTPFilters: test.filters, + } + vhswi, err := fc.ConstructUsableRouteConfiguration(test.routeConfig) + if err != nil { + t.Fatalf("Error constructing usable route configuration: %v", err) + } + // Build out list of errors by iterating through the virtual hosts and routes, + // and running the filters in route configurations. + var errs []string + for _, vh := range vhswi { + for _, r := range vh.Routes { + for _, int := range r.Interceptors { + errs = append(errs, int.AllowRPC(context.Background()).Error()) + } + } + } + if !cmp.Equal(errs, test.wantErrs) { + t.Fatalf("List of errors %v, want %v", errs, test.wantErrs) + } + }) + } +} + +// The Equal() methods defined below help with using cmp.Equal() on these types +// which contain all unexported fields. + +func (fci *FilterChainManager) Equal(other *FilterChainManager) bool { + if (fci == nil) != (other == nil) { + return false + } + if fci == nil { + return true + } + switch { + case !cmp.Equal(fci.dstPrefixMap, other.dstPrefixMap, cmpopts.EquateEmpty()): + return false + // TODO: Support comparing dstPrefixes slice? + case !cmp.Equal(fci.def, other.def, cmpopts.EquateEmpty(), protocmp.Transform()): + return false + case !cmp.Equal(fci.RouteConfigNames, other.RouteConfigNames, cmpopts.EquateEmpty()): + return false + } + return true +} + +func (dpe *destPrefixEntry) Equal(other *destPrefixEntry) bool { + if (dpe == nil) != (other == nil) { + return false + } + if dpe == nil { + return true + } + if !cmp.Equal(dpe.net, other.net) { + return false + } + for i, st := range dpe.srcTypeArr { + if !cmp.Equal(st, other.srcTypeArr[i], cmpopts.EquateEmpty()) { + return false + } + } + return true +} + +func (sp *sourcePrefixes) Equal(other *sourcePrefixes) bool { + if (sp == nil) != (other == nil) { + return false + } + if sp == nil { + return true + } + // TODO: Support comparing srcPrefixes slice? + return cmp.Equal(sp.srcPrefixMap, other.srcPrefixMap, cmpopts.EquateEmpty()) +} + +func (spe *sourcePrefixEntry) Equal(other *sourcePrefixEntry) bool { + if (spe == nil) != (other == nil) { + return false + } + if spe == nil { + return true + } + switch { + case !cmp.Equal(spe.net, other.net): + return false + case !cmp.Equal(spe.srcPortMap, other.srcPortMap, cmpopts.EquateEmpty(), protocmp.Transform()): + return false + } + return true +} + +// The String() methods defined below help with debugging test failures as the +// regular %v or %+v formatting directives do not expands pointer fields inside +// structs, and these types have a lot of pointers pointing to other structs. +func (fci *FilterChainManager) String() string { + if fci == nil { + return "" + } + + var sb strings.Builder + if fci.dstPrefixMap != nil { + sb.WriteString("destination_prefix_map: map {\n") + for k, v := range fci.dstPrefixMap { + sb.WriteString(fmt.Sprintf("%q: %v\n", k, v)) + } + sb.WriteString("}\n") + } + if fci.dstPrefixes != nil { + sb.WriteString("destination_prefixes: [") + for _, p := range fci.dstPrefixes { + sb.WriteString(fmt.Sprintf("%v ", p)) + } + sb.WriteString("]") + } + if fci.def != nil { + sb.WriteString(fmt.Sprintf("default_filter_chain: %+v ", fci.def)) + } + return sb.String() +} + +func (dpe *destPrefixEntry) String() string { + if dpe == nil { + return "" + } + var sb strings.Builder + if dpe.net != nil { + sb.WriteString(fmt.Sprintf("destination_prefix: %s ", dpe.net.String())) + } + sb.WriteString("source_types_array: [") + for _, st := range dpe.srcTypeArr { + sb.WriteString(fmt.Sprintf("%v ", st)) + } + sb.WriteString("]") + return sb.String() +} + +func (sp *sourcePrefixes) String() string { + if sp == nil { + return "" + } + var sb strings.Builder + if sp.srcPrefixMap != nil { + sb.WriteString("source_prefix_map: map {") + for k, v := range sp.srcPrefixMap { + sb.WriteString(fmt.Sprintf("%q: %v ", k, v)) + } + sb.WriteString("}") + } + if sp.srcPrefixes != nil { + sb.WriteString("source_prefixes: [") + for _, p := range sp.srcPrefixes { + sb.WriteString(fmt.Sprintf("%v ", p)) + } + sb.WriteString("]") + } + return sb.String() +} + +func (spe *sourcePrefixEntry) String() string { + if spe == nil { + return "" + } + var sb strings.Builder + if spe.net != nil { + sb.WriteString(fmt.Sprintf("source_prefix: %s ", spe.net.String())) + } + if spe.srcPortMap != nil { + sb.WriteString("source_ports_map: map {") + for k, v := range spe.srcPortMap { + sb.WriteString(fmt.Sprintf("%d: %+v ", k, v)) + } + sb.WriteString("}") + } + return sb.String() +} + +func (f *FilterChain) String() string { + if f == nil || f.SecurityCfg == nil { + return "" + } + return fmt.Sprintf("security_config: %v", f.SecurityCfg) +} + +func ipNetFromCIDR(cidr string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + panic(err) + } + return ipnet +} + +func transportSocketWithInstanceName(name string) *v3corepb.TransportSocket { + return &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{InstanceName: name}, + }, + }), + }, + } +} + +func cidrRangeFromAddressAndPrefixLen(address string, len int) *v3corepb.CidrRange { + return &v3corepb.CidrRange{ + AddressPrefix: address, + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(len), + }, + } +} diff --git a/xds/internal/xdsclient/lds_test.go b/xds/internal/xdsclient/lds_test.go new file mode 100644 index 00000000000..18e2f55ede4 --- /dev/null +++ b/xds/internal/xdsclient/lds_test.go @@ -0,0 +1,1944 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "fmt" + "strings" + "testing" + "time" + + v1typepb "github.com/cncf/udpa/go/udpa/type/v1" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + "github.com/golang/protobuf/proto" + spb "github.com/golang/protobuf/ptypes/struct" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/durationpb" + + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/httpfilter" + _ "google.golang.org/grpc/xds/internal/httpfilter/router" + "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/grpc/xds/internal/version" + + v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" + v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v2httppb "github.com/envoyproxy/go-control-plane/envoy/config/filter/network/http_connection_manager/v2" + v2listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v2" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + anypb "github.com/golang/protobuf/ptypes/any" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" +) + +func (s) TestUnmarshalListener_ClientSide(t *testing.T) { + const ( + v2LDSTarget = "lds.target.good:2222" + v3LDSTarget = "lds.target.good:3333" + v2RouteConfigName = "v2RouteConfig" + v3RouteConfigName = "v3RouteConfig" + routeName = "routeName" + testVersion = "test-version-lds-client" + ) + + var ( + v2Lis = testutils.MarshalAny(&v2xdspb.Listener{ + Name: v2LDSTarget, + ApiListener: &v2listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v2httppb.HttpConnectionManager{ + RouteSpecifier: &v2httppb.HttpConnectionManager_Rds{ + Rds: &v2httppb.Rds{ + ConfigSource: &v2corepb.ConfigSource{ + ConfigSourceSpecifier: &v2corepb.ConfigSource_Ads{Ads: &v2corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: v2RouteConfigName, + }, + }, + }), + }, + }) + customFilter = &v3httppb.HttpFilter{ + Name: "customFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + } + typedStructFilter = &v3httppb.HttpFilter{ + Name: "customFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: wrappedCustomFilterTypedStructConfig}, + } + customOptionalFilter = &v3httppb.HttpFilter{ + Name: "customFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + IsOptional: true, + } + customFilter2 = &v3httppb.HttpFilter{ + Name: "customFilter2", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + } + errFilter = &v3httppb.HttpFilter{ + Name: "errFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: errFilterConfig}, + } + errOptionalFilter = &v3httppb.HttpFilter{ + Name: "errFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: errFilterConfig}, + IsOptional: true, + } + clientOnlyCustomFilter = &v3httppb.HttpFilter{ + Name: "clientOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, + } + serverOnlyCustomFilter = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } + serverOnlyOptionalCustomFilter = &v3httppb.HttpFilter{ + Name: "serverOnlyOptionalCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + IsOptional: true, + } + unknownFilter = &v3httppb.HttpFilter{ + Name: "unknownFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: unknownFilterConfig}, + } + unknownOptionalFilter = &v3httppb.HttpFilter{ + Name: "unknownFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: unknownFilterConfig}, + IsOptional: true, + } + v3LisWithInlineRoute = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: routeName, + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{v3LDSTarget}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}, + }}}}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + }), + }, + }) + v3LisWithFilters = func(fs ...*v3httppb.HttpFilter) *anypb.Any { + fs = append(fs, emptyRouterFilter) + return testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny( + &v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: v3RouteConfigName, + }, + }, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + HttpFilters: fs, + }), + }, + }) + } + v3LisToTestRBAC = func(xffNumTrustedHops uint32, originalIpDetectionExtensions []*v3corepb.TypedExtensionConfig) *anypb.Any { + return testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny( + &v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: v3RouteConfigName, + }, + }, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + XffNumTrustedHops: xffNumTrustedHops, + OriginalIpDetectionExtensions: originalIpDetectionExtensions, + }), + }, + }) + } + errMD = UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + } + ) + + tests := []struct { + name string + resources []*anypb.Any + wantUpdate map[string]ListenerUpdateErrTuple + wantMD UpdateMetadata + wantErr bool + }{ + { + name: "non-listener resource", + resources: []*anypb.Any{{TypeUrl: version.V3HTTPConnManagerURL}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "badly marshaled listener resource", + resources: []*anypb.Any{ + { + TypeUrl: version.V3ListenerURL, + Value: func() []byte { + lis := &v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: &anypb.Any{ + TypeUrl: version.V3HTTPConnManagerURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + } + mLis, _ := proto.Marshal(lis) + return mLis + }(), + }, + }, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "wrong type in apiListener", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v2xdspb.Listener{}), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "empty httpConnMgr in apiListener", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{}, + }, + }), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "scopedRoutes routeConfig in apiListener", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, + }), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "rds.ConfigSource in apiListener is not ADS", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Path{ + Path: "/some/path", + }, + }, + RouteConfigName: v3RouteConfigName, + }, + }, + }), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "empty resource list", + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with no filters", + resources: []*anypb.Any{v3LisWithFilters()}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, HTTPFilters: routerFilterList, Raw: v3LisWithFilters()}}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 no terminal filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny( + &v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{ + ConfigSource: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{Ads: &v3corepb.AggregatedConfigSource{}}, + }, + RouteConfigName: v3RouteConfigName, + }, + }, + CommonHttpProtocolOptions: &v3corepb.HttpProtocolOptions{ + MaxStreamDuration: durationpb.New(time.Second), + }, + }), + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with custom filter", + resources: []*anypb.Any{v3LisWithFilters(customFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{ + { + Name: "customFilter", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterConfig}, + }, + routerFilter, + }, + Raw: v3LisWithFilters(customFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with custom filter in typed struct", + resources: []*anypb.Any{v3LisWithFilters(typedStructFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{ + { + Name: "customFilter", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterTypedStructConfig}, + }, + routerFilter, + }, + Raw: v3LisWithFilters(typedStructFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with optional custom filter", + resources: []*anypb.Any{v3LisWithFilters(customOptionalFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{ + { + Name: "customFilter", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterConfig}, + }, + routerFilter, + }, + Raw: v3LisWithFilters(customOptionalFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with two filters with same name", + resources: []*anypb.Any{v3LisWithFilters(customFilter, customFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with two filters - same type different name", + resources: []*anypb.Any{v3LisWithFilters(customFilter, customFilter2)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{{ + Name: "customFilter", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterConfig}, + }, { + Name: "customFilter2", + Filter: httpFilter{}, + Config: filterConfig{Cfg: customFilterConfig}, + }, + routerFilter, + }, + Raw: v3LisWithFilters(customFilter, customFilter2), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with server-only filter", + resources: []*anypb.Any{v3LisWithFilters(serverOnlyCustomFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with optional server-only filter", + resources: []*anypb.Any{v3LisWithFilters(serverOnlyOptionalCustomFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, + MaxStreamDuration: time.Second, + Raw: v3LisWithFilters(serverOnlyOptionalCustomFilter), + HTTPFilters: routerFilterList, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with client-only filter", + resources: []*anypb.Any{v3LisWithFilters(clientOnlyCustomFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{ + { + Name: "clientOnlyCustomFilter", + Filter: clientOnlyHTTPFilter{}, + Config: filterConfig{Cfg: clientOnlyCustomFilterConfig}, + }, + routerFilter}, + Raw: v3LisWithFilters(clientOnlyCustomFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 with err filter", + resources: []*anypb.Any{v3LisWithFilters(errFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with optional err filter", + resources: []*anypb.Any{v3LisWithFilters(errOptionalFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with unknown filter", + resources: []*anypb.Any{v3LisWithFilters(unknownFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 with unknown filter (optional)", + resources: []*anypb.Any{v3LisWithFilters(unknownOptionalFilter)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, + MaxStreamDuration: time.Second, + HTTPFilters: routerFilterList, + Raw: v3LisWithFilters(unknownOptionalFilter), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v2 listener resource", + resources: []*anypb.Any{v2Lis}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v2LDSTarget: {Update: ListenerUpdate{RouteConfigName: v2RouteConfigName, Raw: v2Lis}}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "v3 listener resource", + resources: []*anypb.Any{v3LisWithFilters()}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, HTTPFilters: routerFilterList, Raw: v3LisWithFilters()}}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + // "To allow equating RBAC's direct_remote_ip and + // remote_ip...HttpConnectionManager.xff_num_trusted_hops must be unset + // or zero and HttpConnectionManager.original_ip_detection_extensions + // must be empty." - A41 + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-valid", + resources: []*anypb.Any{v3LisToTestRBAC(0, nil)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + RouteConfigName: v3RouteConfigName, + MaxStreamDuration: time.Second, + HTTPFilters: []HTTPFilter{routerFilter}, + Raw: v3LisToTestRBAC(0, nil), + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + // In order to support xDS Configured RBAC HTTPFilter equating direct + // remote ip and remote ip, xffNumTrustedHops cannot be greater than + // zero. This is because if you can trust a ingress proxy hop when + // determining an origin clients ip address, direct remote ip != remote + // ip. + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-invalid-num-untrusted-hops", + resources: []*anypb.Any{v3LisToTestRBAC(1, nil)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + // In order to support xDS Configured RBAC HTTPFilter equating direct + // remote ip and remote ip, originalIpDetectionExtensions must be empty. + // This is because if you have to ask ip-detection-extension for the + // original ip, direct remote ip might not equal remote ip. + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-invalid-original-ip-detection-extension", + resources: []*anypb.Any{v3LisToTestRBAC(0, []*v3corepb.TypedExtensionConfig{{Name: "something"}})}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: true, + }, + { + name: "v3 listener with inline route configuration", + resources: []*anypb.Any{v3LisWithInlineRoute}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InlineRouteConfig: &RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{{ + Domains: []string{v3LDSTarget}, + Routes: []*Route{{Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, RouteAction: RouteActionRoute}}, + }}}, + MaxStreamDuration: time.Second, + Raw: v3LisWithInlineRoute, + HTTPFilters: routerFilterList, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "multiple listener resources", + resources: []*anypb.Any{v2Lis, v3LisWithFilters()}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v2LDSTarget: {Update: ListenerUpdate{RouteConfigName: v2RouteConfigName, Raw: v2Lis}}, + v3LDSTarget: {Update: ListenerUpdate{RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters(), HTTPFilters: routerFilterList}}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + // To test that unmarshal keeps processing on errors. + name: "good and bad listener resources", + resources: []*anypb.Any{ + v2Lis, + testutils.MarshalAny(&v3listenerpb.Listener{ + Name: "bad", + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_ScopedRoutes{}, + }), + }}), + v3LisWithFilters(), + }, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v2LDSTarget: {Update: ListenerUpdate{RouteConfigName: v2RouteConfigName, Raw: v2Lis}}, + v3LDSTarget: {Update: ListenerUpdate{RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters(), HTTPFilters: routerFilterList}}, + "bad": {Err: cmpopts.AnyError}, + }, + wantMD: errMD, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + update, md, err := UnmarshalListener(opts) + if (err != nil) != test.wantErr { + t.Fatalf("UnmarshalListener(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) + } + if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { + t.Errorf("got unexpected update, diff (-got +want): %v", diff) + } + if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { + t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) + } + }) + } +} + +func (s) TestUnmarshalListener_ServerSide(t *testing.T) { + oldRBAC := env.RBACSupport + env.RBACSupport = true + defer func() { + env.RBACSupport = oldRBAC + }() + const ( + v3LDSTarget = "grpc/server?xds.resource.listening_address=0.0.0.0:9999" + testVersion = "test-version-lds-server" + ) + + var ( + serverOnlyCustomFilter = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } + routeConfig = &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}} + inlineRouteConfig = &RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*Route{{Prefix: newStringP("/"), RouteAction: RouteActionNonForwardingAction}}, + }}} + emptyValidNetworkFilters = []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + } + localSocketAddress = &v3corepb.Address{ + Address: &v3corepb.Address_SocketAddress{ + SocketAddress: &v3corepb.SocketAddress{ + Address: "0.0.0.0", + PortSpecifier: &v3corepb.SocketAddress_PortValue{ + PortValue: 9999, + }, + }, + }, + } + listenerEmptyTransportSocket = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + }, + }, + }) + listenerNoValidationContextDeprecatedFields = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + }, + }), + }, + }, + }, + }) + listenerNoValidationContextNewFields = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + }, + }), + }, + }, + }, + }) + listenerWithValidationContextDeprecatedFields = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + }, + }, + }), + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance{ + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "defaultRootPluginInstance", + CertificateName: "defaultRootCertName", + }, + }, + }, + }), + }, + }, + }, + }) + listenerWithValidationContextNewFields = testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContext{ + ValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "rootPluginInstance", + CertificateName: "rootCertName", + }, + }, + }, + }, + }), + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Name: "default-filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "defaultIdentityPluginInstance", + CertificateName: "defaultIdentityCertName", + }, + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + CaCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "defaultRootPluginInstance", + CertificateName: "defaultRootCertName", + }, + }, + }, + }, + }, + }), + }, + }, + }, + }) + errMD = UpdateMetadata{ + Status: ServiceStatusNACKed, + Version: testVersion, + ErrState: &UpdateErrorMetadata{ + Version: testVersion, + Err: cmpopts.AnyError, + }, + } + ) + v3LisToTestRBAC := func(xffNumTrustedHops uint32, originalIpDetectionExtensions []*v3corepb.TypedExtensionConfig) *anypb.Any { + return testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + XffNumTrustedHops: xffNumTrustedHops, + OriginalIpDetectionExtensions: originalIpDetectionExtensions, + }), + }, + }, + }, + }, + }, + }) + } + + tests := []struct { + name string + resources []*anypb.Any + wantUpdate map[string]ListenerUpdateErrTuple + wantMD UpdateMetadata + wantErr string + }{ + { + name: "non-empty listener filters", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + ListenerFilters: []*v3listenerpb.ListenerFilter{ + {Name: "listener-filter-1"}, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "unsupported field 'listener_filters'", + }, + { + name: "use_original_dst is set", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + UseOriginalDst: &wrapperspb.BoolValue{Value: true}, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "unsupported field 'use_original_dst'", + }, + { + name: "no address field", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{Name: v3LDSTarget})}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "no address field in LDS response", + }, + { + name: "no socket address field", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: &v3corepb.Address{}, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "no socket_address field in LDS response", + }, + { + name: "no filter chains and no default filter chain", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{DestinationPort: &wrapperspb.UInt32Value{Value: 666}}, + Filters: emptyValidNetworkFilters, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "no supported filter chains and no default filter chain", + }, + { + name: "missing http connection manager network filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "missing HttpConnectionManager filter", + }, + { + name: "missing filter name in http filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{}), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "missing name field in filter", + }, + { + name: "duplicate filter names in http filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter}, + }), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "duplicate filter name", + }, + { + name: "no terminal filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + }), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "http filters list is empty", + }, + { + name: "terminal filter not last", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{emptyRouterFilter, serverOnlyCustomFilter}, + }), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "is a terminal filter but it is not last in the filter chain", + }, + { + name: "last not terminal filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: routeConfig, + }, + HttpFilters: []*v3httppb.HttpFilter{serverOnlyCustomFilter}, + }), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "is not a terminal filter", + }, + { + name: "unsupported oneof in typed config of http filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_ConfigDiscovery{}, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "unsupported config_type", + }, + { + name: "overlapping filter chain match criteria", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3, 4, 5}}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, + }, + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{5, 6, 7}}, + Filters: emptyValidNetworkFilters, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "multiple filter chains with overlapping matching rules are defined", + }, + { + name: "unsupported network filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.LocalReplyConfig{}), + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "unsupported network filter", + }, + { + name: "badly marshaled network filter", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "name", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3HTTPConnManagerURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "failed unmarshaling of network filter", + }, + { + name: "unexpected transport socket name", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "unsupported-transport-socket-name", + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "transport_socket field has unexpected name", + }, + { + name: "unexpected transport socket typedConfig URL", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{}), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "transport_socket field has unexpected typeURL", + }, + { + name: "badly marshaled transport socket", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3DownstreamTLSContextURL, + Value: []byte{1, 2, 3, 4}, + }, + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "failed to unmarshal DownstreamTlsContext in LDS response", + }, + { + name: "missing CommonTlsContext", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{}), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "DownstreamTlsContext in LDS response does not contain a CommonTlsContext", + }, + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-valid", + resources: []*anypb.Any{v3LisToTestRBAC(0, nil)}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Raw: listenerEmptyTransportSocket, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-invalid-num-untrusted-hops", + resources: []*anypb.Any{v3LisToTestRBAC(1, nil)}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "xff_num_trusted_hops must be unset or zero", + }, + { + name: "rbac-allow-equating-direct-remote-ip-and-remote-ip-invalid-original-ip-detection-extension", + resources: []*anypb.Any{v3LisToTestRBAC(0, []*v3corepb.TypedExtensionConfig{{Name: "something"}})}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "original_ip_detection_extensions must be empty", + }, + { + name: "unsupported validation context in transport socket", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_ValidationContextSdsSecretConfig{ + ValidationContextSdsSecretConfig: &v3tlspb.SdsSecretConfig{ + Name: "foo-sds-secret", + }, + }, + }, + }), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "validation context contains unexpected type", + }, + { + name: "empty transport socket", + resources: []*anypb.Any{listenerEmptyTransportSocket}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Raw: listenerEmptyTransportSocket, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "no identity and root certificate providers using deprecated fields", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set", + }, + { + name: "no identity and root certificate providers using new fields", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + RequireClientCertificate: &wrapperspb.BoolValue{Value: true}, + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateProviderInstance: &v3tlspb.CertificateProviderPluginInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set", + }, + { + name: "no identity certificate provider with require_client_cert", + resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ + Name: v3LDSTarget, + Address: localSocketAddress, + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{}, + }), + }, + }, + }, + }, + })}, + wantUpdate: map[string]ListenerUpdateErrTuple{v3LDSTarget: {Err: cmpopts.AnyError}}, + wantMD: errMD, + wantErr: "security configuration on the server-side does not contain identity certificate provider instance name", + }, + { + name: "happy case with no validation context using deprecated fields", + resources: []*anypb.Any{listenerNoValidationContextDeprecatedFields}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + Raw: listenerNoValidationContextDeprecatedFields, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "happy case with no validation context using new fields", + resources: []*anypb.Any{listenerNoValidationContextNewFields}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + Raw: listenerNoValidationContextNewFields, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "happy case with validation context provider instance with deprecated fields", + resources: []*anypb.Any{listenerWithValidationContextDeprecatedFields}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + RootInstanceName: "rootPluginInstance", + RootCertName: "rootCertName", + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + RootInstanceName: "defaultRootPluginInstance", + RootCertName: "defaultRootCertName", + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + Raw: listenerWithValidationContextDeprecatedFields, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + { + name: "happy case with validation context provider instance with new fields", + resources: []*anypb.Any{listenerWithValidationContextNewFields}, + wantUpdate: map[string]ListenerUpdateErrTuple{ + v3LDSTarget: {Update: ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: "0.0.0.0", + Port: "9999", + FilterChains: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: { + SecurityCfg: &SecurityConfig{ + RootInstanceName: "rootPluginInstance", + RootCertName: "rootCertName", + IdentityInstanceName: "identityPluginInstance", + IdentityCertName: "identityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{ + SecurityCfg: &SecurityConfig{ + RootInstanceName: "defaultRootPluginInstance", + RootCertName: "defaultRootCertName", + IdentityInstanceName: "defaultIdentityPluginInstance", + IdentityCertName: "defaultIdentityCertName", + RequireClientCert: true, + }, + InlineRouteConfig: inlineRouteConfig, + HTTPFilters: routerFilterList, + }, + }, + }, + Raw: listenerWithValidationContextNewFields, + }}, + }, + wantMD: UpdateMetadata{ + Status: ServiceStatusACKed, + Version: testVersion, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + gotUpdate, md, err := UnmarshalListener(opts) + if (err != nil) != (test.wantErr != "") { + t.Fatalf("UnmarshalListener(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) + } + if err != nil && !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("UnmarshalListener(%+v) = %v wantErr: %q", opts, err, test.wantErr) + } + if diff := cmp.Diff(gotUpdate, test.wantUpdate, cmpOpts); diff != "" { + t.Errorf("got unexpected update, diff (-got +want): %v", diff) + } + if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { + t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) + } + }) + } +} + +type filterConfig struct { + httpfilter.FilterConfig + Cfg proto.Message + Override proto.Message +} + +// httpFilter allows testing the http filter registry and parsing functionality. +type httpFilter struct { + httpfilter.ClientInterceptorBuilder + httpfilter.ServerInterceptorBuilder +} + +func (httpFilter) TypeURLs() []string { return []string{"custom.filter"} } + +func (httpFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Cfg: cfg}, nil +} + +func (httpFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Override: override}, nil +} + +func (httpFilter) IsTerminal() bool { + return false +} + +// errHTTPFilter returns errors no matter what is passed to ParseFilterConfig. +type errHTTPFilter struct { + httpfilter.ClientInterceptorBuilder +} + +func (errHTTPFilter) TypeURLs() []string { return []string{"err.custom.filter"} } + +func (errHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return nil, fmt.Errorf("error from ParseFilterConfig") +} + +func (errHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return nil, fmt.Errorf("error from ParseFilterConfigOverride") +} + +func (errHTTPFilter) IsTerminal() bool { + return false +} + +func init() { + httpfilter.Register(httpFilter{}) + httpfilter.Register(errHTTPFilter{}) + httpfilter.Register(serverOnlyHTTPFilter{}) + httpfilter.Register(clientOnlyHTTPFilter{}) +} + +// serverOnlyHTTPFilter does not implement ClientInterceptorBuilder +type serverOnlyHTTPFilter struct { + httpfilter.ServerInterceptorBuilder +} + +func (serverOnlyHTTPFilter) TypeURLs() []string { return []string{"serverOnly.custom.filter"} } + +func (serverOnlyHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Cfg: cfg}, nil +} + +func (serverOnlyHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Override: override}, nil +} + +func (serverOnlyHTTPFilter) IsTerminal() bool { + return false +} + +// clientOnlyHTTPFilter does not implement ServerInterceptorBuilder +type clientOnlyHTTPFilter struct { + httpfilter.ClientInterceptorBuilder +} + +func (clientOnlyHTTPFilter) TypeURLs() []string { return []string{"clientOnly.custom.filter"} } + +func (clientOnlyHTTPFilter) ParseFilterConfig(cfg proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Cfg: cfg}, nil +} + +func (clientOnlyHTTPFilter) ParseFilterConfigOverride(override proto.Message) (httpfilter.FilterConfig, error) { + return filterConfig{Override: override}, nil +} + +func (clientOnlyHTTPFilter) IsTerminal() bool { + return false +} + +var customFilterConfig = &anypb.Any{ + TypeUrl: "custom.filter", + Value: []byte{1, 2, 3}, +} + +var errFilterConfig = &anypb.Any{ + TypeUrl: "err.custom.filter", + Value: []byte{1, 2, 3}, +} + +var serverOnlyCustomFilterConfig = &anypb.Any{ + TypeUrl: "serverOnly.custom.filter", + Value: []byte{1, 2, 3}, +} + +var clientOnlyCustomFilterConfig = &anypb.Any{ + TypeUrl: "clientOnly.custom.filter", + Value: []byte{1, 2, 3}, +} + +var customFilterTypedStructConfig = &v1typepb.TypedStruct{ + TypeUrl: "custom.filter", + Value: &spb.Struct{ + Fields: map[string]*spb.Value{ + "foo": {Kind: &spb.Value_StringValue{StringValue: "bar"}}, + }, + }, +} +var wrappedCustomFilterTypedStructConfig *anypb.Any + +func init() { + wrappedCustomFilterTypedStructConfig = testutils.MarshalAny(customFilterTypedStructConfig) +} + +var unknownFilterConfig = &anypb.Any{ + TypeUrl: "unknown.custom.filter", + Value: []byte{1, 2, 3}, +} + +func wrappedOptionalFilter(name string) *anypb.Any { + return testutils.MarshalAny(&v3routepb.FilterConfig{ + IsOptional: true, + Config: &anypb.Any{ + TypeUrl: name, + Value: []byte{1, 2, 3}, + }, + }) +} diff --git a/xds/internal/client/load/reporter.go b/xds/internal/xdsclient/load/reporter.go similarity index 100% rename from xds/internal/client/load/reporter.go rename to xds/internal/xdsclient/load/reporter.go diff --git a/xds/internal/client/load/store.go b/xds/internal/xdsclient/load/store.go similarity index 100% rename from xds/internal/client/load/store.go rename to xds/internal/xdsclient/load/store.go diff --git a/xds/internal/client/load/store_test.go b/xds/internal/xdsclient/load/store_test.go similarity index 100% rename from xds/internal/client/load/store_test.go rename to xds/internal/xdsclient/load/store_test.go diff --git a/xds/internal/client/loadreport.go b/xds/internal/xdsclient/loadreport.go similarity index 98% rename from xds/internal/client/loadreport.go rename to xds/internal/xdsclient/loadreport.go index be42a6e0c38..32a71dada7f 100644 --- a/xds/internal/client/loadreport.go +++ b/xds/internal/xdsclient/loadreport.go @@ -15,13 +15,13 @@ * limitations under the License. */ -package client +package xdsclient import ( "context" "google.golang.org/grpc" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // ReportLoad starts an load reporting stream to the given server. If the server diff --git a/xds/internal/client/tests/loadreport_test.go b/xds/internal/xdsclient/loadreport_test.go similarity index 94% rename from xds/internal/client/tests/loadreport_test.go rename to xds/internal/xdsclient/loadreport_test.go index af145e7f2a9..88a08eb43fd 100644 --- a/xds/internal/client/tests/loadreport_test.go +++ b/xds/internal/xdsclient/loadreport_test.go @@ -16,7 +16,7 @@ * */ -package tests_test +package xdsclient_test import ( "context" @@ -32,13 +32,13 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" - "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/testing/protocmp" - _ "google.golang.org/grpc/xds/internal/client/v2" // Register the v2 xDS API client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register the v2 xDS API client. ) const ( @@ -54,7 +54,7 @@ func (s) TestLRSClient(t *testing.T) { } defer sCleanup() - xdsC, err := client.NewWithConfigForTesting(&bootstrap.Config{ + xdsC, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ BalancerName: fs.Address, Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), NodeProto: &v2corepb.Node{}, diff --git a/xds/internal/client/logging.go b/xds/internal/xdsclient/logging.go similarity index 98% rename from xds/internal/client/logging.go rename to xds/internal/xdsclient/logging.go index bff3fb1d3df..e28ea0d0410 100644 --- a/xds/internal/client/logging.go +++ b/xds/internal/xdsclient/logging.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "fmt" diff --git a/xds/internal/xdsclient/matcher.go b/xds/internal/xdsclient/matcher.go new file mode 100644 index 00000000000..e663e02769f --- /dev/null +++ b/xds/internal/xdsclient/matcher.go @@ -0,0 +1,278 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "fmt" + "strings" + + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpcutil" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/xds/matcher" + "google.golang.org/grpc/metadata" +) + +// RouteToMatcher converts a route to a Matcher to match incoming RPC's against. +func RouteToMatcher(r *Route) (*CompositeMatcher, error) { + var pm pathMatcher + switch { + case r.Regex != nil: + pm = newPathRegexMatcher(r.Regex) + case r.Path != nil: + pm = newPathExactMatcher(*r.Path, r.CaseInsensitive) + case r.Prefix != nil: + pm = newPathPrefixMatcher(*r.Prefix, r.CaseInsensitive) + default: + return nil, fmt.Errorf("illegal route: missing path_matcher") + } + + headerMatchers := make([]matcher.HeaderMatcher, 0, len(r.Headers)) + for _, h := range r.Headers { + var matcherT matcher.HeaderMatcher + switch { + case h.ExactMatch != nil && *h.ExactMatch != "": + matcherT = matcher.NewHeaderExactMatcher(h.Name, *h.ExactMatch) + case h.RegexMatch != nil: + matcherT = matcher.NewHeaderRegexMatcher(h.Name, h.RegexMatch) + case h.PrefixMatch != nil && *h.PrefixMatch != "": + matcherT = matcher.NewHeaderPrefixMatcher(h.Name, *h.PrefixMatch) + case h.SuffixMatch != nil && *h.SuffixMatch != "": + matcherT = matcher.NewHeaderSuffixMatcher(h.Name, *h.SuffixMatch) + case h.RangeMatch != nil: + matcherT = matcher.NewHeaderRangeMatcher(h.Name, h.RangeMatch.Start, h.RangeMatch.End) + case h.PresentMatch != nil: + matcherT = matcher.NewHeaderPresentMatcher(h.Name, *h.PresentMatch) + default: + return nil, fmt.Errorf("illegal route: missing header_match_specifier") + } + if h.InvertMatch != nil && *h.InvertMatch { + matcherT = matcher.NewInvertMatcher(matcherT) + } + headerMatchers = append(headerMatchers, matcherT) + } + + var fractionMatcher *fractionMatcher + if r.Fraction != nil { + fractionMatcher = newFractionMatcher(*r.Fraction) + } + return newCompositeMatcher(pm, headerMatchers, fractionMatcher), nil +} + +// CompositeMatcher is a matcher that holds onto many matchers and aggregates +// the matching results. +type CompositeMatcher struct { + pm pathMatcher + hms []matcher.HeaderMatcher + fm *fractionMatcher +} + +func newCompositeMatcher(pm pathMatcher, hms []matcher.HeaderMatcher, fm *fractionMatcher) *CompositeMatcher { + return &CompositeMatcher{pm: pm, hms: hms, fm: fm} +} + +// Match returns true if all matchers return true. +func (a *CompositeMatcher) Match(info iresolver.RPCInfo) bool { + if a.pm != nil && !a.pm.match(info.Method) { + return false + } + + // Call headerMatchers even if md is nil, because routes may match + // non-presence of some headers. + var md metadata.MD + if info.Context != nil { + md, _ = metadata.FromOutgoingContext(info.Context) + if extraMD, ok := grpcutil.ExtraMetadata(info.Context); ok { + md = metadata.Join(md, extraMD) + // Remove all binary headers. They are hard to match with. May need + // to add back if asked by users. + for k := range md { + if strings.HasSuffix(k, "-bin") { + delete(md, k) + } + } + } + } + for _, m := range a.hms { + if !m.Match(md) { + return false + } + } + + if a.fm != nil && !a.fm.match() { + return false + } + return true +} + +func (a *CompositeMatcher) String() string { + var ret string + if a.pm != nil { + ret += a.pm.String() + } + for _, m := range a.hms { + ret += m.String() + } + if a.fm != nil { + ret += a.fm.String() + } + return ret +} + +type fractionMatcher struct { + fraction int64 // real fraction is fraction/1,000,000. +} + +func newFractionMatcher(fraction uint32) *fractionMatcher { + return &fractionMatcher{fraction: int64(fraction)} +} + +// RandInt63n overwrites grpcrand for control in tests. +var RandInt63n = grpcrand.Int63n + +func (fm *fractionMatcher) match() bool { + t := RandInt63n(1000000) + return t <= fm.fraction +} + +func (fm *fractionMatcher) String() string { + return fmt.Sprintf("fraction:%v", fm.fraction) +} + +type domainMatchType int + +const ( + domainMatchTypeInvalid domainMatchType = iota + domainMatchTypeUniversal + domainMatchTypePrefix + domainMatchTypeSuffix + domainMatchTypeExact +) + +// Exact > Suffix > Prefix > Universal > Invalid. +func (t domainMatchType) betterThan(b domainMatchType) bool { + return t > b +} + +func matchTypeForDomain(d string) domainMatchType { + if d == "" { + return domainMatchTypeInvalid + } + if d == "*" { + return domainMatchTypeUniversal + } + if strings.HasPrefix(d, "*") { + return domainMatchTypeSuffix + } + if strings.HasSuffix(d, "*") { + return domainMatchTypePrefix + } + if strings.Contains(d, "*") { + return domainMatchTypeInvalid + } + return domainMatchTypeExact +} + +func match(domain, host string) (domainMatchType, bool) { + switch typ := matchTypeForDomain(domain); typ { + case domainMatchTypeInvalid: + return typ, false + case domainMatchTypeUniversal: + return typ, true + case domainMatchTypePrefix: + // abc.* + return typ, strings.HasPrefix(host, strings.TrimSuffix(domain, "*")) + case domainMatchTypeSuffix: + // *.123 + return typ, strings.HasSuffix(host, strings.TrimPrefix(domain, "*")) + case domainMatchTypeExact: + return typ, domain == host + default: + return domainMatchTypeInvalid, false + } +} + +// FindBestMatchingVirtualHost returns the virtual host whose domains field best +// matches host +// +// The domains field support 4 different matching pattern types: +// - Exact match +// - Suffix match (e.g. “*ABC”) +// - Prefix match (e.g. “ABC*) +// - Universal match (e.g. “*”) +// +// The best match is defined as: +// - A match is better if it’s matching pattern type is better +// - Exact match > suffix match > prefix match > universal match +// - If two matches are of the same pattern type, the longer match is better +// - This is to compare the length of the matching pattern, e.g. “*ABCDE” > +// “*ABC” +func FindBestMatchingVirtualHost(host string, vHosts []*VirtualHost) *VirtualHost { // Maybe move this crap to client + var ( + matchVh *VirtualHost + matchType = domainMatchTypeInvalid + matchLen int + ) + for _, vh := range vHosts { + for _, domain := range vh.Domains { + typ, matched := match(domain, host) + if typ == domainMatchTypeInvalid { + // The rds response is invalid. + return nil + } + if matchType.betterThan(typ) || matchType == typ && matchLen >= len(domain) || !matched { + // The previous match has better type, or the previous match has + // better length, or this domain isn't a match. + continue + } + matchVh = vh + matchType = typ + matchLen = len(domain) + } + } + return matchVh +} + +// FindBestMatchingVirtualHostServer returns the virtual host whose domains field best +// matches authority. +func FindBestMatchingVirtualHostServer(authority string, vHosts []VirtualHostWithInterceptors) *VirtualHostWithInterceptors { + var ( + matchVh *VirtualHostWithInterceptors + matchType = domainMatchTypeInvalid + matchLen int + ) + for _, vh := range vHosts { + for _, domain := range vh.Domains { + typ, matched := match(domain, authority) + if typ == domainMatchTypeInvalid { + // The rds response is invalid. + return nil + } + if matchType.betterThan(typ) || matchType == typ && matchLen >= len(domain) || !matched { + // The previous match has better type, or the previous match has + // better length, or this domain isn't a match. + continue + } + matchVh = &vh + matchType = typ + matchLen = len(domain) + } + } + return matchVh +} diff --git a/xds/internal/resolver/matcher_path.go b/xds/internal/xdsclient/matcher_path.go similarity index 97% rename from xds/internal/resolver/matcher_path.go rename to xds/internal/xdsclient/matcher_path.go index 011d1a94c49..a00c6954ef5 100644 --- a/xds/internal/resolver/matcher_path.go +++ b/xds/internal/xdsclient/matcher_path.go @@ -16,14 +16,14 @@ * */ -package resolver +package xdsclient import ( "regexp" "strings" ) -type pathMatcherInterface interface { +type pathMatcher interface { match(path string) bool String() string } diff --git a/xds/internal/resolver/matcher_path_test.go b/xds/internal/xdsclient/matcher_path_test.go similarity index 99% rename from xds/internal/resolver/matcher_path_test.go rename to xds/internal/xdsclient/matcher_path_test.go index 263a049108e..a211034a60d 100644 --- a/xds/internal/resolver/matcher_path_test.go +++ b/xds/internal/xdsclient/matcher_path_test.go @@ -16,7 +16,7 @@ * */ -package resolver +package xdsclient import ( "regexp" diff --git a/xds/internal/resolver/matcher_test.go b/xds/internal/xdsclient/matcher_test.go similarity index 55% rename from xds/internal/resolver/matcher_test.go rename to xds/internal/xdsclient/matcher_test.go index 7657b87bf45..f750d07d6e4 100644 --- a/xds/internal/resolver/matcher_test.go +++ b/xds/internal/xdsclient/matcher_test.go @@ -16,7 +16,7 @@ * */ -package resolver +package xdsclient import ( "context" @@ -25,21 +25,22 @@ import ( "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/grpcutil" iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/metadata" ) func TestAndMatcherMatch(t *testing.T) { tests := []struct { name string - pm pathMatcherInterface - hm headerMatcherInterface + pm pathMatcher + hm matcher.HeaderMatcher info iresolver.RPCInfo want bool }{ { name: "both match", pm: newPathExactMatcher("/a/b", false), - hm: newHeaderExactMatcher("th", "tv"), + hm: matcher.NewHeaderExactMatcher("th", "tv"), info: iresolver.RPCInfo{ Method: "/a/b", Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), @@ -49,7 +50,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "both match with path case insensitive", pm: newPathExactMatcher("/A/B", true), - hm: newHeaderExactMatcher("th", "tv"), + hm: matcher.NewHeaderExactMatcher("th", "tv"), info: iresolver.RPCInfo{ Method: "/a/b", Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), @@ -59,7 +60,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "only one match", pm: newPathExactMatcher("/a/b", false), - hm: newHeaderExactMatcher("th", "tv"), + hm: matcher.NewHeaderExactMatcher("th", "tv"), info: iresolver.RPCInfo{ Method: "/z/y", Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), @@ -69,7 +70,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "both not match", pm: newPathExactMatcher("/z/y", false), - hm: newHeaderExactMatcher("th", "abc"), + hm: matcher.NewHeaderExactMatcher("th", "abc"), info: iresolver.RPCInfo{ Method: "/a/b", Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), @@ -79,7 +80,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "fake header", pm: newPathPrefixMatcher("/", false), - hm: newHeaderExactMatcher("content-type", "fake"), + hm: matcher.NewHeaderExactMatcher("content-type", "fake"), info: iresolver.RPCInfo{ Method: "/a/b", Context: grpcutil.WithExtraMetadata(context.Background(), metadata.Pairs( @@ -91,7 +92,7 @@ func TestAndMatcherMatch(t *testing.T) { { name: "binary header", pm: newPathPrefixMatcher("/", false), - hm: newHeaderPresentMatcher("t-bin", true), + hm: matcher.NewHeaderPresentMatcher("t-bin", true), info: iresolver.RPCInfo{ Method: "/a/b", Context: grpcutil.WithExtraMetadata( @@ -105,8 +106,8 @@ func TestAndMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := newCompositeMatcher(tt.pm, []headerMatcherInterface{tt.hm}, nil) - if got := a.match(tt.info); got != tt.want { + a := newCompositeMatcher(tt.pm, []matcher.HeaderMatcher{tt.hm}, nil) + if got := a.Match(tt.info); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } }) @@ -117,11 +118,11 @@ func TestFractionMatcherMatch(t *testing.T) { const fraction = 500000 fm := newFractionMatcher(fraction) defer func() { - grpcrandInt63n = grpcrand.Int63n + RandInt63n = grpcrand.Int63n }() // rand > fraction, should return false. - grpcrandInt63n = func(n int64) int64 { + RandInt63n = func(n int64) int64 { return fraction + 1 } if matched := fm.match(); matched { @@ -129,7 +130,7 @@ func TestFractionMatcherMatch(t *testing.T) { } // rand == fraction, should return true. - grpcrandInt63n = func(n int64) int64 { + RandInt63n = func(n int64) int64 { return fraction } if matched := fm.match(); !matched { @@ -137,10 +138,56 @@ func TestFractionMatcherMatch(t *testing.T) { } // rand < fraction, should return true. - grpcrandInt63n = func(n int64) int64 { + RandInt63n = func(n int64) int64 { return fraction - 1 } if matched := fm.match(); !matched { t.Errorf("match() = %v, want match", matched) } } + +func (s) TestMatchTypeForDomain(t *testing.T) { + tests := []struct { + d string + want domainMatchType + }{ + {d: "", want: domainMatchTypeInvalid}, + {d: "*", want: domainMatchTypeUniversal}, + {d: "bar.*", want: domainMatchTypePrefix}, + {d: "*.abc.com", want: domainMatchTypeSuffix}, + {d: "foo.bar.com", want: domainMatchTypeExact}, + {d: "foo.*.com", want: domainMatchTypeInvalid}, + } + for _, tt := range tests { + if got := matchTypeForDomain(tt.d); got != tt.want { + t.Errorf("matchTypeForDomain(%q) = %v, want %v", tt.d, got, tt.want) + } + } +} + +func (s) TestMatch(t *testing.T) { + tests := []struct { + name string + domain string + host string + wantTyp domainMatchType + wantMatched bool + }{ + {name: "invalid-empty", domain: "", host: "", wantTyp: domainMatchTypeInvalid, wantMatched: false}, + {name: "invalid", domain: "a.*.b", host: "", wantTyp: domainMatchTypeInvalid, wantMatched: false}, + {name: "universal", domain: "*", host: "abc.com", wantTyp: domainMatchTypeUniversal, wantMatched: true}, + {name: "prefix-match", domain: "abc.*", host: "abc.123", wantTyp: domainMatchTypePrefix, wantMatched: true}, + {name: "prefix-no-match", domain: "abc.*", host: "abcd.123", wantTyp: domainMatchTypePrefix, wantMatched: false}, + {name: "suffix-match", domain: "*.123", host: "abc.123", wantTyp: domainMatchTypeSuffix, wantMatched: true}, + {name: "suffix-no-match", domain: "*.123", host: "abc.1234", wantTyp: domainMatchTypeSuffix, wantMatched: false}, + {name: "exact-match", domain: "foo.bar", host: "foo.bar", wantTyp: domainMatchTypeExact, wantMatched: true}, + {name: "exact-no-match", domain: "foo.bar.com", host: "foo.bar", wantTyp: domainMatchTypeExact, wantMatched: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotTyp, gotMatched := match(tt.domain, tt.host); gotTyp != tt.wantTyp || gotMatched != tt.wantMatched { + t.Errorf("match() = %v, %v, want %v, %v", gotTyp, gotMatched, tt.wantTyp, tt.wantMatched) + } + }) + } +} diff --git a/xds/internal/client/rds_test.go b/xds/internal/xdsclient/rds_test.go similarity index 55% rename from xds/internal/client/rds_test.go rename to xds/internal/xdsclient/rds_test.go index 0c1e2b28538..c89e8cddca4 100644 --- a/xds/internal/client/rds_test.go +++ b/xds/internal/xdsclient/rds_test.go @@ -16,28 +16,31 @@ * */ -package client +package xdsclient import ( "fmt" + "regexp" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/version" + "google.golang.org/protobuf/types/known/durationpb" + v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" v2routepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/route" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" - "github.com/golang/protobuf/proto" anypb "github.com/golang/protobuf/ptypes/any" wrapperspb "github.com/golang/protobuf/ptypes/wrappers" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - - "google.golang.org/grpc/xds/internal/env" - "google.golang.org/grpc/xds/internal/httpfilter" - "google.golang.org/grpc/xds/internal/version" - "google.golang.org/protobuf/types/known/durationpb" ) func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { @@ -72,11 +75,55 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { Routes: []*Route{{ Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute, }}, HTTPFilterConfigOverride: cfgs, }}, } } + goodRouteConfigWithRetryPolicy = func(vhrp *v3routepb.RetryPolicy, rrp *v3routepb.RetryPolicy) *v3routepb.RouteConfiguration { + return &v3routepb.RouteConfiguration{ + Name: routeName, + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{ldsTarget}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}}, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}, + RetryPolicy: rrp, + }, + }, + }}, + RetryPolicy: vhrp, + }}, + } + } + goodUpdateWithRetryPolicy = func(vhrc *RetryConfig, rrc *RetryConfig) RouteConfigUpdate { + if !env.RetrySupport { + vhrc = nil + rrc = nil + } + return RouteConfigUpdate{ + VirtualHosts: []*VirtualHost{{ + Domains: []string{ldsTarget}, + Routes: []*Route{{ + Prefix: newStringP("/"), + WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute, + RetryConfig: rrc, + }}, + RetryConfig: vhrc, + }}, + } + } + defaultRetryBackoff = RetryBackoff{BaseInterval: 25 * time.Millisecond, MaxInterval: 250 * time.Millisecond} + goodUpdateIfRetryDisabled = func() RouteConfigUpdate { + if env.RetrySupport { + return RouteConfigUpdate{} + } + return goodUpdateWithRetryPolicy(nil, nil) + } ) tests := []struct { @@ -84,7 +131,6 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { rc *v3routepb.RouteConfiguration wantUpdate RouteConfigUpdate wantError bool - disableFI bool // disable fault injection }{ { name: "default-route-match-field-is-nil", @@ -175,7 +221,10 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { VirtualHosts: []*VirtualHost{ { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP("/"), CaseInsensitive: true, WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP("/"), + CaseInsensitive: true, + WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, }, @@ -217,11 +266,15 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, }, @@ -251,7 +304,9 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { VirtualHosts: []*VirtualHost{ { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP("/"), + WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, }, @@ -328,6 +383,7 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { "b": {Weight: 3}, "c": {Weight: 5}, }, + RouteAction: RouteActionRoute, }}, }, }, @@ -362,6 +418,7 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, MaxStreamDuration: newDurationP(time.Second), + RouteAction: RouteActionRoute, }}, }, }, @@ -396,6 +453,7 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, MaxStreamDuration: newDurationP(time.Second), + RouteAction: RouteActionRoute, }}, }, }, @@ -430,6 +488,7 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { Prefix: newStringP("/"), WeightedClusters: map[string]WeightedCluster{clusterName: {Weight: 1}}, MaxStreamDuration: newDurationP(0), + RouteAction: RouteActionRoute, }}, }, }, @@ -471,18 +530,50 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { wantUpdate: goodUpdateWithFilterConfigs(nil), }, { - name: "good-route-config-with-http-err-filter-config-fi-disabled", - disableFI: true, - rc: goodRouteConfigWithFilterConfigs(map[string]*anypb.Any{"foo": errFilterConfig}), - wantUpdate: goodUpdateWithFilterConfigs(nil), + name: "good-route-config-with-retry-policy", + rc: goodRouteConfigWithRetryPolicy( + &v3routepb.RetryPolicy{RetryOn: "cancelled"}, + &v3routepb.RetryPolicy{RetryOn: "deadline-exceeded,unsupported", NumRetries: &wrapperspb.UInt32Value{Value: 2}}), + wantUpdate: goodUpdateWithRetryPolicy( + &RetryConfig{RetryOn: map[codes.Code]bool{codes.Canceled: true}, NumRetries: 1, RetryBackoff: defaultRetryBackoff}, + &RetryConfig{RetryOn: map[codes.Code]bool{codes.DeadlineExceeded: true}, NumRetries: 2, RetryBackoff: defaultRetryBackoff}), + }, + { + name: "good-route-config-with-retry-backoff", + rc: goodRouteConfigWithRetryPolicy( + &v3routepb.RetryPolicy{RetryOn: "internal", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{BaseInterval: durationpb.New(10 * time.Millisecond), MaxInterval: durationpb.New(10 * time.Millisecond)}}, + &v3routepb.RetryPolicy{RetryOn: "resource-exhausted", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{BaseInterval: durationpb.New(10 * time.Millisecond)}}), + wantUpdate: goodUpdateWithRetryPolicy( + &RetryConfig{RetryOn: map[codes.Code]bool{codes.Internal: true}, NumRetries: 1, RetryBackoff: RetryBackoff{BaseInterval: 10 * time.Millisecond, MaxInterval: 10 * time.Millisecond}}, + &RetryConfig{RetryOn: map[codes.Code]bool{codes.ResourceExhausted: true}, NumRetries: 1, RetryBackoff: RetryBackoff{BaseInterval: 10 * time.Millisecond, MaxInterval: 100 * time.Millisecond}}), + }, + { + name: "bad-retry-policy-0-retries", + rc: goodRouteConfigWithRetryPolicy(&v3routepb.RetryPolicy{RetryOn: "cancelled", NumRetries: &wrapperspb.UInt32Value{Value: 0}}, nil), + wantUpdate: goodUpdateIfRetryDisabled(), + wantError: env.RetrySupport, + }, + { + name: "bad-retry-policy-0-base-interval", + rc: goodRouteConfigWithRetryPolicy(&v3routepb.RetryPolicy{RetryOn: "cancelled", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{BaseInterval: durationpb.New(0)}}, nil), + wantUpdate: goodUpdateIfRetryDisabled(), + wantError: env.RetrySupport, + }, + { + name: "bad-retry-policy-negative-max-interval", + rc: goodRouteConfigWithRetryPolicy(&v3routepb.RetryPolicy{RetryOn: "cancelled", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{MaxInterval: durationpb.New(-time.Second)}}, nil), + wantUpdate: goodUpdateIfRetryDisabled(), + wantError: env.RetrySupport, + }, + { + name: "bad-retry-policy-negative-max-interval-no-known-retry-on", + rc: goodRouteConfigWithRetryPolicy(&v3routepb.RetryPolicy{RetryOn: "something", RetryBackOff: &v3routepb.RetryPolicy_RetryBackOff{MaxInterval: durationpb.New(-time.Second)}}, nil), + wantUpdate: goodUpdateIfRetryDisabled(), + wantError: env.RetrySupport, }, } - for _, test := range tests { t.Run(test.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = !test.disableFI - gotUpdate, gotError := generateRDSUpdateFromRouteConfiguration(test.rc, nil, false) if (gotError != nil) != test.wantError || !cmp.Equal(gotUpdate, test.wantUpdate, cmpopts.EquateEmpty(), @@ -490,8 +581,6 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { return fmt.Sprint(fc) })) { t.Errorf("generateRDSUpdateFromRouteConfiguration(%+v, %v) returned unexpected, diff (-want +got):\\n%s", test.rc, ldsTarget, cmp.Diff(test.wantUpdate, gotUpdate, cmpopts.EquateEmpty())) - - env.FaultInjectionSupport = oldFI } }) } @@ -537,17 +626,10 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { }, }, } - v2RouteConfig = &anypb.Any{ - TypeUrl: version.V2RouteConfigURL, - Value: func() []byte { - rc := &v2xdspb.RouteConfiguration{ - Name: v2RouteConfigName, - VirtualHosts: v2VirtualHost, - } - m, _ := proto.Marshal(rc) - return m - }(), - } + v2RouteConfig = testutils.MarshalAny(&v2xdspb.RouteConfiguration{ + Name: v2RouteConfigName, + VirtualHosts: v2VirtualHost, + }) v3VirtualHost = []*v3routepb.VirtualHost{ { Domains: []string{uninterestingDomain}, @@ -576,24 +658,17 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { }, }, } - v3RouteConfig = &anypb.Any{ - TypeUrl: version.V2RouteConfigURL, - Value: func() []byte { - rc := &v3routepb.RouteConfiguration{ - Name: v3RouteConfigName, - VirtualHosts: v3VirtualHost, - } - m, _ := proto.Marshal(rc) - return m - }(), - } + v3RouteConfig = testutils.MarshalAny(&v3routepb.RouteConfiguration{ + Name: v3RouteConfigName, + VirtualHosts: v3VirtualHost, + }) ) const testVersion = "test-version-rds" tests := []struct { name string resources []*anypb.Any - wantUpdate map[string]RouteConfigUpdate + wantUpdate map[string]RouteConfigUpdateErrTuple wantMD UpdateMetadata wantErr bool }{ @@ -605,7 +680,7 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -623,7 +698,7 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -638,20 +713,24 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { { name: "v2 routeConfig resource", resources: []*anypb.Any{v2RouteConfig}, - wantUpdate: map[string]RouteConfigUpdate{ - v2RouteConfigName: { + wantUpdate: map[string]RouteConfigUpdateErrTuple{ + v2RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v2RouteConfig, - }, + }}, }, wantMD: UpdateMetadata{ Status: ServiceStatusACKed, @@ -661,20 +740,24 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { { name: "v3 routeConfig resource", resources: []*anypb.Any{v3RouteConfig}, - wantUpdate: map[string]RouteConfigUpdate{ - v3RouteConfigName: { + wantUpdate: map[string]RouteConfigUpdateErrTuple{ + v3RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v3RouteConfig, - }, + }}, }, wantMD: UpdateMetadata{ Status: ServiceStatusACKed, @@ -684,33 +767,41 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { { name: "multiple routeConfig resources", resources: []*anypb.Any{v2RouteConfig, v3RouteConfig}, - wantUpdate: map[string]RouteConfigUpdate{ - v3RouteConfigName: { + wantUpdate: map[string]RouteConfigUpdateErrTuple{ + v3RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v3RouteConfig, - }, - v2RouteConfigName: { + }}, + v2RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v2RouteConfig, - }, + }}, }, wantMD: UpdateMetadata{ Status: ServiceStatusACKed, @@ -722,57 +813,58 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { name: "good and bad routeConfig resources", resources: []*anypb.Any{ v2RouteConfig, - { - TypeUrl: version.V2RouteConfigURL, - Value: func() []byte { - rc := &v3routepb.RouteConfiguration{ - Name: "bad", - VirtualHosts: []*v3routepb.VirtualHost{ - {Domains: []string{ldsTarget}, - Routes: []*v3routepb.Route{{ - Match: &v3routepb.RouteMatch{PathSpecifier: &v3routepb.RouteMatch_ConnectMatcher_{}}, - }}}}} - m, _ := proto.Marshal(rc) - return m - }(), - }, + testutils.MarshalAny(&v3routepb.RouteConfiguration{ + Name: "bad", + VirtualHosts: []*v3routepb.VirtualHost{ + {Domains: []string{ldsTarget}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{PathSpecifier: &v3routepb.RouteMatch_ConnectMatcher_{}}, + }}}}}), v3RouteConfig, }, - wantUpdate: map[string]RouteConfigUpdate{ - v3RouteConfigName: { + wantUpdate: map[string]RouteConfigUpdateErrTuple{ + v3RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v3ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v3RouteConfig, - }, - v2RouteConfigName: { + }}, + v2RouteConfigName: {Update: RouteConfigUpdate{ VirtualHosts: []*VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, { Domains: []string{ldsTarget}, - Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}}}, + Routes: []*Route{{Prefix: newStringP(""), + WeightedClusters: map[string]WeightedCluster{v2ClusterName: {Weight: 1}}, + RouteAction: RouteActionRoute}}, }, }, Raw: v2RouteConfig, - }, - "bad": {}, + }}, + "bad": {Err: cmpopts.AnyError}, }, wantMD: UpdateMetadata{ Status: ServiceStatusNACKed, Version: testVersion, ErrState: &UpdateErrorMetadata{ Version: testVersion, - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantErr: true, @@ -780,9 +872,13 @@ func (s) TestUnmarshalRouteConfig(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - update, md, err := UnmarshalRouteConfig(testVersion, test.resources, nil) + opts := &UnmarshalOptions{ + Version: testVersion, + Resources: test.resources, + } + update, md, err := UnmarshalRouteConfig(opts) if (err != nil) != test.wantErr { - t.Fatalf("UnmarshalRouteConfig(), got err: %v, wantErr: %v", err, test.wantErr) + t.Fatalf("UnmarshalRouteConfig(%+v), got err: %v, wantErr: %v", opts, err, test.wantErr) } if diff := cmp.Diff(update, test.wantUpdate, cmpOpts); diff != "" { t.Errorf("got unexpected update, diff (-got +want): %v", diff) @@ -823,6 +919,7 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { CaseInsensitive: true, WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60, HTTPFilterConfigOverride: cfgs}}, HTTPFilterConfigOverride: cfgs, + RouteAction: RouteActionRoute, }} } ) @@ -832,7 +929,6 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { routes []*v3routepb.Route wantRoutes []*Route wantErr bool - disableFI bool // disable fault injection }{ { name: "no path", @@ -863,6 +959,7 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { Prefix: newStringP("/"), CaseInsensitive: true, WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, }}, }, { @@ -910,6 +1007,53 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { }, Fraction: newUInt32P(10000), WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, + { + name: "good with regex matchers", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: "/a/"}}, + Headers: []*v3routepb.HeaderMatcher{ + { + Name: "th", + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{SafeRegexMatch: &v3matcherpb.RegexMatcher{Regex: "tv"}}, + }, + }, + RuntimeFraction: &v3corepb.RuntimeFractionalPercent{ + DefaultValue: &v3typepb.FractionalPercent{ + Numerator: 1, + Denominator: v3typepb.FractionalPercent_HUNDRED, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}}}, + }, + }, + wantRoutes: []*Route{{ + Regex: func() *regexp.Regexp { return regexp.MustCompile("/a/") }(), + Headers: []*HeaderMatcher{ + { + Name: "th", + InvertMatch: newBoolP(false), + RegexMatch: func() *regexp.Regexp { return regexp.MustCompile("tv") }(), + }, + }, + Fraction: newUInt32P(10000), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, }}, wantErr: false, }, @@ -944,6 +1088,7 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { wantRoutes: []*Route{{ Prefix: newStringP("/a/"), WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, }}, wantErr: false, }, @@ -958,6 +1103,44 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { }, wantErr: true, }, + { + name: "bad regex in path specifier", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: "??"}}, + Headers: []*v3routepb.HeaderMatcher{ + { + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "tv"}, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}}, + }, + }, + }, + wantErr: true, + }, + { + name: "bad regex in header specifier", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + Headers: []*v3routepb.HeaderMatcher{ + { + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{SafeRegexMatch: &v3matcherpb.RegexMatcher{Regex: "??"}}, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ClusterSpecifier: &v3routepb.RouteAction_Cluster{Cluster: clusterName}}, + }, + }, + }, + wantErr: true, + }, { name: "unrecognized header match specifier", routes: []*v3routepb.Route{ @@ -967,7 +1150,7 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { Headers: []*v3routepb.HeaderMatcher{ { Name: "th", - HeaderMatchSpecifier: &v3routepb.HeaderMatcher_HiddenEnvoyDeprecatedRegexMatch{}, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_StringMatch{}, }, }, }, @@ -1011,6 +1194,227 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { }, wantErr: true, }, + { + name: "totalWeight is nil in weighted clusters action", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 20}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 30}}, + }, + }}}}, + }, + }, + wantErr: true, + }, + { + name: "The sum of all weighted clusters is not equal totalWeight", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 50}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 20}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}}}, + }, + }, + wantErr: true, + }, + { + name: "unsupported cluster specifier", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_ClusterSpecifierPlugin{}}}, + }, + }, + wantErr: true, + }, + { + name: "default totalWeight is 100 in weighted clusters action", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + }}}}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, + { + name: "default totalWeight is 100 in weighted clusters action", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 30}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 20}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 50}, + }}}}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 20}, "B": {Weight: 30}}, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, + { + name: "good-with-channel-id-hash-policy", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + Headers: []*v3routepb.HeaderMatcher{ + { + Name: "th", + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{ + PrefixMatch: "tv", + }, + InvertMatch: true, + }, + }, + RuntimeFraction: &v3corepb.RuntimeFractionalPercent{ + DefaultValue: &v3typepb.FractionalPercent{ + Numerator: 1, + Denominator: v3typepb.FractionalPercent_HUNDRED, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "io.grpc.channel_id"}}}, + }, + }}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + Headers: []*HeaderMatcher{ + { + Name: "th", + InvertMatch: newBoolP(true), + PrefixMatch: newStringP("tv"), + }, + }, + Fraction: newUInt32P(10000), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + HashPolicies: []*HashPolicy{ + {HashPolicyType: HashPolicyTypeChannelID}, + }, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, + // This tests that policy.Regex ends up being nil if RegexRewrite is not + // set in xds response. + { + name: "good-with-header-hash-policy-no-regex-specified", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + Headers: []*v3routepb.HeaderMatcher{ + { + Name: "th", + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{ + PrefixMatch: "tv", + }, + InvertMatch: true, + }, + }, + RuntimeFraction: &v3corepb.RuntimeFractionalPercent{ + DefaultValue: &v3typepb.FractionalPercent{ + Numerator: 1, + Denominator: v3typepb.FractionalPercent_HUNDRED, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{Header: &v3routepb.RouteAction_HashPolicy_Header{HeaderName: ":path"}}}, + }, + }}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + Headers: []*HeaderMatcher{ + { + Name: "th", + InvertMatch: newBoolP(true), + PrefixMatch: newStringP("tv"), + }, + }, + Fraction: newUInt32P(10000), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + HashPolicies: []*HashPolicy{ + {HashPolicyType: HashPolicyTypeHeader, + HeaderName: ":path"}, + }, + RouteAction: RouteActionRoute, + }}, + wantErr: false, + }, { name: "with custom HTTP filter config", routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": customFilterConfig}), @@ -1026,12 +1430,6 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": wrappedOptionalFilter("custom.filter")}), wantRoutes: goodUpdateWithFilterConfigs(map[string]httpfilter.FilterConfig{"foo": filterConfig{Override: customFilterConfig}}), }, - { - name: "with custom HTTP filter config, FI disabled", - disableFI: true, - routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": customFilterConfig}), - wantRoutes: goodUpdateWithFilterConfigs(nil), - }, { name: "with erroring custom HTTP filter config", routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": errFilterConfig}), @@ -1042,12 +1440,6 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": wrappedOptionalFilter("err.custom.filter")}), wantErr: true, }, - { - name: "with erroring custom HTTP filter config, FI disabled", - disableFI: true, - routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": errFilterConfig}), - wantRoutes: goodUpdateWithFilterConfigs(nil), - }, { name: "with unknown custom HTTP filter config", routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": unknownFilterConfig}), @@ -1061,28 +1453,137 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { } cmpOpts := []cmp.Option{ - cmp.AllowUnexported(Route{}, HeaderMatcher{}, Int64Range{}), + cmp.AllowUnexported(Route{}, HeaderMatcher{}, Int64Range{}, regexp.Regexp{}), cmpopts.EquateEmpty(), cmp.Transformer("FilterConfig", func(fc httpfilter.FilterConfig) string { return fmt.Sprint(fc) }), } - + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = !tt.disableFI - got, err := routesProtoToSlice(tt.routes, nil, false) if (err != nil) != tt.wantErr { - t.Errorf("routesProtoToSlice() error = %v, wantErr %v", err, tt.wantErr) - return + t.Fatalf("routesProtoToSlice() error = %v, wantErr %v", err, tt.wantErr) } - if !cmp.Equal(got, tt.wantRoutes, cmpOpts...) { - t.Errorf("routesProtoToSlice() got = %v, want %v, diff: %v", got, tt.wantRoutes, cmp.Diff(got, tt.wantRoutes, cmpOpts...)) + if diff := cmp.Diff(got, tt.wantRoutes, cmpOpts...); diff != "" { + t.Fatalf("routesProtoToSlice() returned unexpected diff (-got +want):\n%s", diff) } + }) + } +} + +func (s) TestHashPoliciesProtoToSlice(t *testing.T) { + tests := []struct { + name string + hashPolicies []*v3routepb.RouteAction_HashPolicy + wantHashPolicies []*HashPolicy + wantErr bool + }{ + // header-hash-policy tests a basic hash policy that specifies to hash a + // certain header. + { + name: "header-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + { + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: ":path", + RegexRewrite: &v3matcherpb.RegexMatchAndSubstitute{ + Pattern: &v3matcherpb.RegexMatcher{Regex: "/products"}, + Substitution: "/products", + }, + }, + }, + }, + }, + wantHashPolicies: []*HashPolicy{ + { + HashPolicyType: HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("/products") }(), + RegexSubstitution: "/products", + }, + }, + }, + // channel-id-hash-policy tests a basic hash policy that specifies to + // hash a unique identifier of the channel. + { + name: "channel-id-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "io.grpc.channel_id"}}}, + }, + wantHashPolicies: []*HashPolicy{ + {HashPolicyType: HashPolicyTypeChannelID}, + }, + }, + // unsupported-filter-state-key tests that an unsupported key in the + // filter state hash policy are treated as a no-op. + { + name: "wrong-filter-state-key", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "unsupported key"}}}, + }, + }, + // no-op-hash-policy tests that hash policies that are not supported by + // grpc are treated as a no-op. + { + name: "no-op-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{}}, + }, + }, + // header-and-channel-id-hash-policy test that a list of header and + // channel id hash policies are successfully converted to an internal + // struct. + { + name: "header-and-channel-id-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + { + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: ":path", + RegexRewrite: &v3matcherpb.RegexMatchAndSubstitute{ + Pattern: &v3matcherpb.RegexMatcher{Regex: "/products"}, + Substitution: "/products", + }, + }, + }, + }, + { + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "io.grpc.channel_id"}}, + Terminal: true, + }, + }, + wantHashPolicies: []*HashPolicy{ + { + HashPolicyType: HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("/products") }(), + RegexSubstitution: "/products", + }, + { + HashPolicyType: HashPolicyTypeChannelID, + Terminal: true, + }, + }, + }, + } - env.FaultInjectionSupport = oldFI + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := hashPoliciesProtoToSlice(tt.hashPolicies, nil) + if (err != nil) != tt.wantErr { + t.Fatalf("hashPoliciesProtoToSlice() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.wantHashPolicies, cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Fatalf("hashPoliciesProtoToSlice() returned unexpected diff (-got +want):\n%s", diff) + } }) } } diff --git a/xds/internal/xdsclient/requests_counter.go b/xds/internal/xdsclient/requests_counter.go new file mode 100644 index 00000000000..beed2e9d0ad --- /dev/null +++ b/xds/internal/xdsclient/requests_counter.go @@ -0,0 +1,107 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "fmt" + "sync" + "sync/atomic" +) + +type clusterNameAndServiceName struct { + clusterName, edsServcieName string +} + +type clusterRequestsCounter struct { + mu sync.Mutex + clusters map[clusterNameAndServiceName]*ClusterRequestsCounter +} + +var src = &clusterRequestsCounter{ + clusters: make(map[clusterNameAndServiceName]*ClusterRequestsCounter), +} + +// ClusterRequestsCounter is used to track the total inflight requests for a +// service with the provided name. +type ClusterRequestsCounter struct { + ClusterName string + EDSServiceName string + numRequests uint32 +} + +// GetClusterRequestsCounter returns the ClusterRequestsCounter with the +// provided serviceName. If one does not exist, it creates it. +func GetClusterRequestsCounter(clusterName, edsServiceName string) *ClusterRequestsCounter { + src.mu.Lock() + defer src.mu.Unlock() + k := clusterNameAndServiceName{ + clusterName: clusterName, + edsServcieName: edsServiceName, + } + c, ok := src.clusters[k] + if !ok { + c = &ClusterRequestsCounter{ClusterName: clusterName} + src.clusters[k] = c + } + return c +} + +// StartRequest starts a request for a cluster, incrementing its number of +// requests by 1. Returns an error if the max number of requests is exceeded. +func (c *ClusterRequestsCounter) StartRequest(max uint32) error { + // Note that during race, the limits could be exceeded. This is allowed: + // "Since the implementation is eventually consistent, races between threads + // may allow limits to be potentially exceeded." + // https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/upstream/circuit_breaking#arch-overview-circuit-break. + if atomic.LoadUint32(&c.numRequests) >= max { + return fmt.Errorf("max requests %v exceeded on service %v", max, c.ClusterName) + } + atomic.AddUint32(&c.numRequests, 1) + return nil +} + +// EndRequest ends a request for a service, decrementing its number of requests +// by 1. +func (c *ClusterRequestsCounter) EndRequest() { + atomic.AddUint32(&c.numRequests, ^uint32(0)) +} + +// ClearCounterForTesting clears the counter for the service. Should be only +// used in tests. +func ClearCounterForTesting(clusterName, edsServiceName string) { + src.mu.Lock() + defer src.mu.Unlock() + k := clusterNameAndServiceName{ + clusterName: clusterName, + edsServcieName: edsServiceName, + } + c, ok := src.clusters[k] + if !ok { + return + } + c.numRequests = 0 +} + +// ClearAllCountersForTesting clears all the counters. Should be only used in +// tests. +func ClearAllCountersForTesting() { + src.mu.Lock() + defer src.mu.Unlock() + src.clusters = make(map[clusterNameAndServiceName]*ClusterRequestsCounter) +} diff --git a/xds/internal/client/requests_counter_test.go b/xds/internal/xdsclient/requests_counter_test.go similarity index 76% rename from xds/internal/client/requests_counter_test.go rename to xds/internal/xdsclient/requests_counter_test.go index fe532724d14..e2eeea774e2 100644 --- a/xds/internal/client/requests_counter_test.go +++ b/xds/internal/xdsclient/requests_counter_test.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "sync" @@ -24,6 +24,8 @@ import ( "testing" ) +const testService = "test-service-name" + type counterTest struct { name string maxRequests uint32 @@ -49,9 +51,9 @@ var tests = []counterTest{ }, } -func resetServiceRequestsCounter() { - src = &servicesRequestsCounter{ - services: make(map[string]*ServiceRequestsCounter), +func resetClusterRequestsCounter() { + src = &clusterRequestsCounter{ + clusters: make(map[clusterNameAndServiceName]*ClusterRequestsCounter), } } @@ -65,7 +67,7 @@ func testCounter(t *testing.T, test counterTest) { var successes, errors uint32 for i := 0; i < int(test.numRequests); i++ { go func() { - counter := GetServiceRequestsCounter(test.name) + counter := GetClusterRequestsCounter(test.name, testService) defer requestsDone.Done() err := counter.StartRequest(test.maxRequests) if err == nil { @@ -91,13 +93,17 @@ func testCounter(t *testing.T, test counterTest) { if test.expectedErrors == 0 && loadedError != nil { t.Errorf("error starting request: %v", loadedError.(error)) } - if successes != test.expectedSuccesses || errors != test.expectedErrors { + // We allow the limits to be exceeded during races. + // + // But we should never over-limit, so this test fails if there are less + // successes than expected. + if successes < test.expectedSuccesses || errors > test.expectedErrors { t.Errorf("unexpected number of (successes, errors), expected (%v, %v), encountered (%v, %v)", test.expectedSuccesses, test.expectedErrors, successes, errors) } } func (s) TestRequestsCounter(t *testing.T) { - defer resetServiceRequestsCounter() + defer resetClusterRequestsCounter() for _, test := range tests { t.Run(test.name, func(t *testing.T) { testCounter(t, test) @@ -105,18 +111,18 @@ func (s) TestRequestsCounter(t *testing.T) { } } -func (s) TestGetServiceRequestsCounter(t *testing.T) { - defer resetServiceRequestsCounter() +func (s) TestGetClusterRequestsCounter(t *testing.T) { + defer resetClusterRequestsCounter() for _, test := range tests { - counterA := GetServiceRequestsCounter(test.name) - counterB := GetServiceRequestsCounter(test.name) + counterA := GetClusterRequestsCounter(test.name, testService) + counterB := GetClusterRequestsCounter(test.name, testService) if counterA != counterB { t.Errorf("counter %v %v != counter %v %v", counterA, *counterA, counterB, *counterB) } } } -func startRequests(t *testing.T, n uint32, max uint32, counter *ServiceRequestsCounter) { +func startRequests(t *testing.T, n uint32, max uint32, counter *ClusterRequestsCounter) { for i := uint32(0); i < n; i++ { if err := counter.StartRequest(max); err != nil { t.Fatalf("error starting initial request: %v", err) @@ -125,11 +131,11 @@ func startRequests(t *testing.T, n uint32, max uint32, counter *ServiceRequestsC } func (s) TestSetMaxRequestsIncreased(t *testing.T) { - defer resetServiceRequestsCounter() - const serviceName string = "set-max-requests-increased" + defer resetClusterRequestsCounter() + const clusterName string = "set-max-requests-increased" var initialMax uint32 = 16 - counter := GetServiceRequestsCounter(serviceName) + counter := GetClusterRequestsCounter(clusterName, testService) startRequests(t, initialMax, initialMax, counter) if err := counter.StartRequest(initialMax); err == nil { t.Fatal("unexpected success on start request after max met") @@ -142,11 +148,11 @@ func (s) TestSetMaxRequestsIncreased(t *testing.T) { } func (s) TestSetMaxRequestsDecreased(t *testing.T) { - defer resetServiceRequestsCounter() - const serviceName string = "set-max-requests-decreased" + defer resetClusterRequestsCounter() + const clusterName string = "set-max-requests-decreased" var initialMax uint32 = 16 - counter := GetServiceRequestsCounter(serviceName) + counter := GetClusterRequestsCounter(clusterName, testService) startRequests(t, initialMax-1, initialMax, counter) newMax := initialMax - 1 diff --git a/xds/internal/xdsclient/singleton.go b/xds/internal/xdsclient/singleton.go new file mode 100644 index 00000000000..f045790e2a4 --- /dev/null +++ b/xds/internal/xdsclient/singleton.go @@ -0,0 +1,198 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "bytes" + "encoding/json" + "fmt" + "sync" + "time" + + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" +) + +const defaultWatchExpiryTimeout = 15 * time.Second + +// This is the Client returned by New(). It contains one client implementation, +// and maintains the refcount. +var singletonClient = &clientRefCounted{} + +// To override in tests. +var bootstrapNewConfig = bootstrap.NewConfig + +// clientRefCounted is ref-counted, and to be shared by the xds resolver and +// balancer implementations, across multiple ClientConns and Servers. +type clientRefCounted struct { + *clientImpl + + // This mu protects all the fields, including the embedded clientImpl above. + mu sync.Mutex + refCount int +} + +// New returns a new xdsClient configured by the bootstrap file specified in env +// variable GRPC_XDS_BOOTSTRAP or GRPC_XDS_BOOTSTRAP_CONFIG. +// +// The returned xdsClient is a singleton. This function creates the xds client +// if it doesn't already exist. +// +// Note that the first invocation of New() or NewWithConfig() sets the client +// singleton. The following calls will return the singleton xds client without +// checking or using the config. +func New() (XDSClient, error) { + // This cannot just return newRefCounted(), because in error cases, the + // returned nil is a typed nil (*clientRefCounted), which may cause nil + // checks fail. + c, err := newRefCounted() + if err != nil { + return nil, err + } + return c, nil +} + +func newRefCounted() (*clientRefCounted, error) { + singletonClient.mu.Lock() + defer singletonClient.mu.Unlock() + // If the client implementation was created, increment ref count and return + // the client. + if singletonClient.clientImpl != nil { + singletonClient.refCount++ + return singletonClient, nil + } + + // Create the new client implementation. + config, err := bootstrapNewConfig() + if err != nil { + return nil, fmt.Errorf("xds: failed to read bootstrap file: %v", err) + } + c, err := newWithConfig(config, defaultWatchExpiryTimeout) + if err != nil { + return nil, err + } + + singletonClient.clientImpl = c + singletonClient.refCount++ + return singletonClient, nil +} + +// NewWithConfig returns a new xdsClient configured by the given config. +// +// The returned xdsClient is a singleton. This function creates the xds client +// if it doesn't already exist. +// +// Note that the first invocation of New() or NewWithConfig() sets the client +// singleton. The following calls will return the singleton xds client without +// checking or using the config. +// +// This function is internal only, for c2p resolver and testing to use. DO NOT +// use this elsewhere. Use New() instead. +func NewWithConfig(config *bootstrap.Config) (XDSClient, error) { + singletonClient.mu.Lock() + defer singletonClient.mu.Unlock() + // If the client implementation was created, increment ref count and return + // the client. + if singletonClient.clientImpl != nil { + singletonClient.refCount++ + return singletonClient, nil + } + + // Create the new client implementation. + c, err := newWithConfig(config, defaultWatchExpiryTimeout) + if err != nil { + return nil, err + } + + singletonClient.clientImpl = c + singletonClient.refCount++ + return singletonClient, nil +} + +// Close closes the client. It does ref count of the xds client implementation, +// and closes the gRPC connection to the management server when ref count +// reaches 0. +func (c *clientRefCounted) Close() { + c.mu.Lock() + defer c.mu.Unlock() + c.refCount-- + if c.refCount == 0 { + c.clientImpl.Close() + // Set clientImpl back to nil. So if New() is called after this, a new + // implementation will be created. + c.clientImpl = nil + } +} + +// NewWithConfigForTesting is exported for testing only. +// +// Note that this function doesn't set the singleton, so that the testing states +// don't leak. +func NewWithConfigForTesting(config *bootstrap.Config, watchExpiryTimeout time.Duration) (XDSClient, error) { + cl, err := newWithConfig(config, watchExpiryTimeout) + if err != nil { + return nil, err + } + return &clientRefCounted{clientImpl: cl, refCount: 1}, nil +} + +// NewClientWithBootstrapContents returns an xds client for this config, +// separate from the global singleton. This should be used for testing +// purposes only. +func NewClientWithBootstrapContents(contents []byte) (XDSClient, error) { + // Normalize the contents + buf := bytes.Buffer{} + err := json.Indent(&buf, contents, "", "") + if err != nil { + return nil, fmt.Errorf("xds: error normalizing JSON: %v", err) + } + contents = bytes.TrimSpace(buf.Bytes()) + + clientsMu.Lock() + defer clientsMu.Unlock() + if c := clients[string(contents)]; c != nil { + c.mu.Lock() + // Since we don't remove the *Client from the map when it is closed, we + // need to recreate the impl if the ref count dropped to zero. + if c.refCount > 0 { + c.refCount++ + c.mu.Unlock() + return c, nil + } + c.mu.Unlock() + } + + bcfg, err := bootstrap.NewConfigFromContents(contents) + if err != nil { + return nil, fmt.Errorf("xds: error with bootstrap config: %v", err) + } + + cImpl, err := newWithConfig(bcfg, defaultWatchExpiryTimeout) + if err != nil { + return nil, err + } + + c := &clientRefCounted{clientImpl: cImpl, refCount: 1} + clients[string(contents)] = c + return c, nil +} + +var ( + clients = map[string]*clientRefCounted{} + clientsMu sync.Mutex +) diff --git a/xds/internal/client/transport_helper.go b/xds/internal/xdsclient/transport_helper.go similarity index 98% rename from xds/internal/client/transport_helper.go rename to xds/internal/xdsclient/transport_helper.go index b286a61d638..4c56daaf011 100644 --- a/xds/internal/client/transport_helper.go +++ b/xds/internal/xdsclient/transport_helper.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "context" @@ -24,7 +24,7 @@ import ( "time" "github.com/golang/protobuf/proto" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" "google.golang.org/grpc" "google.golang.org/grpc/internal/buffer" @@ -297,7 +297,7 @@ func (t *TransportHelper) sendExisting(stream grpc.ClientStream) bool { for rType, s := range t.watchMap { if err := t.vClient.SendRequest(stream, mapToSlice(s), rType, "", "", ""); err != nil { - t.logger.Errorf("ADS request failed: %v", err) + t.logger.Warningf("ADS request failed: %v", err) return false } } @@ -342,11 +342,12 @@ func (t *TransportHelper) recv(stream grpc.ClientStream) bool { } } -func mapToSlice(m map[string]bool) (ret []string) { +func mapToSlice(m map[string]bool) []string { + ret := make([]string, 0, len(m)) for i := range m { ret = append(ret, i) } - return + return ret } type watchAction struct { diff --git a/xds/internal/client/v2/ack_test.go b/xds/internal/xdsclient/v2/ack_test.go similarity index 99% rename from xds/internal/client/v2/ack_test.go rename to xds/internal/xdsclient/v2/ack_test.go index 813d8baa79d..d2f0605f6d0 100644 --- a/xds/internal/client/v2/ack_test.go +++ b/xds/internal/xdsclient/v2/ack_test.go @@ -31,9 +31,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( diff --git a/xds/internal/client/v2/cds_test.go b/xds/internal/xdsclient/v2/cds_test.go similarity index 86% rename from xds/internal/client/v2/cds_test.go rename to xds/internal/xdsclient/v2/cds_test.go index c71b8453231..cef7563017c 100644 --- a/xds/internal/client/v2/cds_test.go +++ b/xds/internal/xdsclient/v2/cds_test.go @@ -24,10 +24,11 @@ import ( xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - "github.com/golang/protobuf/ptypes" anypb "github.com/golang/protobuf/ptypes/any" - xdsclient "google.golang.org/grpc/xds/internal/client" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -63,8 +64,8 @@ var ( }, }, } - marshaledCluster1, _ = ptypes.MarshalAny(goodCluster1) - goodCluster2 = &xdspb.Cluster{ + marshaledCluster1 = testutils.MarshalAny(goodCluster1) + goodCluster2 = &xdspb.Cluster{ Name: goodClusterName2, ClusterDiscoveryType: &xdspb.Cluster_Type{Type: xdspb.Cluster_EDS}, EdsClusterConfig: &xdspb.Cluster_EdsClusterConfig{ @@ -77,8 +78,8 @@ var ( }, LbPolicy: xdspb.Cluster_ROUND_ROBIN, } - marshaledCluster2, _ = ptypes.MarshalAny(goodCluster2) - goodCDSResponse1 = &xdspb.DiscoveryResponse{ + marshaledCluster2 = testutils.MarshalAny(goodCluster2) + goodCDSResponse1 = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ marshaledCluster1, }, @@ -100,7 +101,7 @@ func (s) TestCDSHandleResponse(t *testing.T) { name string cdsResponse *xdspb.DiscoveryResponse wantErr bool - wantUpdate map[string]xdsclient.ClusterUpdate + wantUpdate map[string]xdsclient.ClusterUpdateErrTuple wantUpdateMD xdsclient.UpdateMetadata wantUpdateErr bool }{ @@ -113,7 +114,7 @@ func (s) TestCDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -127,7 +128,7 @@ func (s) TestCDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -148,8 +149,8 @@ func (s) TestCDSHandleResponse(t *testing.T) { name: "one-uninteresting-cluster", cdsResponse: goodCDSResponse2, wantErr: false, - wantUpdate: map[string]xdsclient.ClusterUpdate{ - goodClusterName2: {ServiceName: serviceName2, Raw: marshaledCluster2}, + wantUpdate: map[string]xdsclient.ClusterUpdateErrTuple{ + goodClusterName2: {Update: xdsclient.ClusterUpdate{ClusterName: goodClusterName2, EDSServiceName: serviceName2, Raw: marshaledCluster2}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -161,8 +162,8 @@ func (s) TestCDSHandleResponse(t *testing.T) { name: "one-good-cluster", cdsResponse: goodCDSResponse1, wantErr: false, - wantUpdate: map[string]xdsclient.ClusterUpdate{ - goodClusterName1: {ServiceName: serviceName1, EnableLRS: true, Raw: marshaledCluster1}, + wantUpdate: map[string]xdsclient.ClusterUpdateErrTuple{ + goodClusterName1: {Update: xdsclient.ClusterUpdate{ClusterName: goodClusterName1, EDSServiceName: serviceName1, EnableLRS: true, Raw: marshaledCluster1}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, diff --git a/xds/internal/client/v2/client.go b/xds/internal/xdsclient/v2/client.go similarity index 82% rename from xds/internal/client/v2/client.go rename to xds/internal/xdsclient/v2/client.go index b6bc4908120..dc137f63e5f 100644 --- a/xds/internal/client/v2/client.go +++ b/xds/internal/xdsclient/v2/client.go @@ -27,8 +27,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpclog" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" @@ -65,10 +66,11 @@ func newClient(cc *grpc.ClientConn, opts xdsclient.BuildOptions) (xdsclient.APIC return nil, fmt.Errorf("xds: unsupported Node proto type: %T, want %T", opts.NodeProto, (*v2corepb.Node)(nil)) } v2c := &client{ - cc: cc, - parent: opts.Parent, - nodeProto: nodeProto, - logger: opts.Logger, + cc: cc, + parent: opts.Parent, + nodeProto: nodeProto, + logger: opts.Logger, + updateValidator: opts.Validator, } v2c.ctx, v2c.cancelCtx = context.WithCancel(context.Background()) v2c.TransportHelper = xdsclient.NewTransportHelper(v2c, opts.Logger, opts.Backoff) @@ -89,8 +91,9 @@ type client struct { logger *grpclog.PrefixLogger // ClientConn to the xDS gRPC server. Owned by the parent xdsClient. - cc *grpc.ClientConn - nodeProto *v2corepb.Node + cc *grpc.ClientConn + nodeProto *v2corepb.Node + updateValidator xdsclient.UpdateValidatorFunc } func (v2c *client) NewStream(ctx context.Context) (grpc.ClientStream, error) { @@ -125,7 +128,7 @@ func (v2c *client) SendRequest(s grpc.ClientStream, resourceNames []string, rTyp if err := stream.Send(req); err != nil { return fmt.Errorf("xds: stream.Send(%+v) failed: %v", req, err) } - v2c.logger.Debugf("ADS request sent: %v", req) + v2c.logger.Debugf("ADS request sent: %v", pretty.ToJSON(req)) return nil } @@ -139,11 +142,11 @@ func (v2c *client) RecvResponse(s grpc.ClientStream) (proto.Message, error) { resp, err := stream.Recv() if err != nil { - // TODO: call watch callbacks with error when stream is broken. + v2c.parent.NewConnectionError(err) return nil, fmt.Errorf("xds: stream.Recv() failed: %v", err) } v2c.logger.Infof("ADS response received, type: %v", resp.GetTypeUrl()) - v2c.logger.Debugf("ADS response received: %v", resp) + v2c.logger.Debugf("ADS response received: %v", pretty.ToJSON(resp)) return resp, nil } @@ -185,7 +188,12 @@ func (v2c *client) HandleResponse(r proto.Message) (xdsclient.ResourceType, stri // server. On receipt of a good response, it also invokes the registered watcher // callback. func (v2c *client) handleLDSResponse(resp *v2xdspb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalListener(resp.GetVersionInfo(), resp.GetResources(), v2c.logger) + update, md, err := xdsclient.UnmarshalListener(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v2c.logger, + UpdateValidator: v2c.updateValidator, + }) v2c.parent.NewListeners(update, md) return err } @@ -194,7 +202,12 @@ func (v2c *client) handleLDSResponse(resp *v2xdspb.DiscoveryResponse) error { // server. On receipt of a good response, it caches validated resources and also // invokes the registered watcher callback. func (v2c *client) handleRDSResponse(resp *v2xdspb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalRouteConfig(resp.GetVersionInfo(), resp.GetResources(), v2c.logger) + update, md, err := xdsclient.UnmarshalRouteConfig(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v2c.logger, + UpdateValidator: v2c.updateValidator, + }) v2c.parent.NewRouteConfigs(update, md) return err } @@ -203,13 +216,23 @@ func (v2c *client) handleRDSResponse(resp *v2xdspb.DiscoveryResponse) error { // server. On receipt of a good response, it also invokes the registered watcher // callback. func (v2c *client) handleCDSResponse(resp *v2xdspb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalCluster(resp.GetVersionInfo(), resp.GetResources(), v2c.logger) + update, md, err := xdsclient.UnmarshalCluster(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v2c.logger, + UpdateValidator: v2c.updateValidator, + }) v2c.parent.NewClusters(update, md) return err } func (v2c *client) handleEDSResponse(resp *v2xdspb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalEndpoints(resp.GetVersionInfo(), resp.GetResources(), v2c.logger) + update, md, err := xdsclient.UnmarshalEndpoints(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v2c.logger, + UpdateValidator: v2c.updateValidator, + }) v2c.parent.NewEndpoints(update, md) return err } diff --git a/xds/internal/client/v2/client_test.go b/xds/internal/xdsclient/v2/client_test.go similarity index 88% rename from xds/internal/client/v2/client_test.go rename to xds/internal/xdsclient/v2/client_test.go index e770324e1b1..ed4322b0dc5 100644 --- a/xds/internal/client/v2/client_test.go +++ b/xds/internal/xdsclient/v2/client_test.go @@ -21,12 +21,10 @@ package v2 import ( "context" "errors" - "fmt" "testing" "time" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc" @@ -36,9 +34,9 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" "google.golang.org/protobuf/testing/protocmp" xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" @@ -112,30 +110,24 @@ var ( }, }, } - marshaledConnMgr1, _ = proto.Marshal(goodHTTPConnManager1) - goodListener1 = &xdspb.Listener{ + marshaledConnMgr1 = testutils.MarshalAny(goodHTTPConnManager1) + goodListener1 = &xdspb.Listener{ Name: goodLDSTarget1, ApiListener: &listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, + ApiListener: marshaledConnMgr1, }, } - marshaledListener1, _ = ptypes.MarshalAny(goodListener1) - goodListener2 = &xdspb.Listener{ + marshaledListener1 = testutils.MarshalAny(goodListener1) + goodListener2 = &xdspb.Listener{ Name: goodLDSTarget2, ApiListener: &listenerpb.ApiListener{ - ApiListener: &anypb.Any{ - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, + ApiListener: marshaledConnMgr1, }, } - marshaledListener2, _ = ptypes.MarshalAny(goodListener2) - noAPIListener = &xdspb.Listener{Name: goodLDSTarget1} - marshaledNoAPIListener, _ = proto.Marshal(noAPIListener) - badAPIListener2 = &xdspb.Listener{ + marshaledListener2 = testutils.MarshalAny(goodListener2) + noAPIListener = &xdspb.Listener{Name: goodLDSTarget1} + marshaledNoAPIListener = testutils.MarshalAny(noAPIListener) + badAPIListener2 = &xdspb.Listener{ Name: goodLDSTarget2, ApiListener: &listenerpb.ApiListener{ ApiListener: &anypb.Any{ @@ -168,13 +160,8 @@ var ( TypeUrl: version.V2ListenerURL, } badResourceTypeInLDSResponse = &xdspb.DiscoveryResponse{ - Resources: []*anypb.Any{ - { - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, - }, - TypeUrl: version.V2ListenerURL, + Resources: []*anypb.Any{marshaledConnMgr1}, + TypeUrl: version.V2ListenerURL, } ldsResponseWithMultipleResources = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ @@ -184,13 +171,8 @@ var ( TypeUrl: version.V2ListenerURL, } noAPIListenerLDSResponse = &xdspb.DiscoveryResponse{ - Resources: []*anypb.Any{ - { - TypeUrl: version.V2ListenerURL, - Value: marshaledNoAPIListener, - }, - }, - TypeUrl: version.V2ListenerURL, + Resources: []*anypb.Any{marshaledNoAPIListener}, + TypeUrl: version.V2ListenerURL, } goodBadUglyLDSResponse = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ @@ -213,19 +195,14 @@ var ( TypeUrl: version.V2RouteConfigURL, } badResourceTypeInRDSResponse = &xdspb.DiscoveryResponse{ - Resources: []*anypb.Any{ - { - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, - }, - TypeUrl: version.V2RouteConfigURL, + Resources: []*anypb.Any{marshaledConnMgr1}, + TypeUrl: version.V2RouteConfigURL, } noVirtualHostsRouteConfig = &xdspb.RouteConfiguration{ Name: goodRouteName1, } - marshaledNoVirtualHostsRouteConfig, _ = ptypes.MarshalAny(noVirtualHostsRouteConfig) - noVirtualHostsInRDSResponse = &xdspb.DiscoveryResponse{ + marshaledNoVirtualHostsRouteConfig = testutils.MarshalAny(noVirtualHostsRouteConfig) + noVirtualHostsInRDSResponse = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ marshaledNoVirtualHostsRouteConfig, }, @@ -262,8 +239,8 @@ var ( }, }, } - marshaledGoodRouteConfig1, _ = ptypes.MarshalAny(goodRouteConfig1) - goodRouteConfig2 = &xdspb.RouteConfiguration{ + marshaledGoodRouteConfig1 = testutils.MarshalAny(goodRouteConfig1) + goodRouteConfig2 = &xdspb.RouteConfiguration{ Name: goodRouteName2, VirtualHosts: []*routepb.VirtualHost{ { @@ -294,8 +271,8 @@ var ( }, }, } - marshaledGoodRouteConfig2, _ = ptypes.MarshalAny(goodRouteConfig2) - goodRDSResponse1 = &xdspb.DiscoveryResponse{ + marshaledGoodRouteConfig2 = testutils.MarshalAny(goodRouteConfig2) + goodRDSResponse1 = &xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ marshaledGoodRouteConfig1, }, @@ -307,9 +284,6 @@ var ( }, TypeUrl: version.V2RouteConfigURL, } - // An place holder error. When comparing UpdateErrorMetadata, we only check - // if error is nil, and don't compare error content. - errPlaceHolder = fmt.Errorf("err place holder") ) type watchHandleTestcase struct { @@ -327,7 +301,7 @@ type testUpdateReceiver struct { f func(rType xdsclient.ResourceType, d map[string]interface{}, md xdsclient.UpdateMetadata) } -func (t *testUpdateReceiver) NewListeners(d map[string]xdsclient.ListenerUpdate, metadata xdsclient.UpdateMetadata) { +func (t *testUpdateReceiver) NewListeners(d map[string]xdsclient.ListenerUpdateErrTuple, metadata xdsclient.UpdateMetadata) { dd := make(map[string]interface{}) for k, v := range d { dd[k] = v @@ -335,7 +309,7 @@ func (t *testUpdateReceiver) NewListeners(d map[string]xdsclient.ListenerUpdate, t.newUpdate(xdsclient.ListenerResource, dd, metadata) } -func (t *testUpdateReceiver) NewRouteConfigs(d map[string]xdsclient.RouteConfigUpdate, metadata xdsclient.UpdateMetadata) { +func (t *testUpdateReceiver) NewRouteConfigs(d map[string]xdsclient.RouteConfigUpdateErrTuple, metadata xdsclient.UpdateMetadata) { dd := make(map[string]interface{}) for k, v := range d { dd[k] = v @@ -343,7 +317,7 @@ func (t *testUpdateReceiver) NewRouteConfigs(d map[string]xdsclient.RouteConfigU t.newUpdate(xdsclient.RouteConfigResource, dd, metadata) } -func (t *testUpdateReceiver) NewClusters(d map[string]xdsclient.ClusterUpdate, metadata xdsclient.UpdateMetadata) { +func (t *testUpdateReceiver) NewClusters(d map[string]xdsclient.ClusterUpdateErrTuple, metadata xdsclient.UpdateMetadata) { dd := make(map[string]interface{}) for k, v := range d { dd[k] = v @@ -351,7 +325,7 @@ func (t *testUpdateReceiver) NewClusters(d map[string]xdsclient.ClusterUpdate, m t.newUpdate(xdsclient.ClusterResource, dd, metadata) } -func (t *testUpdateReceiver) NewEndpoints(d map[string]xdsclient.EndpointsUpdate, metadata xdsclient.UpdateMetadata) { +func (t *testUpdateReceiver) NewEndpoints(d map[string]xdsclient.EndpointsUpdateErrTuple, metadata xdsclient.UpdateMetadata) { dd := make(map[string]interface{}) for k, v := range d { dd[k] = v @@ -359,6 +333,8 @@ func (t *testUpdateReceiver) NewEndpoints(d map[string]xdsclient.EndpointsUpdate t.newUpdate(xdsclient.EndpointsResource, dd, metadata) } +func (t *testUpdateReceiver) NewConnectionError(error) {} + func (t *testUpdateReceiver) newUpdate(rType xdsclient.ResourceType, d map[string]interface{}, metadata xdsclient.UpdateMetadata) { t.f(rType, d, metadata) } @@ -387,27 +363,27 @@ func testWatchHandle(t *testing.T, test *watchHandleTestcase) { if rType == test.rType { switch test.rType { case xdsclient.ListenerResource: - dd := make(map[string]xdsclient.ListenerUpdate) + dd := make(map[string]xdsclient.ListenerUpdateErrTuple) for n, u := range d { - dd[n] = u.(xdsclient.ListenerUpdate) + dd[n] = u.(xdsclient.ListenerUpdateErrTuple) } gotUpdateCh.Send(updateErr{dd, md, nil}) case xdsclient.RouteConfigResource: - dd := make(map[string]xdsclient.RouteConfigUpdate) + dd := make(map[string]xdsclient.RouteConfigUpdateErrTuple) for n, u := range d { - dd[n] = u.(xdsclient.RouteConfigUpdate) + dd[n] = u.(xdsclient.RouteConfigUpdateErrTuple) } gotUpdateCh.Send(updateErr{dd, md, nil}) case xdsclient.ClusterResource: - dd := make(map[string]xdsclient.ClusterUpdate) + dd := make(map[string]xdsclient.ClusterUpdateErrTuple) for n, u := range d { - dd[n] = u.(xdsclient.ClusterUpdate) + dd[n] = u.(xdsclient.ClusterUpdateErrTuple) } gotUpdateCh.Send(updateErr{dd, md, nil}) case xdsclient.EndpointsResource: - dd := make(map[string]xdsclient.EndpointsUpdate) + dd := make(map[string]xdsclient.EndpointsUpdateErrTuple) for n, u := range d { - dd[n] = u.(xdsclient.EndpointsUpdate) + dd[n] = u.(xdsclient.EndpointsUpdateErrTuple) } gotUpdateCh.Send(updateErr{dd, md, nil}) } @@ -457,7 +433,7 @@ func testWatchHandle(t *testing.T, test *watchHandleTestcase) { cmpopts.EquateEmpty(), protocmp.Transform(), cmpopts.IgnoreFields(xdsclient.UpdateMetadata{}, "Timestamp"), cmpopts.IgnoreFields(xdsclient.UpdateErrorMetadata{}, "Timestamp"), - cmp.Comparer(func(x, y error) bool { return (x == nil) == (y == nil) }), + cmp.FilterValues(func(x, y error) bool { return true }, cmpopts.EquateErrors()), } uErr, err := gotUpdateCh.Receive(ctx) if err == context.DeadlineExceeded { @@ -546,7 +522,7 @@ func (s) TestV2ClientBackoffAfterRecvError(t *testing.T) { fakeServer.XDSResponseChan <- &fakeserver.Response{Err: errors.New("RPC error")} t.Log("Bad LDS response pushed to fakeServer...") - timer := time.NewTimer(defaultTestShortTimeout) + timer := time.NewTimer(defaultTestTimeout) select { case <-timer.C: t.Fatal("Timeout when expecting LDS update") @@ -688,7 +664,7 @@ func (s) TestV2ClientWatchWithoutStream(t *testing.T) { if v, err := callbackCh.Receive(ctx); err != nil { t.Fatal("Timeout when expecting LDS update") - } else if _, ok := v.(xdsclient.ListenerUpdate); !ok { + } else if _, ok := v.(xdsclient.ListenerUpdateErrTuple); !ok { t.Fatalf("Expect an LDS update from watcher, got %v", v) } } diff --git a/xds/internal/client/v2/eds_test.go b/xds/internal/xdsclient/v2/eds_test.go similarity index 85% rename from xds/internal/client/v2/eds_test.go rename to xds/internal/xdsclient/v2/eds_test.go index 0990e7ebae0..8176b6dfb93 100644 --- a/xds/internal/client/v2/eds_test.go +++ b/xds/internal/xdsclient/v2/eds_test.go @@ -23,12 +23,13 @@ import ( "time" v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" - "github.com/golang/protobuf/ptypes" anypb "github.com/golang/protobuf/ptypes/any" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/testutils" + xtestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" ) var ( @@ -42,20 +43,14 @@ var ( TypeUrl: version.V2EndpointsURL, } badResourceTypeInEDSResponse = &v2xdspb.DiscoveryResponse{ - Resources: []*anypb.Any{ - { - TypeUrl: httpConnManagerURL, - Value: marshaledConnMgr1, - }, - }, - TypeUrl: version.V2EndpointsURL, + Resources: []*anypb.Any{marshaledConnMgr1}, + TypeUrl: version.V2EndpointsURL, } marshaledGoodCLA1 = func() *anypb.Any { - clab0 := testutils.NewClusterLoadAssignmentBuilder(goodEDSName, nil) + clab0 := xtestutils.NewClusterLoadAssignmentBuilder(goodEDSName, nil) clab0.AddLocality("locality-1", 1, 1, []string{"addr1:314"}, nil) clab0.AddLocality("locality-2", 1, 0, []string{"addr2:159"}, nil) - a, _ := ptypes.MarshalAny(clab0.Build()) - return a + return testutils.MarshalAny(clab0.Build()) }() goodEDSResponse1 = &v2xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ @@ -64,10 +59,9 @@ var ( TypeUrl: version.V2EndpointsURL, } marshaledGoodCLA2 = func() *anypb.Any { - clab0 := testutils.NewClusterLoadAssignmentBuilder("not-goodEDSName", nil) + clab0 := xtestutils.NewClusterLoadAssignmentBuilder("not-goodEDSName", nil) clab0.AddLocality("locality-1", 1, 0, []string{"addr1:314"}, nil) - a, _ := ptypes.MarshalAny(clab0.Build()) - return a + return testutils.MarshalAny(clab0.Build()) }() goodEDSResponse2 = &v2xdspb.DiscoveryResponse{ Resources: []*anypb.Any{ @@ -82,7 +76,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { name string edsResponse *v2xdspb.DiscoveryResponse wantErr bool - wantUpdate map[string]xdsclient.EndpointsUpdate + wantUpdate map[string]xdsclient.EndpointsUpdateErrTuple wantUpdateMD xdsclient.UpdateMetadata wantUpdateErr bool }{ @@ -95,7 +89,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -109,7 +103,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -119,8 +113,8 @@ func (s) TestEDSHandleResponse(t *testing.T) { name: "one-uninterestring-assignment", edsResponse: goodEDSResponse2, wantErr: false, - wantUpdate: map[string]xdsclient.EndpointsUpdate{ - "not-goodEDSName": { + wantUpdate: map[string]xdsclient.EndpointsUpdateErrTuple{ + "not-goodEDSName": {Update: xdsclient.EndpointsUpdate{ Localities: []xdsclient.Locality{ { Endpoints: []xdsclient.Endpoint{{Address: "addr1:314"}}, @@ -130,7 +124,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { }, }, Raw: marshaledGoodCLA2, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -142,8 +136,8 @@ func (s) TestEDSHandleResponse(t *testing.T) { name: "one-good-assignment", edsResponse: goodEDSResponse1, wantErr: false, - wantUpdate: map[string]xdsclient.EndpointsUpdate{ - goodEDSName: { + wantUpdate: map[string]xdsclient.EndpointsUpdateErrTuple{ + goodEDSName: {Update: xdsclient.EndpointsUpdate{ Localities: []xdsclient.Locality{ { Endpoints: []xdsclient.Endpoint{{Address: "addr1:314"}}, @@ -159,7 +153,7 @@ func (s) TestEDSHandleResponse(t *testing.T) { }, }, Raw: marshaledGoodCLA1, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, diff --git a/xds/internal/client/v2/lds_test.go b/xds/internal/xdsclient/v2/lds_test.go similarity index 81% rename from xds/internal/client/v2/lds_test.go rename to xds/internal/xdsclient/v2/lds_test.go index 1f4c980fae5..a0600550095 100644 --- a/xds/internal/client/v2/lds_test.go +++ b/xds/internal/xdsclient/v2/lds_test.go @@ -23,8 +23,9 @@ import ( "time" v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" + "github.com/google/go-cmp/cmp/cmpopts" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/xdsclient" ) // TestLDSHandleResponse starts a fake xDS server, makes a ClientConn to it, @@ -35,7 +36,7 @@ func (s) TestLDSHandleResponse(t *testing.T) { name string ldsResponse *v2xdspb.DiscoveryResponse wantErr bool - wantUpdate map[string]xdsclient.ListenerUpdate + wantUpdate map[string]xdsclient.ListenerUpdateErrTuple wantUpdateMD xdsclient.UpdateMetadata wantUpdateErr bool }{ @@ -48,7 +49,7 @@ func (s) TestLDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -62,7 +63,7 @@ func (s) TestLDSHandleResponse(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -74,13 +75,13 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "no-apiListener-in-response", ldsResponse: noAPIListenerLDSResponse, wantErr: true, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget1: {}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget1: {Err: cmpopts.AnyError}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -90,8 +91,8 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "one-good-listener", ldsResponse: goodLDSResponse1, wantErr: false, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget1: {RouteConfigName: goodRouteName1, Raw: marshaledListener1}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget1: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener1}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -104,9 +105,9 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "multiple-good-listener", ldsResponse: ldsResponseWithMultipleResources, wantErr: false, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget1: {RouteConfigName: goodRouteName1, Raw: marshaledListener1}, - goodLDSTarget2: {RouteConfigName: goodRouteName1, Raw: marshaledListener2}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget1: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener1}}, + goodLDSTarget2: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener2}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -120,14 +121,14 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "good-bad-ugly-listeners", ldsResponse: goodBadUglyLDSResponse, wantErr: true, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget1: {RouteConfigName: goodRouteName1, Raw: marshaledListener1}, - goodLDSTarget2: {}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget1: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener1}}, + goodLDSTarget2: {Err: cmpopts.AnyError}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -137,8 +138,8 @@ func (s) TestLDSHandleResponse(t *testing.T) { name: "one-uninteresting-listener", ldsResponse: goodLDSResponse2, wantErr: false, - wantUpdate: map[string]xdsclient.ListenerUpdate{ - goodLDSTarget2: {RouteConfigName: goodRouteName1, Raw: marshaledListener2}, + wantUpdate: map[string]xdsclient.ListenerUpdateErrTuple{ + goodLDSTarget2: {Update: xdsclient.ListenerUpdate{RouteConfigName: goodRouteName1, Raw: marshaledListener2}}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, diff --git a/xds/internal/client/v2/loadreport.go b/xds/internal/xdsclient/v2/loadreport.go similarity index 88% rename from xds/internal/client/v2/loadreport.go rename to xds/internal/xdsclient/v2/loadreport.go index 69405fcd9ad..f0034e21c35 100644 --- a/xds/internal/client/v2/loadreport.go +++ b/xds/internal/xdsclient/v2/loadreport.go @@ -26,7 +26,8 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/xds/internal/xdsclient/load" v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" v2endpointpb "github.com/envoyproxy/go-control-plane/envoy/api/v2/endpoint" @@ -57,7 +58,7 @@ func (v2c *client) SendFirstLoadStatsRequest(s grpc.ClientStream) error { node.ClientFeatures = append(node.ClientFeatures, clientFeatureLRSSendAllClusters) req := &lrspb.LoadStatsRequest{Node: node} - v2c.logger.Infof("lrs: sending init LoadStatsRequest: %v", req) + v2c.logger.Infof("lrs: sending init LoadStatsRequest: %v", pretty.ToJSON(req)) return stream.Send(req) } @@ -71,7 +72,7 @@ func (v2c *client) HandleLoadStatsResponse(s grpc.ClientStream) ([]string, time. if err != nil { return nil, 0, fmt.Errorf("lrs: failed to receive first response: %v", err) } - v2c.logger.Infof("lrs: received first LoadStatsResponse: %+v", resp) + v2c.logger.Infof("lrs: received first LoadStatsResponse: %+v", pretty.ToJSON(resp)) interval, err := ptypes.Duration(resp.GetLoadReportingInterval()) if err != nil { @@ -98,24 +99,22 @@ func (v2c *client) SendLoadStatsRequest(s grpc.ClientStream, loads []*load.Data) return fmt.Errorf("lrs: Attempt to send request on unsupported stream type: %T", s) } - var clusterStats []*v2endpointpb.ClusterStats + clusterStats := make([]*v2endpointpb.ClusterStats, 0, len(loads)) for _, sd := range loads { - var ( - droppedReqs []*v2endpointpb.ClusterStats_DroppedRequests - localityStats []*v2endpointpb.UpstreamLocalityStats - ) + droppedReqs := make([]*v2endpointpb.ClusterStats_DroppedRequests, 0, len(sd.Drops)) for category, count := range sd.Drops { droppedReqs = append(droppedReqs, &v2endpointpb.ClusterStats_DroppedRequests{ Category: category, DroppedCount: count, }) } + localityStats := make([]*v2endpointpb.UpstreamLocalityStats, 0, len(sd.LocalityStats)) for l, localityData := range sd.LocalityStats { lid, err := internal.LocalityIDFromString(l) if err != nil { return err } - var loadMetricStats []*v2endpointpb.EndpointLoadMetricStats + loadMetricStats := make([]*v2endpointpb.EndpointLoadMetricStats, 0, len(localityData.LoadStats)) for name, loadData := range localityData.LoadStats { loadMetricStats = append(loadMetricStats, &v2endpointpb.EndpointLoadMetricStats{ MetricName: name, @@ -149,6 +148,6 @@ func (v2c *client) SendLoadStatsRequest(s grpc.ClientStream, loads []*load.Data) } req := &lrspb.LoadStatsRequest{ClusterStats: clusterStats} - v2c.logger.Infof("lrs: sending LRS loads: %+v", req) + v2c.logger.Infof("lrs: sending LRS loads: %+v", pretty.ToJSON(req)) return stream.Send(req) } diff --git a/xds/internal/client/v2/rds_test.go b/xds/internal/xdsclient/v2/rds_test.go similarity index 77% rename from xds/internal/client/v2/rds_test.go rename to xds/internal/xdsclient/v2/rds_test.go index dd145158b8a..3389f053946 100644 --- a/xds/internal/client/v2/rds_test.go +++ b/xds/internal/xdsclient/v2/rds_test.go @@ -24,9 +24,10 @@ import ( "time" xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" + "github.com/google/go-cmp/cmp/cmpopts" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeserver" + "google.golang.org/grpc/xds/internal/xdsclient" ) // doLDS makes a LDS watch, and waits for the response and ack to finish. @@ -49,7 +50,7 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { name string rdsResponse *xdspb.DiscoveryResponse wantErr bool - wantUpdate map[string]xdsclient.RouteConfigUpdate + wantUpdate map[string]xdsclient.RouteConfigUpdateErrTuple wantUpdateMD xdsclient.UpdateMetadata wantUpdateErr bool }{ @@ -62,7 +63,7 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, @@ -76,23 +77,23 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusNACKed, ErrState: &xdsclient.UpdateErrorMetadata{ - Err: errPlaceHolder, + Err: cmpopts.AnyError, }, }, wantUpdateErr: false, }, - // No VirtualHosts in the response. Just one test case here for a bad + // No virtualHosts in the response. Just one test case here for a bad // RouteConfiguration, since the others are covered in // TestGetClusterFromRouteConfiguration. { name: "no-virtual-hosts-in-response", rdsResponse: noVirtualHostsInRDSResponse, wantErr: false, - wantUpdate: map[string]xdsclient.RouteConfigUpdate{ - goodRouteName1: { + wantUpdate: map[string]xdsclient.RouteConfigUpdateErrTuple{ + goodRouteName1: {Update: xdsclient.RouteConfigUpdate{ VirtualHosts: nil, Raw: marshaledNoVirtualHostsRouteConfig, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -104,20 +105,25 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { name: "one-uninteresting-route-config", rdsResponse: goodRDSResponse2, wantErr: false, - wantUpdate: map[string]xdsclient.RouteConfigUpdate{ - goodRouteName2: { + wantUpdate: map[string]xdsclient.RouteConfigUpdateErrTuple{ + goodRouteName2: {Update: xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: xdsclient.RouteActionRoute}}, }, { Domains: []string{goodLDSTarget1}, - Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{goodClusterName2: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{ + Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{goodClusterName2: {Weight: 1}}, + RouteAction: xdsclient.RouteActionRoute}}, }, }, Raw: marshaledGoodRouteConfig2, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, @@ -129,20 +135,25 @@ func (s) TestRDSHandleResponseWithRouting(t *testing.T) { name: "one-good-route-config", rdsResponse: goodRDSResponse1, wantErr: false, - wantUpdate: map[string]xdsclient.RouteConfigUpdate{ - goodRouteName1: { + wantUpdate: map[string]xdsclient.RouteConfigUpdateErrTuple{ + goodRouteName1: {Update: xdsclient.RouteConfigUpdate{ VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{uninterestingDomain}, - Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{uninterestingClusterName: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{ + Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{uninterestingClusterName: {Weight: 1}}, + RouteAction: xdsclient.RouteActionRoute}}, }, { Domains: []string{goodLDSTarget1}, - Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{goodClusterName1: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{goodClusterName1: {Weight: 1}}, + RouteAction: xdsclient.RouteActionRoute}}, }, }, Raw: marshaledGoodRouteConfig1, - }, + }}, }, wantUpdateMD: xdsclient.UpdateMetadata{ Status: xdsclient.ServiceStatusACKed, diff --git a/xds/internal/client/v3/client.go b/xds/internal/xdsclient/v3/client.go similarity index 82% rename from xds/internal/client/v3/client.go rename to xds/internal/xdsclient/v3/client.go index 55cae56d8cc..827c06b741b 100644 --- a/xds/internal/client/v3/client.go +++ b/xds/internal/xdsclient/v3/client.go @@ -28,8 +28,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpclog" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3adsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" @@ -65,10 +66,11 @@ func newClient(cc *grpc.ClientConn, opts xdsclient.BuildOptions) (xdsclient.APIC return nil, fmt.Errorf("xds: unsupported Node proto type: %T, want %T", opts.NodeProto, v3corepb.Node{}) } v3c := &client{ - cc: cc, - parent: opts.Parent, - nodeProto: nodeProto, - logger: opts.Logger, + cc: cc, + parent: opts.Parent, + nodeProto: nodeProto, + logger: opts.Logger, + updateValidator: opts.Validator, } v3c.ctx, v3c.cancelCtx = context.WithCancel(context.Background()) v3c.TransportHelper = xdsclient.NewTransportHelper(v3c, opts.Logger, opts.Backoff) @@ -89,8 +91,9 @@ type client struct { logger *grpclog.PrefixLogger // ClientConn to the xDS gRPC server. Owned by the parent xdsClient. - cc *grpc.ClientConn - nodeProto *v3corepb.Node + cc *grpc.ClientConn + nodeProto *v3corepb.Node + updateValidator xdsclient.UpdateValidatorFunc } func (v3c *client) NewStream(ctx context.Context) (grpc.ClientStream, error) { @@ -125,7 +128,7 @@ func (v3c *client) SendRequest(s grpc.ClientStream, resourceNames []string, rTyp if err := stream.Send(req); err != nil { return fmt.Errorf("xds: stream.Send(%+v) failed: %v", req, err) } - v3c.logger.Debugf("ADS request sent: %v", req) + v3c.logger.Debugf("ADS request sent: %v", pretty.ToJSON(req)) return nil } @@ -139,11 +142,11 @@ func (v3c *client) RecvResponse(s grpc.ClientStream) (proto.Message, error) { resp, err := stream.Recv() if err != nil { - // TODO: call watch callbacks with error when stream is broken. + v3c.parent.NewConnectionError(err) return nil, fmt.Errorf("xds: stream.Recv() failed: %v", err) } v3c.logger.Infof("ADS response received, type: %v", resp.GetTypeUrl()) - v3c.logger.Debugf("ADS response received: %+v", resp) + v3c.logger.Debugf("ADS response received: %+v", pretty.ToJSON(resp)) return resp, nil } @@ -185,7 +188,12 @@ func (v3c *client) HandleResponse(r proto.Message) (xdsclient.ResourceType, stri // server. On receipt of a good response, it also invokes the registered watcher // callback. func (v3c *client) handleLDSResponse(resp *v3discoverypb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalListener(resp.GetVersionInfo(), resp.GetResources(), v3c.logger) + update, md, err := xdsclient.UnmarshalListener(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v3c.logger, + UpdateValidator: v3c.updateValidator, + }) v3c.parent.NewListeners(update, md) return err } @@ -194,7 +202,12 @@ func (v3c *client) handleLDSResponse(resp *v3discoverypb.DiscoveryResponse) erro // server. On receipt of a good response, it caches validated resources and also // invokes the registered watcher callback. func (v3c *client) handleRDSResponse(resp *v3discoverypb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalRouteConfig(resp.GetVersionInfo(), resp.GetResources(), v3c.logger) + update, md, err := xdsclient.UnmarshalRouteConfig(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v3c.logger, + UpdateValidator: v3c.updateValidator, + }) v3c.parent.NewRouteConfigs(update, md) return err } @@ -203,13 +216,23 @@ func (v3c *client) handleRDSResponse(resp *v3discoverypb.DiscoveryResponse) erro // server. On receipt of a good response, it also invokes the registered watcher // callback. func (v3c *client) handleCDSResponse(resp *v3discoverypb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalCluster(resp.GetVersionInfo(), resp.GetResources(), v3c.logger) + update, md, err := xdsclient.UnmarshalCluster(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v3c.logger, + UpdateValidator: v3c.updateValidator, + }) v3c.parent.NewClusters(update, md) return err } func (v3c *client) handleEDSResponse(resp *v3discoverypb.DiscoveryResponse) error { - update, md, err := xdsclient.UnmarshalEndpoints(resp.GetVersionInfo(), resp.GetResources(), v3c.logger) + update, md, err := xdsclient.UnmarshalEndpoints(&xdsclient.UnmarshalOptions{ + Version: resp.GetVersionInfo(), + Resources: resp.GetResources(), + Logger: v3c.logger, + UpdateValidator: v3c.updateValidator, + }) v3c.parent.NewEndpoints(update, md) return err } diff --git a/xds/internal/client/v3/loadreport.go b/xds/internal/xdsclient/v3/loadreport.go similarity index 88% rename from xds/internal/client/v3/loadreport.go rename to xds/internal/xdsclient/v3/loadreport.go index 74e18632aa0..8cdb5476fbb 100644 --- a/xds/internal/client/v3/loadreport.go +++ b/xds/internal/xdsclient/v3/loadreport.go @@ -26,7 +26,8 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/xds/internal/xdsclient/load" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" @@ -57,7 +58,7 @@ func (v3c *client) SendFirstLoadStatsRequest(s grpc.ClientStream) error { node.ClientFeatures = append(node.ClientFeatures, clientFeatureLRSSendAllClusters) req := &lrspb.LoadStatsRequest{Node: node} - v3c.logger.Infof("lrs: sending init LoadStatsRequest: %v", req) + v3c.logger.Infof("lrs: sending init LoadStatsRequest: %v", pretty.ToJSON(req)) return stream.Send(req) } @@ -71,7 +72,7 @@ func (v3c *client) HandleLoadStatsResponse(s grpc.ClientStream) ([]string, time. if err != nil { return nil, 0, fmt.Errorf("lrs: failed to receive first response: %v", err) } - v3c.logger.Infof("lrs: received first LoadStatsResponse: %+v", resp) + v3c.logger.Infof("lrs: received first LoadStatsResponse: %+v", pretty.ToJSON(resp)) interval, err := ptypes.Duration(resp.GetLoadReportingInterval()) if err != nil { @@ -98,24 +99,22 @@ func (v3c *client) SendLoadStatsRequest(s grpc.ClientStream, loads []*load.Data) return fmt.Errorf("lrs: Attempt to send request on unsupported stream type: %T", s) } - var clusterStats []*v3endpointpb.ClusterStats + clusterStats := make([]*v3endpointpb.ClusterStats, 0, len(loads)) for _, sd := range loads { - var ( - droppedReqs []*v3endpointpb.ClusterStats_DroppedRequests - localityStats []*v3endpointpb.UpstreamLocalityStats - ) + droppedReqs := make([]*v3endpointpb.ClusterStats_DroppedRequests, 0, len(sd.Drops)) for category, count := range sd.Drops { droppedReqs = append(droppedReqs, &v3endpointpb.ClusterStats_DroppedRequests{ Category: category, DroppedCount: count, }) } + localityStats := make([]*v3endpointpb.UpstreamLocalityStats, 0, len(sd.LocalityStats)) for l, localityData := range sd.LocalityStats { lid, err := internal.LocalityIDFromString(l) if err != nil { return err } - var loadMetricStats []*v3endpointpb.EndpointLoadMetricStats + loadMetricStats := make([]*v3endpointpb.EndpointLoadMetricStats, 0, len(localityData.LoadStats)) for name, loadData := range localityData.LoadStats { loadMetricStats = append(loadMetricStats, &v3endpointpb.EndpointLoadMetricStats{ MetricName: name, @@ -148,6 +147,6 @@ func (v3c *client) SendLoadStatsRequest(s grpc.ClientStream, loads []*load.Data) } req := &lrspb.LoadStatsRequest{ClusterStats: clusterStats} - v3c.logger.Infof("lrs: sending LRS loads: %+v", req) + v3c.logger.Infof("lrs: sending LRS loads: %+v", pretty.ToJSON(req)) return stream.Send(req) } diff --git a/xds/internal/client/watchers.go b/xds/internal/xdsclient/watchers.go similarity index 95% rename from xds/internal/client/watchers.go rename to xds/internal/xdsclient/watchers.go index 9fafe5a60f8..e26ed360308 100644 --- a/xds/internal/client/watchers.go +++ b/xds/internal/xdsclient/watchers.go @@ -16,12 +16,14 @@ * */ -package client +package xdsclient import ( "fmt" "sync" "time" + + "google.golang.org/grpc/internal/pretty" ) type watchInfoState int @@ -64,6 +66,17 @@ func (wi *watchInfo) newUpdate(update interface{}) { wi.c.scheduleCallback(wi, update, nil) } +func (wi *watchInfo) newError(err error) { + wi.mu.Lock() + defer wi.mu.Unlock() + if wi.state == watchInfoStateCanceled { + return + } + wi.state = watchInfoStateRespReceived + wi.expiryTimer.Stop() + wi.sendErrorLocked(err) +} + func (wi *watchInfo) resourceNotFound() { wi.mu.Lock() defer wi.mu.Unlock() @@ -161,22 +174,22 @@ func (c *clientImpl) watch(wi *watchInfo) (cancel func()) { switch wi.rType { case ListenerResource: if v, ok := c.ldsCache[resourceName]; ok { - c.logger.Debugf("LDS resource with name %v found in cache: %+v", wi.target, v) + c.logger.Debugf("LDS resource with name %v found in cache: %+v", wi.target, pretty.ToJSON(v)) wi.newUpdate(v) } case RouteConfigResource: if v, ok := c.rdsCache[resourceName]; ok { - c.logger.Debugf("RDS resource with name %v found in cache: %+v", wi.target, v) + c.logger.Debugf("RDS resource with name %v found in cache: %+v", wi.target, pretty.ToJSON(v)) wi.newUpdate(v) } case ClusterResource: if v, ok := c.cdsCache[resourceName]; ok { - c.logger.Debugf("CDS resource with name %v found in cache: %+v", wi.target, v) + c.logger.Debugf("CDS resource with name %v found in cache: %+v", wi.target, pretty.ToJSON(v)) wi.newUpdate(v) } case EndpointsResource: if v, ok := c.edsCache[resourceName]; ok { - c.logger.Debugf("EDS resource with name %v found in cache: %+v", wi.target, v) + c.logger.Debugf("EDS resource with name %v found in cache: %+v", wi.target, pretty.ToJSON(v)) wi.newUpdate(v) } } diff --git a/xds/internal/client/watchers_cluster_test.go b/xds/internal/xdsclient/watchers_cluster_test.go similarity index 58% rename from xds/internal/client/watchers_cluster_test.go rename to xds/internal/xdsclient/watchers_cluster_test.go index fdef0cf6164..c06319e959c 100644 --- a/xds/internal/client/watchers_cluster_test.go +++ b/xds/internal/xdsclient/watchers_cluster_test.go @@ -16,22 +16,19 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc/internal/testutils" ) -type clusterUpdateErr struct { - u ClusterUpdate - err error -} - // TestClusterWatch covers the cases: // - an update is received after a watch() // - an update for another resource name @@ -56,30 +53,34 @@ func (s) TestClusterWatch(t *testing.T) { clusterUpdateCh := testutils.NewChannel() cancelWatch := client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } - // Another update, with an extra resource for a different resource name. - client.NewClusters(map[string]ClusterUpdate{ - testCDSName: wantUpdate, + // Push an update, with an extra resource for a different resource name. + // Specify a non-nil raw proto in the original resource to ensure that the + // new update is not considered equal to the old one. + newUpdate := wantUpdate + newUpdate.Raw = &anypb.Any{} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName: {Update: newUpdate}, "randomName": {}, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, newUpdate, nil); err != nil { t.Fatal(err) } // Cancel watch, and send update again. cancelWatch() - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() if u, err := clusterUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { @@ -114,7 +115,7 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) { clusterUpdateCh := testutils.NewChannel() clusterUpdateChs = append(clusterUpdateChs, clusterUpdateCh) cancelLastWatch = client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -126,27 +127,36 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) { } } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } - // Cancel the last watch, and send update again. + // Cancel the last watch, and send update again. None of the watchers should + // be notified because one has been cancelled, and the other is receiving + // the same update. cancelLastWatch() - client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count-1; i++ { - if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + func() { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := clusterUpdateChs[i].Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ClusterUpdate: %v, %v, want channel recv timeout", u, err) + } + }() } - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := clusterUpdateChs[count-1].Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err) + // Push a new update and make sure the uncancelled watcher is invoked. + // Specify a non-nil raw proto to ensure that the new update is not + // considered equal to the old one. + newUpdate := ClusterUpdate{ClusterName: testEDSName, Raw: &anypb.Any{}} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: {Update: newUpdate}}, UpdateMetadata{}) + if err := verifyClusterUpdate(ctx, clusterUpdateChs[0], newUpdate, nil); err != nil { + t.Fatal(err) } } @@ -177,7 +187,7 @@ func (s) TestClusterThreeWatchDifferentResourceName(t *testing.T) { clusterUpdateCh := testutils.NewChannel() clusterUpdateChs = append(clusterUpdateChs, clusterUpdateCh) client.WatchCluster(testCDSName+"1", func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -192,25 +202,25 @@ func (s) TestClusterThreeWatchDifferentResourceName(t *testing.T) { // Third watch for a different name. clusterUpdateCh2 := testutils.NewChannel() client.WatchCluster(testCDSName+"2", func(update ClusterUpdate, err error) { - clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh2.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate1 := ClusterUpdate{ServiceName: testEDSName + "1"} - wantUpdate2 := ClusterUpdate{ServiceName: testEDSName + "2"} - client.NewClusters(map[string]ClusterUpdate{ - testCDSName + "1": wantUpdate1, - testCDSName + "2": wantUpdate2, + wantUpdate1 := ClusterUpdate{ClusterName: testEDSName + "1"} + wantUpdate2 := ClusterUpdate{ClusterName: testEDSName + "2"} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName + "1": {Update: wantUpdate1}, + testCDSName + "2": {Update: wantUpdate2}, }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate1); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -237,24 +247,24 @@ func (s) TestClusterWatchAfterCache(t *testing.T) { clusterUpdateCh := testutils.NewChannel() client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{ - testCDSName: wantUpdate, + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName: {Update: wantUpdate}, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } // Another watch for the resource in cache. clusterUpdateCh2 := testutils.NewChannel() client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { - clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh2.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() @@ -263,7 +273,7 @@ func (s) TestClusterWatchAfterCache(t *testing.T) { } // New watch should receives the update. - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -298,7 +308,7 @@ func (s) TestClusterWatchExpiryTimer(t *testing.T) { clusterUpdateCh := testutils.NewChannel() client.WatchCluster(testCDSName, func(u ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: u, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: u, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -308,9 +318,9 @@ func (s) TestClusterWatchExpiryTimer(t *testing.T) { if err != nil { t.Fatalf("timeout when waiting for cluster update: %v", err) } - gotUpdate := u.(clusterUpdateErr) - if gotUpdate.err == nil || !cmp.Equal(gotUpdate.u, ClusterUpdate{}) { - t.Fatalf("unexpected clusterUpdate: (%v, %v), want: (ClusterUpdate{}, nil)", gotUpdate.u, gotUpdate.err) + gotUpdate := u.(ClusterUpdateErrTuple) + if gotUpdate.Err == nil || !cmp.Equal(gotUpdate.Update, ClusterUpdate{}) { + t.Fatalf("unexpected clusterUpdate: (%v, %v), want: (ClusterUpdate{}, nil)", gotUpdate.Update, gotUpdate.Err) } } @@ -337,17 +347,17 @@ func (s) TestClusterWatchExpiryTimerStop(t *testing.T) { clusterUpdateCh := testutils.NewChannel() client.WatchCluster(testCDSName, func(u ClusterUpdate, err error) { - clusterUpdateCh.Send(clusterUpdateErr{u: u, err: err}) + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: u, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate := ClusterUpdate{ServiceName: testEDSName} - client.NewClusters(map[string]ClusterUpdate{ - testCDSName: wantUpdate, + wantUpdate := ClusterUpdate{ClusterName: testEDSName} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName: {Update: wantUpdate}, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -385,7 +395,7 @@ func (s) TestClusterResourceRemoved(t *testing.T) { clusterUpdateCh1 := testutils.NewChannel() client.WatchCluster(testCDSName+"1", func(update ClusterUpdate, err error) { - clusterUpdateCh1.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh1.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -394,50 +404,152 @@ func (s) TestClusterResourceRemoved(t *testing.T) { // Another watch for a different name. clusterUpdateCh2 := testutils.NewChannel() client.WatchCluster(testCDSName+"2", func(update ClusterUpdate, err error) { - clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err}) + clusterUpdateCh2.Send(ClusterUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } - wantUpdate1 := ClusterUpdate{ServiceName: testEDSName + "1"} - wantUpdate2 := ClusterUpdate{ServiceName: testEDSName + "2"} - client.NewClusters(map[string]ClusterUpdate{ - testCDSName + "1": wantUpdate1, - testCDSName + "2": wantUpdate2, + wantUpdate1 := ClusterUpdate{ClusterName: testEDSName + "1"} + wantUpdate2 := ClusterUpdate{ClusterName: testEDSName + "2"} + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName + "1": {Update: wantUpdate1}, + testCDSName + "2": {Update: wantUpdate2}, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh1, wantUpdate1); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh1, wantUpdate1, nil); err != nil { t.Fatal(err) } - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } // Send another update to remove resource 1. - client.NewClusters(map[string]ClusterUpdate{testCDSName + "2": wantUpdate2}, UpdateMetadata{}) + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName + "2": {Update: wantUpdate2}}, UpdateMetadata{}) // Watcher 1 should get an error. - if u, err := clusterUpdateCh1.Receive(ctx); err != nil || ErrType(u.(clusterUpdateErr).err) != ErrorTypeResourceNotFound { + if u, err := clusterUpdateCh1.Receive(ctx); err != nil || ErrType(u.(ClusterUpdateErrTuple).Err) != ErrorTypeResourceNotFound { t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err) } - // Watcher 2 should get the same update again. - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { - t.Fatal(err) + // Watcher 2 should not see an update since the resource has not changed. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := clusterUpdateCh2.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ClusterUpdate: %v, want receiving from channel timeout", u) } - // Send one more update without resource 1. - client.NewClusters(map[string]ClusterUpdate{testCDSName + "2": wantUpdate2}, UpdateMetadata{}) + // Send another update with resource 2 modified. Specify a non-nil raw proto + // to ensure that the new update is not considered equal to the old one. + wantUpdate2 = ClusterUpdate{ClusterName: testEDSName + "2", Raw: &anypb.Any{}} + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName + "2": {Update: wantUpdate2}}, UpdateMetadata{}) // Watcher 1 should not see an update. - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() if u, err := clusterUpdateCh1.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err) + t.Errorf("unexpected Cluster: %v, want receiving from channel timeout", u) + } + + // Watcher 2 should get the update. + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } +} + +// TestClusterWatchNACKError covers the case that an update is NACK'ed, and the +// watcher should also receive the error. +func (s) TestClusterWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + clusterUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewClusters(map[string]ClusterUpdateErrTuple{testCDSName: { + Err: wantError, + }}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyClusterUpdate(ctx, clusterUpdateCh, ClusterUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} + +// TestClusterWatchPartialValid covers the case that a response contains both +// valid and invalid resources. This response will be NACK'ed by the xdsclient. +// But the watchers with valid resources should receive the update, those with +// invalida resources should receive an error. +func (s) TestClusterWatchPartialValid(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const badResourceName = "bad-resource" + updateChs := make(map[string]*testutils.Channel) + + for _, name := range []string{testCDSName, badResourceName} { + clusterUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchCluster(name, func(update ClusterUpdate, err error) { + clusterUpdateCh.Send(ClusterUpdateErrTuple{Update: update, Err: err}) + }) + defer func() { + cancelWatch() + if _, err := apiClient.removeWatches[ClusterResource].Receive(ctx); err != nil { + t.Fatalf("want watch to be canceled, got err: %v", err) + } + }() + if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + updateChs[name] = clusterUpdateCh + } + + wantError := fmt.Errorf("testing error") + wantError2 := fmt.Errorf("individual error") + client.NewClusters(map[string]ClusterUpdateErrTuple{ + testCDSName: {Update: ClusterUpdate{ClusterName: testEDSName}}, + badResourceName: {Err: wantError2}, + }, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + + // The valid resource should be sent to the watcher. + if err := verifyClusterUpdate(ctx, updateChs[testCDSName], ClusterUpdate{ClusterName: testEDSName}, nil); err != nil { + t.Fatal(err) } - // Watcher 2 should get the same update again. - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + // The failed watcher should receive an error. + if err := verifyClusterUpdate(ctx, updateChs[badResourceName], ClusterUpdate{}, wantError2); err != nil { t.Fatal(err) } } diff --git a/xds/internal/client/watchers_endpoints_test.go b/xds/internal/xdsclient/watchers_endpoints_test.go similarity index 57% rename from xds/internal/client/watchers_endpoints_test.go rename to xds/internal/xdsclient/watchers_endpoints_test.go index b79397414d4..b87723e5086 100644 --- a/xds/internal/client/watchers_endpoints_test.go +++ b/xds/internal/xdsclient/watchers_endpoints_test.go @@ -16,13 +16,15 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal" @@ -45,11 +47,6 @@ var ( } ) -type endpointsUpdateErr struct { - u EndpointsUpdate - err error -} - // TestEndpointsWatch covers the cases: // - an update is received after a watch() // - an update for another resource name (which doesn't trigger callback) @@ -74,30 +71,35 @@ func (s) TestEndpointsWatch(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() cancelWatch := client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } wantUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate); err != nil { + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } - // Another update for a different resource name. - client.NewEndpoints(map[string]EndpointsUpdate{"randomName": {}}, UpdateMetadata{}) - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := endpointsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) + // Push an update, with an extra resource for a different resource name. + // Specify a non-nil raw proto in the original resource to ensure that the + // new update is not considered equal to the old one. + newUpdate := wantUpdate + newUpdate.Raw = &anypb.Any{} + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{ + testCDSName: {Update: newUpdate}, + "randomName": {}, + }, UpdateMetadata{}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, newUpdate, nil); err != nil { + t.Fatal(err) } // Cancel watch, and send update again. cancelWatch() - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() if u, err := endpointsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) @@ -133,7 +135,7 @@ func (s) TestEndpointsTwoWatchSameResourceName(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() endpointsUpdateChs = append(endpointsUpdateChs, endpointsUpdateCh) cancelLastWatch = client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -146,26 +148,35 @@ func (s) TestEndpointsTwoWatchSameResourceName(t *testing.T) { } wantUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } - // Cancel the last watch, and send update again. + // Cancel the last watch, and send update again. None of the watchers should + // be notified because one has been cancelled, and the other is receiving + // the same update. cancelLastWatch() - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count-1; i++ { - if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + func() { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := endpointsUpdateChs[i].Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) + } + }() } - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := endpointsUpdateChs[count-1].Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) + // Push a new update and make sure the uncancelled watcher is invoked. + // Specify a non-nil raw proto to ensure that the new update is not + // considered equal to the old one. + newUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}, Raw: &anypb.Any{}} + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: newUpdate}}, UpdateMetadata{}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[0], newUpdate, nil); err != nil { + t.Fatal(err) } } @@ -196,7 +207,7 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() endpointsUpdateChs = append(endpointsUpdateChs, endpointsUpdateCh) client.WatchEndpoints(testCDSName+"1", func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -211,7 +222,7 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) { // Third watch for a different name. endpointsUpdateCh2 := testutils.NewChannel() client.WatchEndpoints(testCDSName+"2", func(update EndpointsUpdate, err error) { - endpointsUpdateCh2.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh2.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -219,17 +230,17 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) { wantUpdate1 := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} wantUpdate2 := EndpointsUpdate{Localities: []Locality{testLocalities[1]}} - client.NewEndpoints(map[string]EndpointsUpdate{ - testCDSName + "1": wantUpdate1, - testCDSName + "2": wantUpdate2, + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{ + testCDSName + "1": {Update: wantUpdate1}, + testCDSName + "2": {Update: wantUpdate2}, }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate1); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate2); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -256,22 +267,22 @@ func (s) TestEndpointsWatchAfterCache(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) } wantUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} - client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate); err != nil { + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } // Another watch for the resource in cache. endpointsUpdateCh2 := testutils.NewChannel() client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh2.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh2.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() @@ -280,7 +291,7 @@ func (s) TestEndpointsWatchAfterCache(t *testing.T) { } // New watch should receives the update. - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -315,7 +326,7 @@ func (s) TestEndpointsWatchExpiryTimer(t *testing.T) { endpointsUpdateCh := testutils.NewChannel() client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { - endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -325,8 +336,104 @@ func (s) TestEndpointsWatchExpiryTimer(t *testing.T) { if err != nil { t.Fatalf("timeout when waiting for endpoints update: %v", err) } - gotUpdate := u.(endpointsUpdateErr) - if gotUpdate.err == nil || !cmp.Equal(gotUpdate.u, EndpointsUpdate{}) { - t.Fatalf("unexpected endpointsUpdate: (%v, %v), want: (EndpointsUpdate{}, nil)", gotUpdate.u, gotUpdate.err) + gotUpdate := u.(EndpointsUpdateErrTuple) + if gotUpdate.Err == nil || !cmp.Equal(gotUpdate.Update, EndpointsUpdate{}) { + t.Fatalf("unexpected endpointsUpdate: (%v, %v), want: (EndpointsUpdate{}, nil)", gotUpdate.Update, gotUpdate.Err) + } +} + +// TestEndpointsWatchNACKError covers the case that an update is NACK'ed, and +// the watcher should also receive the error. +func (s) TestEndpointsWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + endpointsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{testCDSName: {Err: wantError}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, EndpointsUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} + +// TestEndpointsWatchPartialValid covers the case that a response contains both +// valid and invalid resources. This response will be NACK'ed by the xdsclient. +// But the watchers with valid resources should receive the update, those with +// invalida resources should receive an error. +func (s) TestEndpointsWatchPartialValid(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const badResourceName = "bad-resource" + updateChs := make(map[string]*testutils.Channel) + + for _, name := range []string{testCDSName, badResourceName} { + endpointsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchEndpoints(name, func(update EndpointsUpdate, err error) { + endpointsUpdateCh.Send(EndpointsUpdateErrTuple{Update: update, Err: err}) + }) + defer func() { + cancelWatch() + if _, err := apiClient.removeWatches[EndpointsResource].Receive(ctx); err != nil { + t.Fatalf("want watch to be canceled, got err: %v", err) + } + }() + if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + updateChs[name] = endpointsUpdateCh + } + + wantError := fmt.Errorf("testing error") + wantError2 := fmt.Errorf("individual error") + client.NewEndpoints(map[string]EndpointsUpdateErrTuple{ + testCDSName: {Update: EndpointsUpdate{Localities: []Locality{testLocalities[0]}}}, + badResourceName: {Err: wantError2}, + }, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + + // The valid resource should be sent to the watcher. + if err := verifyEndpointsUpdate(ctx, updateChs[testCDSName], EndpointsUpdate{Localities: []Locality{testLocalities[0]}}, nil); err != nil { + t.Fatal(err) + } + + // The failed watcher should receive an error. + if err := verifyEndpointsUpdate(ctx, updateChs[badResourceName], EndpointsUpdate{}, wantError2); err != nil { + t.Fatal(err) } } diff --git a/xds/internal/xdsclient/watchers_listener_test.go b/xds/internal/xdsclient/watchers_listener_test.go new file mode 100644 index 00000000000..176e6bbcb7b --- /dev/null +++ b/xds/internal/xdsclient/watchers_listener_test.go @@ -0,0 +1,591 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "context" + "fmt" + "testing" + + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/protobuf/types/known/anypb" +) + +// TestLDSWatch covers the cases: +// - an update is received after a watch() +// - an update for another resource name +// - an update is received after cancel() +func (s) TestLDSWatch(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate, nil); err != nil { + t.Fatal(err) + } + + // Push an update, with an extra resource for a different resource name. + // Specify a non-nil raw proto in the original resource to ensure that the + // new update is not considered equal to the old one. + newUpdate := ListenerUpdate{RouteConfigName: testRDSName, Raw: &anypb.Any{}} + client.NewListeners(map[string]ListenerUpdateErrTuple{ + testLDSName: {Update: newUpdate}, + "randomName": {}, + }, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh, newUpdate, nil); err != nil { + t.Fatal(err) + } + + // Cancel watch, and send update again. + cancelWatch() + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) + } +} + +// TestLDSTwoWatchSameResourceName covers the case where an update is received +// after two watch() for the same resource name. +func (s) TestLDSTwoWatchSameResourceName(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const count = 2 + var ( + ldsUpdateChs []*testutils.Channel + cancelLastWatch func() + ) + + for i := 0; i < count; i++ { + ldsUpdateCh := testutils.NewChannel() + ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) + cancelLastWatch = client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + + if i == 0 { + // A new watch is registered on the underlying API client only for + // the first iteration because we are using the same resource name. + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + } + } + + wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate, nil); err != nil { + t.Fatal(err) + } + } + + // Cancel the last watch, and send update again. None of the watchers should + // be notified because one has been cancelled, and the other is receiving + // the same update. + cancelLastWatch() + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + func() { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateChs[i].Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) + } + }() + } + + // Push a new update and make sure the uncancelled watcher is invoked. + // Specify a non-nil raw proto to ensure that the new update is not + // considered equal to the old one. + newUpdate := ListenerUpdate{RouteConfigName: testRDSName, Raw: &anypb.Any{}} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: newUpdate}}, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateChs[0], newUpdate, nil); err != nil { + t.Fatal(err) + } +} + +// TestLDSThreeWatchDifferentResourceName covers the case where an update is +// received after three watch() for different resource names. +func (s) TestLDSThreeWatchDifferentResourceName(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + var ldsUpdateChs []*testutils.Channel + const count = 2 + + // Two watches for the same name. + for i := 0; i < count; i++ { + ldsUpdateCh := testutils.NewChannel() + ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) + client.WatchListener(testLDSName+"1", func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + + if i == 0 { + // A new watch is registered on the underlying API client only for + // the first iteration because we are using the same resource name. + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + } + } + + // Third watch for a different name. + ldsUpdateCh2 := testutils.NewChannel() + client.WatchListener(testLDSName+"2", func(update ListenerUpdate, err error) { + ldsUpdateCh2.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantUpdate1 := ListenerUpdate{RouteConfigName: testRDSName + "1"} + wantUpdate2 := ListenerUpdate{RouteConfigName: testRDSName + "2"} + client.NewListeners(map[string]ListenerUpdateErrTuple{ + testLDSName + "1": {Update: wantUpdate1}, + testLDSName + "2": {Update: wantUpdate2}, + }, UpdateMetadata{}) + + for i := 0; i < count; i++ { + if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate1, nil); err != nil { + t.Fatal(err) + } + } + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } +} + +// TestLDSWatchAfterCache covers the case where watch is called after the update +// is in cache. +func (s) TestLDSWatchAfterCache(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh := testutils.NewChannel() + client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate, nil); err != nil { + t.Fatal(err) + } + + // Another watch for the resource in cache. + ldsUpdateCh2 := testutils.NewChannel() + client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh2.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if n, err := apiClient.addWatches[ListenerResource].Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err) + } + + // New watch should receive the update. + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate, nil); err != nil { + t.Fatal(err) + } + + // Old watch should see nothing. + sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) + } +} + +// TestLDSResourceRemoved covers the cases: +// - an update is received after a watch() +// - another update is received, with one resource removed +// - this should trigger callback with resource removed error +// - one more update without the removed resource +// - the callback (above) shouldn't receive any update +func (s) TestLDSResourceRemoved(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh1 := testutils.NewChannel() + client.WatchListener(testLDSName+"1", func(update ListenerUpdate, err error) { + ldsUpdateCh1.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + // Another watch for a different name. + ldsUpdateCh2 := testutils.NewChannel() + client.WatchListener(testLDSName+"2", func(update ListenerUpdate, err error) { + ldsUpdateCh2.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantUpdate1 := ListenerUpdate{RouteConfigName: testEDSName + "1"} + wantUpdate2 := ListenerUpdate{RouteConfigName: testEDSName + "2"} + client.NewListeners(map[string]ListenerUpdateErrTuple{ + testLDSName + "1": {Update: wantUpdate1}, + testLDSName + "2": {Update: wantUpdate2}, + }, UpdateMetadata{}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh1, wantUpdate1, nil); err != nil { + t.Fatal(err) + } + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } + + // Send another update to remove resource 1. + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName + "2": {Update: wantUpdate2}}, UpdateMetadata{}) + + // Watcher 1 should get an error. + if u, err := ldsUpdateCh1.Receive(ctx); err != nil || ErrType(u.(ListenerUpdateErrTuple).Err) != ErrorTypeResourceNotFound { + t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err) + } + + // Watcher 2 should not see an update since the resource has not changed. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh2.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u) + } + + // Send another update with resource 2 modified. Specify a non-nil raw proto + // to ensure that the new update is not considered equal to the old one. + wantUpdate2 = ListenerUpdate{RouteConfigName: testEDSName + "2", Raw: &anypb.Any{}} + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName + "2": {Update: wantUpdate2}}, UpdateMetadata{}) + + // Watcher 1 should not see an update. + sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh1.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u) + } + + // Watcher 2 should get the update. + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } +} + +// TestListenerWatchNACKError covers the case that an update is NACK'ed, and the +// watcher should also receive the error. +func (s) TestListenerWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Err: wantError}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh, ListenerUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} + +// TestListenerWatchPartialValid covers the case that a response contains both +// valid and invalid resources. This response will be NACK'ed by the xdsclient. +// But the watchers with valid resources should receive the update, those with +// invalida resources should receive an error. +func (s) TestListenerWatchPartialValid(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const badResourceName = "bad-resource" + updateChs := make(map[string]*testutils.Channel) + + for _, name := range []string{testLDSName, badResourceName} { + ldsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchListener(name, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + defer func() { + cancelWatch() + if _, err := apiClient.removeWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want watch to be canceled, got err: %v", err) + } + }() + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + updateChs[name] = ldsUpdateCh + } + + wantError := fmt.Errorf("testing error") + wantError2 := fmt.Errorf("individual error") + client.NewListeners(map[string]ListenerUpdateErrTuple{ + testLDSName: {Update: ListenerUpdate{RouteConfigName: testEDSName}}, + badResourceName: {Err: wantError2}, + }, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + + // The valid resource should be sent to the watcher. + if err := verifyListenerUpdate(ctx, updateChs[testLDSName], ListenerUpdate{RouteConfigName: testEDSName}, nil); err != nil { + t.Fatal(err) + } + + // The failed watcher should receive an error. + if err := verifyListenerUpdate(ctx, updateChs[badResourceName], ListenerUpdate{}, wantError2); err != nil { + t.Fatal(err) + } +} + +// TestListenerWatch_RedundantUpdateSupression tests scenarios where an update +// with an unmodified resource is suppressed, and modified resource is not. +func (s) TestListenerWatch_RedundantUpdateSupression(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh := testutils.NewChannel() + client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ListenerUpdateErrTuple{Update: update, Err: err}) + }) + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + basicListener := testutils.MarshalAny(&v3listenerpb.Listener{ + Name: testLDSName, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{RouteConfigName: "route-config-name"}, + }, + }), + }, + }) + listenerWithFilter1 := testutils.MarshalAny(&v3listenerpb.Listener{ + Name: testLDSName, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{RouteConfigName: "route-config-name"}, + }, + HttpFilters: []*v3httppb.HttpFilter{ + { + Name: "customFilter1", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + }, + }, + }), + }, + }) + listenerWithFilter2 := testutils.MarshalAny(&v3listenerpb.Listener{ + Name: testLDSName, + ApiListener: &v3listenerpb.ApiListener{ + ApiListener: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_Rds{ + Rds: &v3httppb.Rds{RouteConfigName: "route-config-name"}, + }, + HttpFilters: []*v3httppb.HttpFilter{ + { + Name: "customFilter2", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: customFilterConfig}, + }, + }, + }), + }, + }) + + tests := []struct { + update ListenerUpdate + wantCallback bool + }{ + { + // First update. Callback should be invoked. + update: ListenerUpdate{Raw: basicListener}, + wantCallback: true, + }, + { + // Same update as previous. Callback should be skipped. + update: ListenerUpdate{Raw: basicListener}, + wantCallback: false, + }, + { + // New update. Callback should be invoked. + update: ListenerUpdate{Raw: listenerWithFilter1}, + wantCallback: true, + }, + { + // Same update as previous. Callback should be skipped. + update: ListenerUpdate{Raw: listenerWithFilter1}, + wantCallback: false, + }, + { + // New update. Callback should be invoked. + update: ListenerUpdate{Raw: listenerWithFilter2}, + wantCallback: true, + }, + { + // Same update as previous. Callback should be skipped. + update: ListenerUpdate{Raw: listenerWithFilter2}, + wantCallback: false, + }, + } + for _, test := range tests { + client.NewListeners(map[string]ListenerUpdateErrTuple{testLDSName: {Update: test.update}}, UpdateMetadata{}) + if test.wantCallback { + if err := verifyListenerUpdate(ctx, ldsUpdateCh, test.update, nil); err != nil { + t.Fatal(err) + } + } else { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := ldsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u) + } + } + } +} diff --git a/xds/internal/client/watchers_route_test.go b/xds/internal/xdsclient/watchers_route_test.go similarity index 54% rename from xds/internal/client/watchers_route_test.go rename to xds/internal/xdsclient/watchers_route_test.go index 5f44e549333..70c8dd829e9 100644 --- a/xds/internal/client/watchers_route_test.go +++ b/xds/internal/xdsclient/watchers_route_test.go @@ -16,22 +16,19 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc/internal/testutils" ) -type rdsUpdateErr struct { - u RouteConfigUpdate - err error -} - // TestRDSWatch covers the cases: // - an update is received after a watch() // - an update for another resource name (which doesn't trigger callback) @@ -56,7 +53,7 @@ func (s) TestRDSWatch(t *testing.T) { rdsUpdateCh := testutils.NewChannel() cancelWatch := client.WatchRouteConfig(testRDSName, func(update RouteConfigUpdate, err error) { - rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -70,23 +67,28 @@ func (s) TestRDSWatch(t *testing.T) { }, }, } - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate); err != nil { + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } - // Another update for a different resource name. - client.NewRouteConfigs(map[string]RouteConfigUpdate{"randomName": {}}, UpdateMetadata{}) - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := rdsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) + // Push an update, with an extra resource for a different resource name. + // Specify a non-nil raw proto in the original resource to ensure that the + // new update is not considered equal to the old one. + newUpdate := wantUpdate + newUpdate.Raw = &anypb.Any{} + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{ + testRDSName: {Update: newUpdate}, + "randomName": {}, + }, UpdateMetadata{}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, newUpdate, nil); err != nil { + t.Fatal(err) } // Cancel watch, and send update again. cancelWatch() - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() if u, err := rdsUpdateCh.Receive(sCtx); err != context.DeadlineExceeded { t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) @@ -122,7 +124,7 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) { rdsUpdateCh := testutils.NewChannel() rdsUpdateChs = append(rdsUpdateChs, rdsUpdateCh) cancelLastWatch = client.WatchRouteConfig(testRDSName, func(update RouteConfigUpdate, err error) { - rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -142,26 +144,36 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) { }, }, } - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } - // Cancel the last watch, and send update again. + // Cancel the last watch, and send update again. None of the watchers should + // be notified because one has been cancelled, and the other is receiving + // the same update. cancelLastWatch() - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - for i := 0; i < count-1; i++ { - if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate); err != nil { - t.Fatal(err) - } + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) + for i := 0; i < count; i++ { + func() { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if u, err := rdsUpdateChs[i].Receive(sCtx); err != context.DeadlineExceeded { + t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) + } + }() } - sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) - defer sCancel() - if u, err := rdsUpdateChs[count-1].Receive(sCtx); err != context.DeadlineExceeded { - t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) + // Push a new update and make sure the uncancelled watcher is invoked. + // Specify a non-nil raw proto to ensure that the new update is not + // considered equal to the old one. + newUpdate := wantUpdate + newUpdate.Raw = &anypb.Any{} + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: newUpdate}}, UpdateMetadata{}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[0], newUpdate, nil); err != nil { + t.Fatal(err) } } @@ -192,7 +204,7 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) { rdsUpdateCh := testutils.NewChannel() rdsUpdateChs = append(rdsUpdateChs, rdsUpdateCh) client.WatchRouteConfig(testRDSName+"1", func(update RouteConfigUpdate, err error) { - rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if i == 0 { @@ -207,7 +219,7 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) { // Third watch for a different name. rdsUpdateCh2 := testutils.NewChannel() client.WatchRouteConfig(testRDSName+"2", func(update RouteConfigUpdate, err error) { - rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh2.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -229,17 +241,17 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) { }, }, } - client.NewRouteConfigs(map[string]RouteConfigUpdate{ - testRDSName + "1": wantUpdate1, - testRDSName + "2": wantUpdate2, + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{ + testRDSName + "1": {Update: wantUpdate1}, + testRDSName + "2": {Update: wantUpdate2}, }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate1); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh2, wantUpdate2); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -266,7 +278,7 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { rdsUpdateCh := testutils.NewChannel() client.WatchRouteConfig(testRDSName, func(update RouteConfigUpdate, err error) { - rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { t.Fatalf("want new watch to start, got error %v", err) @@ -280,15 +292,15 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { }, }, } - client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate); err != nil { + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testRDSName: {Update: wantUpdate}}, UpdateMetadata{}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } // Another watch for the resource in cache. rdsUpdateCh2 := testutils.NewChannel() client.WatchRouteConfig(testRDSName, func(update RouteConfigUpdate, err error) { - rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err}) + rdsUpdateCh2.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) }) sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) defer sCancel() @@ -297,7 +309,7 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { } // New watch should receives the update. - if u, err := rdsUpdateCh2.Receive(ctx); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) { + if u, err := rdsUpdateCh2.Receive(ctx); err != nil || !cmp.Equal(u, RouteConfigUpdateErrTuple{wantUpdate, nil}, cmp.AllowUnexported(RouteConfigUpdateErrTuple{})) { t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err) } @@ -308,3 +320,105 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) } } + +// TestRouteWatchNACKError covers the case that an update is NACK'ed, and the +// watcher should also receive the error. +func (s) TestRouteWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + rdsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchRouteConfig(testCDSName, func(update RouteConfigUpdate, err error) { + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{testCDSName: {Err: wantError}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, RouteConfigUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} + +// TestRouteWatchPartialValid covers the case that a response contains both +// valid and invalid resources. This response will be NACK'ed by the xdsclient. +// But the watchers with valid resources should receive the update, those with +// invalida resources should receive an error. +func (s) TestRouteWatchPartialValid(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + const badResourceName = "bad-resource" + updateChs := make(map[string]*testutils.Channel) + + for _, name := range []string{testRDSName, badResourceName} { + rdsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchRouteConfig(name, func(update RouteConfigUpdate, err error) { + rdsUpdateCh.Send(RouteConfigUpdateErrTuple{Update: update, Err: err}) + }) + defer func() { + cancelWatch() + if _, err := apiClient.removeWatches[RouteConfigResource].Receive(ctx); err != nil { + t.Fatalf("want watch to be canceled, got err: %v", err) + } + }() + if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + updateChs[name] = rdsUpdateCh + } + + wantError := fmt.Errorf("testing error") + wantError2 := fmt.Errorf("individual error") + client.NewRouteConfigs(map[string]RouteConfigUpdateErrTuple{ + testRDSName: {Update: RouteConfigUpdate{VirtualHosts: []*VirtualHost{{ + Domains: []string{testLDSName}, + Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{testCDSName: {Weight: 1}}}}, + }}}}, + badResourceName: {Err: wantError2}, + }, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + + // The valid resource should be sent to the watcher. + if err := verifyRouteConfigUpdate(ctx, updateChs[testRDSName], RouteConfigUpdate{VirtualHosts: []*VirtualHost{{ + Domains: []string{testLDSName}, + Routes: []*Route{{Prefix: newStringP(""), WeightedClusters: map[string]WeightedCluster{testCDSName: {Weight: 1}}}}, + }}}, nil); err != nil { + t.Fatal(err) + } + + // The failed watcher should receive an error. + if err := verifyRouteConfigUpdate(ctx, updateChs[badResourceName], RouteConfigUpdate{}, wantError2); err != nil { + t.Fatal(err) + } +} diff --git a/xds/internal/xdsclient/xds.go b/xds/internal/xdsclient/xds.go new file mode 100644 index 00000000000..732c4e6addc --- /dev/null +++ b/xds/internal/xdsclient/xds.go @@ -0,0 +1,1334 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "errors" + "fmt" + "net" + "regexp" + "strconv" + "strings" + "time" + + v1typepb "github.com/cncf/udpa/go/udpa/type/v1" + v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3aggregateclusterpb "github.com/envoyproxy/go-control-plane/envoy/extensions/clusters/aggregate/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/internal/xds/matcher" + "google.golang.org/protobuf/types/known/anypb" + + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/version" +) + +// TransportSocket proto message has a `name` field which is expected to be set +// to this value by the management server. +const transportSocketName = "envoy.transport_sockets.tls" + +// UnmarshalOptions wraps the input parameters for `UnmarshalXxx` functions. +type UnmarshalOptions struct { + // Version is the version of the received response. + Version string + // Resources are the xDS resources resources in the received response. + Resources []*anypb.Any + // Logger is the prefix logger to be used during unmarshaling. + Logger *grpclog.PrefixLogger + // UpdateValidator is a post unmarshal validation check provided by the + // upper layer. + UpdateValidator UpdateValidatorFunc +} + +// UnmarshalListener processes resources received in an LDS response, validates +// them, and transforms them into a native struct which contains only fields we +// are interested in. +func UnmarshalListener(opts *UnmarshalOptions) (map[string]ListenerUpdateErrTuple, UpdateMetadata, error) { + update := make(map[string]ListenerUpdateErrTuple) + md, err := processAllResources(opts, update) + return update, md, err +} + +func unmarshalListenerResource(r *anypb.Any, f UpdateValidatorFunc, logger *grpclog.PrefixLogger) (string, ListenerUpdate, error) { + if !IsListenerResource(r.GetTypeUrl()) { + return "", ListenerUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) + } + // TODO: Pass version.TransportAPI instead of relying upon the type URL + v2 := r.GetTypeUrl() == version.V2ListenerURL + lis := &v3listenerpb.Listener{} + if err := proto.Unmarshal(r.GetValue(), lis); err != nil { + return "", ListenerUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + logger.Infof("Resource with name: %v, type: %T, contains: %v", lis.GetName(), lis, pretty.ToJSON(lis)) + + lu, err := processListener(lis, logger, v2) + if err != nil { + return lis.GetName(), ListenerUpdate{}, err + } + if f != nil { + if err := f(*lu); err != nil { + return lis.GetName(), ListenerUpdate{}, err + } + } + lu.Raw = r + return lis.GetName(), *lu, nil +} + +func processListener(lis *v3listenerpb.Listener, logger *grpclog.PrefixLogger, v2 bool) (*ListenerUpdate, error) { + if lis.GetApiListener() != nil { + return processClientSideListener(lis, logger, v2) + } + return processServerSideListener(lis) +} + +// processClientSideListener checks if the provided Listener proto meets +// the expected criteria. If so, it returns a non-empty routeConfigName. +func processClientSideListener(lis *v3listenerpb.Listener, logger *grpclog.PrefixLogger, v2 bool) (*ListenerUpdate, error) { + update := &ListenerUpdate{} + + apiLisAny := lis.GetApiListener().GetApiListener() + if !IsHTTPConnManagerResource(apiLisAny.GetTypeUrl()) { + return nil, fmt.Errorf("unexpected resource type: %q", apiLisAny.GetTypeUrl()) + } + apiLis := &v3httppb.HttpConnectionManager{} + if err := proto.Unmarshal(apiLisAny.GetValue(), apiLis); err != nil { + return nil, fmt.Errorf("failed to unmarshal api_listner: %v", err) + } + // "HttpConnectionManager.xff_num_trusted_hops must be unset or zero and + // HttpConnectionManager.original_ip_detection_extensions must be empty. If + // either field has an incorrect value, the Listener must be NACKed." - A41 + if apiLis.XffNumTrustedHops != 0 { + return nil, fmt.Errorf("xff_num_trusted_hops must be unset or zero %+v", apiLis) + } + if len(apiLis.OriginalIpDetectionExtensions) != 0 { + return nil, fmt.Errorf("original_ip_detection_extensions must be empty %+v", apiLis) + } + + switch apiLis.RouteSpecifier.(type) { + case *v3httppb.HttpConnectionManager_Rds: + if apiLis.GetRds().GetConfigSource().GetAds() == nil { + return nil, fmt.Errorf("ConfigSource is not ADS: %+v", lis) + } + name := apiLis.GetRds().GetRouteConfigName() + if name == "" { + return nil, fmt.Errorf("empty route_config_name: %+v", lis) + } + update.RouteConfigName = name + case *v3httppb.HttpConnectionManager_RouteConfig: + routeU, err := generateRDSUpdateFromRouteConfiguration(apiLis.GetRouteConfig(), logger, v2) + if err != nil { + return nil, fmt.Errorf("failed to parse inline RDS resp: %v", err) + } + update.InlineRouteConfig = &routeU + case nil: + return nil, fmt.Errorf("no RouteSpecifier: %+v", apiLis) + default: + return nil, fmt.Errorf("unsupported type %T for RouteSpecifier", apiLis.RouteSpecifier) + } + + if v2 { + return update, nil + } + + // The following checks and fields only apply to xDS protocol versions v3+. + + update.MaxStreamDuration = apiLis.GetCommonHttpProtocolOptions().GetMaxStreamDuration().AsDuration() + + var err error + if update.HTTPFilters, err = processHTTPFilters(apiLis.GetHttpFilters(), false); err != nil { + return nil, err + } + + return update, nil +} + +func unwrapHTTPFilterConfig(config *anypb.Any) (proto.Message, string, error) { + // The real type name is inside the TypedStruct. + s := new(v1typepb.TypedStruct) + if !ptypes.Is(config, s) { + return config, config.GetTypeUrl(), nil + } + if err := ptypes.UnmarshalAny(config, s); err != nil { + return nil, "", fmt.Errorf("error unmarshalling TypedStruct filter config: %v", err) + } + return s, s.GetTypeUrl(), nil +} + +func validateHTTPFilterConfig(cfg *anypb.Any, lds, optional bool) (httpfilter.Filter, httpfilter.FilterConfig, error) { + config, typeURL, err := unwrapHTTPFilterConfig(cfg) + if err != nil { + return nil, nil, err + } + filterBuilder := httpfilter.Get(typeURL) + if filterBuilder == nil { + if optional { + return nil, nil, nil + } + return nil, nil, fmt.Errorf("no filter implementation found for %q", typeURL) + } + parseFunc := filterBuilder.ParseFilterConfig + if !lds { + parseFunc = filterBuilder.ParseFilterConfigOverride + } + filterConfig, err := parseFunc(config) + if err != nil { + return nil, nil, fmt.Errorf("error parsing config for filter %q: %v", typeURL, err) + } + return filterBuilder, filterConfig, nil +} + +func processHTTPFilterOverrides(cfgs map[string]*anypb.Any) (map[string]httpfilter.FilterConfig, error) { + if len(cfgs) == 0 { + return nil, nil + } + m := make(map[string]httpfilter.FilterConfig) + for name, cfg := range cfgs { + optional := false + s := new(v3routepb.FilterConfig) + if ptypes.Is(cfg, s) { + if err := ptypes.UnmarshalAny(cfg, s); err != nil { + return nil, fmt.Errorf("filter override %q: error unmarshalling FilterConfig: %v", name, err) + } + cfg = s.GetConfig() + optional = s.GetIsOptional() + } + + httpFilter, config, err := validateHTTPFilterConfig(cfg, false, optional) + if err != nil { + return nil, fmt.Errorf("filter override %q: %v", name, err) + } + if httpFilter == nil { + // Optional configs are ignored. + continue + } + m[name] = config + } + return m, nil +} + +func processHTTPFilters(filters []*v3httppb.HttpFilter, server bool) ([]HTTPFilter, error) { + ret := make([]HTTPFilter, 0, len(filters)) + seenNames := make(map[string]bool, len(filters)) + for _, filter := range filters { + name := filter.GetName() + if name == "" { + return nil, errors.New("filter missing name field") + } + if seenNames[name] { + return nil, fmt.Errorf("duplicate filter name %q", name) + } + seenNames[name] = true + + httpFilter, config, err := validateHTTPFilterConfig(filter.GetTypedConfig(), true, filter.GetIsOptional()) + if err != nil { + return nil, err + } + if httpFilter == nil { + // Optional configs are ignored. + continue + } + if server { + if _, ok := httpFilter.(httpfilter.ServerInterceptorBuilder); !ok { + if filter.GetIsOptional() { + continue + } + return nil, fmt.Errorf("HTTP filter %q not supported server-side", name) + } + } else if _, ok := httpFilter.(httpfilter.ClientInterceptorBuilder); !ok { + if filter.GetIsOptional() { + continue + } + return nil, fmt.Errorf("HTTP filter %q not supported client-side", name) + } + + // Save name/config + ret = append(ret, HTTPFilter{Name: name, Filter: httpFilter, Config: config}) + } + // "Validation will fail if a terminal filter is not the last filter in the + // chain or if a non-terminal filter is the last filter in the chain." - A39 + if len(ret) == 0 { + return nil, fmt.Errorf("http filters list is empty") + } + var i int + for ; i < len(ret)-1; i++ { + if ret[i].Filter.IsTerminal() { + return nil, fmt.Errorf("http filter %q is a terminal filter but it is not last in the filter chain", ret[i].Name) + } + } + if !ret[i].Filter.IsTerminal() { + return nil, fmt.Errorf("http filter %q is not a terminal filter", ret[len(ret)-1].Name) + } + return ret, nil +} + +func processServerSideListener(lis *v3listenerpb.Listener) (*ListenerUpdate, error) { + if n := len(lis.ListenerFilters); n != 0 { + return nil, fmt.Errorf("unsupported field 'listener_filters' contains %d entries", n) + } + if useOrigDst := lis.GetUseOriginalDst(); useOrigDst != nil && useOrigDst.GetValue() { + return nil, errors.New("unsupported field 'use_original_dst' is present and set to true") + } + addr := lis.GetAddress() + if addr == nil { + return nil, fmt.Errorf("no address field in LDS response: %+v", lis) + } + sockAddr := addr.GetSocketAddress() + if sockAddr == nil { + return nil, fmt.Errorf("no socket_address field in LDS response: %+v", lis) + } + lu := &ListenerUpdate{ + InboundListenerCfg: &InboundListenerConfig{ + Address: sockAddr.GetAddress(), + Port: strconv.Itoa(int(sockAddr.GetPortValue())), + }, + } + + fcMgr, err := NewFilterChainManager(lis) + if err != nil { + return nil, err + } + lu.InboundListenerCfg.FilterChains = fcMgr + return lu, nil +} + +// UnmarshalRouteConfig processes resources received in an RDS response, +// validates them, and transforms them into a native struct which contains only +// fields we are interested in. The provided hostname determines the route +// configuration resources of interest. +func UnmarshalRouteConfig(opts *UnmarshalOptions) (map[string]RouteConfigUpdateErrTuple, UpdateMetadata, error) { + update := make(map[string]RouteConfigUpdateErrTuple) + md, err := processAllResources(opts, update) + return update, md, err +} + +func unmarshalRouteConfigResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, RouteConfigUpdate, error) { + if !IsRouteConfigResource(r.GetTypeUrl()) { + return "", RouteConfigUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) + } + rc := &v3routepb.RouteConfiguration{} + if err := proto.Unmarshal(r.GetValue(), rc); err != nil { + return "", RouteConfigUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + logger.Infof("Resource with name: %v, type: %T, contains: %v.", rc.GetName(), rc, pretty.ToJSON(rc)) + + // TODO: Pass version.TransportAPI instead of relying upon the type URL + v2 := r.GetTypeUrl() == version.V2RouteConfigURL + u, err := generateRDSUpdateFromRouteConfiguration(rc, logger, v2) + if err != nil { + return rc.GetName(), RouteConfigUpdate{}, err + } + u.Raw = r + return rc.GetName(), u, nil +} + +// generateRDSUpdateFromRouteConfiguration checks if the provided +// RouteConfiguration meets the expected criteria. If so, it returns a +// RouteConfigUpdate with nil error. +// +// A RouteConfiguration resource is considered valid when only if it contains a +// VirtualHost whose domain field matches the server name from the URI passed +// to the gRPC channel, and it contains a clusterName or a weighted cluster. +// +// The RouteConfiguration includes a list of virtualHosts, which may have zero +// or more elements. We are interested in the element whose domains field +// matches the server name specified in the "xds:" URI. The only field in the +// VirtualHost proto that the we are interested in is the list of routes. We +// only look at the last route in the list (the default route), whose match +// field must be empty and whose route field must be set. Inside that route +// message, the cluster field will contain the clusterName or weighted clusters +// we are looking for. +func generateRDSUpdateFromRouteConfiguration(rc *v3routepb.RouteConfiguration, logger *grpclog.PrefixLogger, v2 bool) (RouteConfigUpdate, error) { + vhs := make([]*VirtualHost, 0, len(rc.GetVirtualHosts())) + for _, vh := range rc.GetVirtualHosts() { + routes, err := routesProtoToSlice(vh.Routes, logger, v2) + if err != nil { + return RouteConfigUpdate{}, fmt.Errorf("received route is invalid: %v", err) + } + rc, err := generateRetryConfig(vh.GetRetryPolicy()) + if err != nil { + return RouteConfigUpdate{}, fmt.Errorf("received route is invalid: %v", err) + } + vhOut := &VirtualHost{ + Domains: vh.GetDomains(), + Routes: routes, + RetryConfig: rc, + } + if !v2 { + cfgs, err := processHTTPFilterOverrides(vh.GetTypedPerFilterConfig()) + if err != nil { + return RouteConfigUpdate{}, fmt.Errorf("virtual host %+v: %v", vh, err) + } + vhOut.HTTPFilterConfigOverride = cfgs + } + vhs = append(vhs, vhOut) + } + return RouteConfigUpdate{VirtualHosts: vhs}, nil +} + +func generateRetryConfig(rp *v3routepb.RetryPolicy) (*RetryConfig, error) { + if !env.RetrySupport || rp == nil { + return nil, nil + } + + cfg := &RetryConfig{RetryOn: make(map[codes.Code]bool)} + for _, s := range strings.Split(rp.GetRetryOn(), ",") { + switch strings.TrimSpace(strings.ToLower(s)) { + case "cancelled": + cfg.RetryOn[codes.Canceled] = true + case "deadline-exceeded": + cfg.RetryOn[codes.DeadlineExceeded] = true + case "internal": + cfg.RetryOn[codes.Internal] = true + case "resource-exhausted": + cfg.RetryOn[codes.ResourceExhausted] = true + case "unavailable": + cfg.RetryOn[codes.Unavailable] = true + } + } + + if rp.NumRetries == nil { + cfg.NumRetries = 1 + } else { + cfg.NumRetries = rp.GetNumRetries().Value + if cfg.NumRetries < 1 { + return nil, fmt.Errorf("retry_policy.num_retries = %v; must be >= 1", cfg.NumRetries) + } + } + + backoff := rp.GetRetryBackOff() + if backoff == nil { + cfg.RetryBackoff.BaseInterval = 25 * time.Millisecond + } else { + cfg.RetryBackoff.BaseInterval = backoff.GetBaseInterval().AsDuration() + if cfg.RetryBackoff.BaseInterval <= 0 { + return nil, fmt.Errorf("retry_policy.base_interval = %v; must be > 0", cfg.RetryBackoff.BaseInterval) + } + } + if max := backoff.GetMaxInterval(); max == nil { + cfg.RetryBackoff.MaxInterval = 10 * cfg.RetryBackoff.BaseInterval + } else { + cfg.RetryBackoff.MaxInterval = max.AsDuration() + if cfg.RetryBackoff.MaxInterval <= 0 { + return nil, fmt.Errorf("retry_policy.max_interval = %v; must be > 0", cfg.RetryBackoff.MaxInterval) + } + } + + if len(cfg.RetryOn) == 0 { + return &RetryConfig{}, nil + } + return cfg, nil +} + +func routesProtoToSlice(routes []*v3routepb.Route, logger *grpclog.PrefixLogger, v2 bool) ([]*Route, error) { + var routesRet []*Route + for _, r := range routes { + match := r.GetMatch() + if match == nil { + return nil, fmt.Errorf("route %+v doesn't have a match", r) + } + + if len(match.GetQueryParameters()) != 0 { + // Ignore route with query parameters. + logger.Warningf("route %+v has query parameter matchers, the route will be ignored", r) + continue + } + + pathSp := match.GetPathSpecifier() + if pathSp == nil { + return nil, fmt.Errorf("route %+v doesn't have a path specifier", r) + } + + var route Route + switch pt := pathSp.(type) { + case *v3routepb.RouteMatch_Prefix: + route.Prefix = &pt.Prefix + case *v3routepb.RouteMatch_Path: + route.Path = &pt.Path + case *v3routepb.RouteMatch_SafeRegex: + regex := pt.SafeRegex.GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("route %+v contains an invalid regex %q", r, regex) + } + route.Regex = re + default: + return nil, fmt.Errorf("route %+v has an unrecognized path specifier: %+v", r, pt) + } + + if caseSensitive := match.GetCaseSensitive(); caseSensitive != nil { + route.CaseInsensitive = !caseSensitive.Value + } + + for _, h := range match.GetHeaders() { + var header HeaderMatcher + switch ht := h.GetHeaderMatchSpecifier().(type) { + case *v3routepb.HeaderMatcher_ExactMatch: + header.ExactMatch = &ht.ExactMatch + case *v3routepb.HeaderMatcher_SafeRegexMatch: + regex := ht.SafeRegexMatch.GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("route %+v contains an invalid regex %q", r, regex) + } + header.RegexMatch = re + case *v3routepb.HeaderMatcher_RangeMatch: + header.RangeMatch = &Int64Range{ + Start: ht.RangeMatch.Start, + End: ht.RangeMatch.End, + } + case *v3routepb.HeaderMatcher_PresentMatch: + header.PresentMatch = &ht.PresentMatch + case *v3routepb.HeaderMatcher_PrefixMatch: + header.PrefixMatch = &ht.PrefixMatch + case *v3routepb.HeaderMatcher_SuffixMatch: + header.SuffixMatch = &ht.SuffixMatch + default: + return nil, fmt.Errorf("route %+v has an unrecognized header matcher: %+v", r, ht) + } + header.Name = h.GetName() + invert := h.GetInvertMatch() + header.InvertMatch = &invert + route.Headers = append(route.Headers, &header) + } + + if fr := match.GetRuntimeFraction(); fr != nil { + d := fr.GetDefaultValue() + n := d.GetNumerator() + switch d.GetDenominator() { + case v3typepb.FractionalPercent_HUNDRED: + n *= 10000 + case v3typepb.FractionalPercent_TEN_THOUSAND: + n *= 100 + case v3typepb.FractionalPercent_MILLION: + } + route.Fraction = &n + } + + switch r.GetAction().(type) { + case *v3routepb.Route_Route: + route.WeightedClusters = make(map[string]WeightedCluster) + action := r.GetRoute() + + // Hash Policies are only applicable for a Ring Hash LB. + if env.RingHashSupport { + hp, err := hashPoliciesProtoToSlice(action.HashPolicy, logger) + if err != nil { + return nil, err + } + route.HashPolicies = hp + } + + switch a := action.GetClusterSpecifier().(type) { + case *v3routepb.RouteAction_Cluster: + route.WeightedClusters[a.Cluster] = WeightedCluster{Weight: 1} + case *v3routepb.RouteAction_WeightedClusters: + wcs := a.WeightedClusters + var totalWeight uint32 + for _, c := range wcs.Clusters { + w := c.GetWeight().GetValue() + if w == 0 { + continue + } + wc := WeightedCluster{Weight: w} + if !v2 { + cfgs, err := processHTTPFilterOverrides(c.GetTypedPerFilterConfig()) + if err != nil { + return nil, fmt.Errorf("route %+v, action %+v: %v", r, a, err) + } + wc.HTTPFilterConfigOverride = cfgs + } + route.WeightedClusters[c.GetName()] = wc + totalWeight += w + } + // envoy xds doc + // default TotalWeight https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/route/v3/route_components.proto.html#envoy-v3-api-field-config-route-v3-weightedcluster-total-weight + wantTotalWeight := uint32(100) + if tw := wcs.GetTotalWeight(); tw != nil { + wantTotalWeight = tw.GetValue() + } + if totalWeight != wantTotalWeight { + return nil, fmt.Errorf("route %+v, action %+v, weights of clusters do not add up to total total weight, got: %v, expected total weight from response: %v", r, a, totalWeight, wantTotalWeight) + } + if totalWeight == 0 { + return nil, fmt.Errorf("route %+v, action %+v, has no valid cluster in WeightedCluster action", r, a) + } + case *v3routepb.RouteAction_ClusterHeader: + continue + default: + return nil, fmt.Errorf("route %+v, has an unknown ClusterSpecifier: %+v", r, a) + } + + msd := action.GetMaxStreamDuration() + // Prefer grpc_timeout_header_max, if set. + dur := msd.GetGrpcTimeoutHeaderMax() + if dur == nil { + dur = msd.GetMaxStreamDuration() + } + if dur != nil { + d := dur.AsDuration() + route.MaxStreamDuration = &d + } + + var err error + route.RetryConfig, err = generateRetryConfig(action.GetRetryPolicy()) + if err != nil { + return nil, fmt.Errorf("route %+v, action %+v: %v", r, action, err) + } + + route.RouteAction = RouteActionRoute + + case *v3routepb.Route_NonForwardingAction: + // Expected to be used on server side. + route.RouteAction = RouteActionNonForwardingAction + default: + route.RouteAction = RouteActionUnsupported + } + + if !v2 { + cfgs, err := processHTTPFilterOverrides(r.GetTypedPerFilterConfig()) + if err != nil { + return nil, fmt.Errorf("route %+v: %v", r, err) + } + route.HTTPFilterConfigOverride = cfgs + } + routesRet = append(routesRet, &route) + } + return routesRet, nil +} + +func hashPoliciesProtoToSlice(policies []*v3routepb.RouteAction_HashPolicy, logger *grpclog.PrefixLogger) ([]*HashPolicy, error) { + var hashPoliciesRet []*HashPolicy + for _, p := range policies { + policy := HashPolicy{Terminal: p.Terminal} + switch p.GetPolicySpecifier().(type) { + case *v3routepb.RouteAction_HashPolicy_Header_: + policy.HashPolicyType = HashPolicyTypeHeader + policy.HeaderName = p.GetHeader().GetHeaderName() + if rr := p.GetHeader().GetRegexRewrite(); rr != nil { + regex := rr.GetPattern().GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("hash policy %+v contains an invalid regex %q", p, regex) + } + policy.Regex = re + policy.RegexSubstitution = rr.GetSubstitution() + } + case *v3routepb.RouteAction_HashPolicy_FilterState_: + if p.GetFilterState().GetKey() != "io.grpc.channel_id" { + logger.Infof("hash policy %+v contains an invalid key for filter state policy %q", p, p.GetFilterState().GetKey()) + continue + } + policy.HashPolicyType = HashPolicyTypeChannelID + default: + logger.Infof("hash policy %T is an unsupported hash policy", p.GetPolicySpecifier()) + continue + } + + hashPoliciesRet = append(hashPoliciesRet, &policy) + } + return hashPoliciesRet, nil +} + +// UnmarshalCluster processes resources received in an CDS response, validates +// them, and transforms them into a native struct which contains only fields we +// are interested in. +func UnmarshalCluster(opts *UnmarshalOptions) (map[string]ClusterUpdateErrTuple, UpdateMetadata, error) { + update := make(map[string]ClusterUpdateErrTuple) + md, err := processAllResources(opts, update) + return update, md, err +} + +func unmarshalClusterResource(r *anypb.Any, f UpdateValidatorFunc, logger *grpclog.PrefixLogger) (string, ClusterUpdate, error) { + if !IsClusterResource(r.GetTypeUrl()) { + return "", ClusterUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) + } + + cluster := &v3clusterpb.Cluster{} + if err := proto.Unmarshal(r.GetValue(), cluster); err != nil { + return "", ClusterUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + logger.Infof("Resource with name: %v, type: %T, contains: %v", cluster.GetName(), cluster, pretty.ToJSON(cluster)) + cu, err := validateClusterAndConstructClusterUpdate(cluster) + if err != nil { + return cluster.GetName(), ClusterUpdate{}, err + } + cu.Raw = r + if f != nil { + if err := f(cu); err != nil { + return "", ClusterUpdate{}, err + } + } + + return cluster.GetName(), cu, nil +} + +const ( + defaultRingHashMinSize = 1024 + defaultRingHashMaxSize = 8 * 1024 * 1024 // 8M + ringHashSizeUpperBound = 8 * 1024 * 1024 // 8M +) + +func validateClusterAndConstructClusterUpdate(cluster *v3clusterpb.Cluster) (ClusterUpdate, error) { + var lbPolicy *ClusterLBPolicyRingHash + switch cluster.GetLbPolicy() { + case v3clusterpb.Cluster_ROUND_ROBIN: + lbPolicy = nil // The default is round_robin, and there's no config to set. + case v3clusterpb.Cluster_RING_HASH: + if !env.RingHashSupport { + return ClusterUpdate{}, fmt.Errorf("unexpected lbPolicy %v in response: %+v", cluster.GetLbPolicy(), cluster) + } + rhc := cluster.GetRingHashLbConfig() + if rhc.GetHashFunction() != v3clusterpb.Cluster_RingHashLbConfig_XX_HASH { + return ClusterUpdate{}, fmt.Errorf("unsupported ring_hash hash function %v in response: %+v", rhc.GetHashFunction(), cluster) + } + // Minimum defaults to 1024 entries, and limited to 8M entries Maximum + // defaults to 8M entries, and limited to 8M entries + var minSize, maxSize uint64 = defaultRingHashMinSize, defaultRingHashMaxSize + if min := rhc.GetMinimumRingSize(); min != nil { + if min.GetValue() > ringHashSizeUpperBound { + return ClusterUpdate{}, fmt.Errorf("unexpected ring_hash mininum ring size %v in response: %+v", min.GetValue(), cluster) + } + minSize = min.GetValue() + } + if max := rhc.GetMaximumRingSize(); max != nil { + if max.GetValue() > ringHashSizeUpperBound { + return ClusterUpdate{}, fmt.Errorf("unexpected ring_hash maxinum ring size %v in response: %+v", max.GetValue(), cluster) + } + maxSize = max.GetValue() + } + if minSize > maxSize { + return ClusterUpdate{}, fmt.Errorf("ring_hash config min size %v is greater than max %v", minSize, maxSize) + } + lbPolicy = &ClusterLBPolicyRingHash{MinimumRingSize: minSize, MaximumRingSize: maxSize} + default: + return ClusterUpdate{}, fmt.Errorf("unexpected lbPolicy %v in response: %+v", cluster.GetLbPolicy(), cluster) + } + + // Process security configuration received from the control plane iff the + // corresponding environment variable is set. + var sc *SecurityConfig + if env.ClientSideSecuritySupport { + var err error + if sc, err = securityConfigFromCluster(cluster); err != nil { + return ClusterUpdate{}, err + } + } + + ret := ClusterUpdate{ + ClusterName: cluster.GetName(), + EnableLRS: cluster.GetLrsServer().GetSelf() != nil, + SecurityCfg: sc, + MaxRequests: circuitBreakersFromCluster(cluster), + LBPolicy: lbPolicy, + } + + // Validate and set cluster type from the response. + switch { + case cluster.GetType() == v3clusterpb.Cluster_EDS: + if cluster.GetEdsClusterConfig().GetEdsConfig().GetAds() == nil { + return ClusterUpdate{}, fmt.Errorf("unexpected edsConfig in response: %+v", cluster) + } + ret.ClusterType = ClusterTypeEDS + ret.EDSServiceName = cluster.GetEdsClusterConfig().GetServiceName() + return ret, nil + case cluster.GetType() == v3clusterpb.Cluster_LOGICAL_DNS: + if !env.AggregateAndDNSSupportEnv { + return ClusterUpdate{}, fmt.Errorf("unsupported cluster type (%v, %v) in response: %+v", cluster.GetType(), cluster.GetClusterType(), cluster) + } + ret.ClusterType = ClusterTypeLogicalDNS + dnsHN, err := dnsHostNameFromCluster(cluster) + if err != nil { + return ClusterUpdate{}, err + } + ret.DNSHostName = dnsHN + return ret, nil + case cluster.GetClusterType() != nil && cluster.GetClusterType().Name == "envoy.clusters.aggregate": + if !env.AggregateAndDNSSupportEnv { + return ClusterUpdate{}, fmt.Errorf("unsupported cluster type (%v, %v) in response: %+v", cluster.GetType(), cluster.GetClusterType(), cluster) + } + clusters := &v3aggregateclusterpb.ClusterConfig{} + if err := proto.Unmarshal(cluster.GetClusterType().GetTypedConfig().GetValue(), clusters); err != nil { + return ClusterUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + ret.ClusterType = ClusterTypeAggregate + ret.PrioritizedClusterNames = clusters.Clusters + return ret, nil + default: + return ClusterUpdate{}, fmt.Errorf("unsupported cluster type (%v, %v) in response: %+v", cluster.GetType(), cluster.GetClusterType(), cluster) + } +} + +// dnsHostNameFromCluster extracts the DNS host name from the cluster's load +// assignment. +// +// There should be exactly one locality, with one endpoint, whose address +// contains the address and port. +func dnsHostNameFromCluster(cluster *v3clusterpb.Cluster) (string, error) { + loadAssignment := cluster.GetLoadAssignment() + if loadAssignment == nil { + return "", fmt.Errorf("load_assignment not present for LOGICAL_DNS cluster") + } + if len(loadAssignment.GetEndpoints()) != 1 { + return "", fmt.Errorf("load_assignment for LOGICAL_DNS cluster must have exactly one locality, got: %+v", loadAssignment) + } + endpoints := loadAssignment.GetEndpoints()[0].GetLbEndpoints() + if len(endpoints) != 1 { + return "", fmt.Errorf("locality for LOGICAL_DNS cluster must have exactly one endpoint, got: %+v", endpoints) + } + endpoint := endpoints[0].GetEndpoint() + if endpoint == nil { + return "", fmt.Errorf("endpoint for LOGICAL_DNS cluster not set") + } + socketAddr := endpoint.GetAddress().GetSocketAddress() + if socketAddr == nil { + return "", fmt.Errorf("socket address for endpoint for LOGICAL_DNS cluster not set") + } + if socketAddr.GetResolverName() != "" { + return "", fmt.Errorf("socket address for endpoint for LOGICAL_DNS cluster not set has unexpected custom resolver name: %v", socketAddr.GetResolverName()) + } + host := socketAddr.GetAddress() + if host == "" { + return "", fmt.Errorf("host for endpoint for LOGICAL_DNS cluster not set") + } + port := socketAddr.GetPortValue() + if port == 0 { + return "", fmt.Errorf("port for endpoint for LOGICAL_DNS cluster not set") + } + return net.JoinHostPort(host, strconv.Itoa(int(port))), nil +} + +// securityConfigFromCluster extracts the relevant security configuration from +// the received Cluster resource. +func securityConfigFromCluster(cluster *v3clusterpb.Cluster) (*SecurityConfig, error) { + if tsm := cluster.GetTransportSocketMatches(); len(tsm) != 0 { + return nil, fmt.Errorf("unsupport transport_socket_matches field is non-empty: %+v", tsm) + } + // The Cluster resource contains a `transport_socket` field, which contains + // a oneof `typed_config` field of type `protobuf.Any`. The any proto + // contains a marshaled representation of an `UpstreamTlsContext` message. + ts := cluster.GetTransportSocket() + if ts == nil { + return nil, nil + } + if name := ts.GetName(); name != transportSocketName { + return nil, fmt.Errorf("transport_socket field has unexpected name: %s", name) + } + any := ts.GetTypedConfig() + if any == nil || any.TypeUrl != version.V3UpstreamTLSContextURL { + return nil, fmt.Errorf("transport_socket field has unexpected typeURL: %s", any.TypeUrl) + } + upstreamCtx := &v3tlspb.UpstreamTlsContext{} + if err := proto.Unmarshal(any.GetValue(), upstreamCtx); err != nil { + return nil, fmt.Errorf("failed to unmarshal UpstreamTlsContext in CDS response: %v", err) + } + // The following fields from `UpstreamTlsContext` are ignored: + // - sni + // - allow_renegotiation + // - max_session_keys + if upstreamCtx.GetCommonTlsContext() == nil { + return nil, errors.New("UpstreamTlsContext in CDS response does not contain a CommonTlsContext") + } + + return securityConfigFromCommonTLSContext(upstreamCtx.GetCommonTlsContext(), false) +} + +// common is expected to be not nil. +// The `alpn_protocols` field is ignored. +func securityConfigFromCommonTLSContext(common *v3tlspb.CommonTlsContext, server bool) (*SecurityConfig, error) { + if common.GetTlsParams() != nil { + return nil, fmt.Errorf("unsupported tls_params field in CommonTlsContext message: %+v", common) + } + if common.GetCustomHandshaker() != nil { + return nil, fmt.Errorf("unsupported custom_handshaker field in CommonTlsContext message: %+v", common) + } + + // For now, if we can't get a valid security config from the new fields, we + // fallback to the old deprecated fields. + // TODO: Drop support for deprecated fields. NACK if err != nil here. + sc, _ := securityConfigFromCommonTLSContextUsingNewFields(common, server) + if sc == nil || sc.Equal(&SecurityConfig{}) { + var err error + sc, err = securityConfigFromCommonTLSContextWithDeprecatedFields(common, server) + if err != nil { + return nil, err + } + } + if sc != nil { + // sc == nil is a valid case where the control plane has not sent us any + // security configuration. xDS creds will use fallback creds. + if server { + if sc.IdentityInstanceName == "" { + return nil, errors.New("security configuration on the server-side does not contain identity certificate provider instance name") + } + } else { + if sc.RootInstanceName == "" { + return nil, errors.New("security configuration on the client-side does not contain root certificate provider instance name") + } + } + } + return sc, nil +} + +func securityConfigFromCommonTLSContextWithDeprecatedFields(common *v3tlspb.CommonTlsContext, server bool) (*SecurityConfig, error) { + // The `CommonTlsContext` contains a + // `tls_certificate_certificate_provider_instance` field of type + // `CertificateProviderInstance`, which contains the provider instance name + // and the certificate name to fetch identity certs. + sc := &SecurityConfig{} + if identity := common.GetTlsCertificateCertificateProviderInstance(); identity != nil { + sc.IdentityInstanceName = identity.GetInstanceName() + sc.IdentityCertName = identity.GetCertificateName() + } + + // The `CommonTlsContext` contains a `validation_context_type` field which + // is a oneof. We can get the values that we are interested in from two of + // those possible values: + // - combined validation context: + // - contains a default validation context which holds the list of + // matchers for accepted SANs. + // - contains certificate provider instance configuration + // - certificate provider instance configuration + // - in this case, we do not get a list of accepted SANs. + switch t := common.GetValidationContextType().(type) { + case *v3tlspb.CommonTlsContext_CombinedValidationContext: + combined := common.GetCombinedValidationContext() + var matchers []matcher.StringMatcher + if def := combined.GetDefaultValidationContext(); def != nil { + for _, m := range def.GetMatchSubjectAltNames() { + matcher, err := matcher.StringMatcherFromProto(m) + if err != nil { + return nil, err + } + matchers = append(matchers, matcher) + } + } + if server && len(matchers) != 0 { + return nil, fmt.Errorf("match_subject_alt_names field in validation context is not supported on the server: %v", common) + } + sc.SubjectAltNameMatchers = matchers + if pi := combined.GetValidationContextCertificateProviderInstance(); pi != nil { + sc.RootInstanceName = pi.GetInstanceName() + sc.RootCertName = pi.GetCertificateName() + } + case *v3tlspb.CommonTlsContext_ValidationContextCertificateProviderInstance: + pi := common.GetValidationContextCertificateProviderInstance() + sc.RootInstanceName = pi.GetInstanceName() + sc.RootCertName = pi.GetCertificateName() + case nil: + // It is valid for the validation context to be nil on the server side. + default: + return nil, fmt.Errorf("validation context contains unexpected type: %T", t) + } + return sc, nil +} + +// gRFC A29 https://github.com/grpc/proposal/blob/master/A29-xds-tls-security.md +// specifies the new way to fetch security configuration and says the following: +// +// Although there are various ways to obtain certificates as per this proto +// (which are supported by Envoy), gRPC supports only one of them and that is +// the `CertificateProviderPluginInstance` proto. +// +// This helper function attempts to fetch security configuration from the +// `CertificateProviderPluginInstance` message, given a CommonTlsContext. +func securityConfigFromCommonTLSContextUsingNewFields(common *v3tlspb.CommonTlsContext, server bool) (*SecurityConfig, error) { + // The `tls_certificate_provider_instance` field of type + // `CertificateProviderPluginInstance` is used to fetch the identity + // certificate provider. + sc := &SecurityConfig{} + identity := common.GetTlsCertificateProviderInstance() + if identity == nil && len(common.GetTlsCertificates()) != 0 { + return nil, fmt.Errorf("expected field tls_certificate_provider_instance is not set, while unsupported field tls_certificates is set in CommonTlsContext message: %+v", common) + } + if identity == nil && common.GetTlsCertificateSdsSecretConfigs() != nil { + return nil, fmt.Errorf("expected field tls_certificate_provider_instance is not set, while unsupported field tls_certificate_sds_secret_configs is set in CommonTlsContext message: %+v", common) + } + sc.IdentityInstanceName = identity.GetInstanceName() + sc.IdentityCertName = identity.GetCertificateName() + + // The `CommonTlsContext` contains a oneof field `validation_context_type`, + // which contains the `CertificateValidationContext` message in one of the + // following ways: + // - `validation_context` field + // - this is directly of type `CertificateValidationContext` + // - `combined_validation_context` field + // - this is of type `CombinedCertificateValidationContext` and contains + // a `default validation context` field of type + // `CertificateValidationContext` + // + // The `CertificateValidationContext` message has the following fields that + // we are interested in: + // - `ca_certificate_provider_instance` + // - this is of type `CertificateProviderPluginInstance` + // - `match_subject_alt_names` + // - this is a list of string matchers + // + // The `CertificateProviderPluginInstance` message contains two fields + // - instance_name + // - this is the certificate provider instance name to be looked up in + // the bootstrap configuration + // - certificate_name + // - this is an opaque name passed to the certificate provider + var validationCtx *v3tlspb.CertificateValidationContext + switch typ := common.GetValidationContextType().(type) { + case *v3tlspb.CommonTlsContext_ValidationContext: + validationCtx = common.GetValidationContext() + case *v3tlspb.CommonTlsContext_CombinedValidationContext: + validationCtx = common.GetCombinedValidationContext().GetDefaultValidationContext() + case nil: + // It is valid for the validation context to be nil on the server side. + return sc, nil + default: + return nil, fmt.Errorf("validation context contains unexpected type: %T", typ) + } + // If we get here, it means that the `CertificateValidationContext` message + // was found through one of the supported ways. It is an error if the + // validation context is specified, but it does not contain the + // ca_certificate_provider_instance field which contains information about + // the certificate provider to be used for the root certificates. + if validationCtx.GetCaCertificateProviderInstance() == nil { + return nil, fmt.Errorf("expected field ca_certificate_provider_instance is missing in CommonTlsContext message: %+v", common) + } + // The following fields are ignored: + // - trusted_ca + // - watched_directory + // - allow_expired_certificate + // - trust_chain_verification + switch { + case len(validationCtx.GetVerifyCertificateSpki()) != 0: + return nil, fmt.Errorf("unsupported verify_certificate_spki field in CommonTlsContext message: %+v", common) + case len(validationCtx.GetVerifyCertificateHash()) != 0: + return nil, fmt.Errorf("unsupported verify_certificate_hash field in CommonTlsContext message: %+v", common) + case validationCtx.GetRequireSignedCertificateTimestamp().GetValue(): + return nil, fmt.Errorf("unsupported require_sugned_ceritificate_timestamp field in CommonTlsContext message: %+v", common) + case validationCtx.GetCrl() != nil: + return nil, fmt.Errorf("unsupported crl field in CommonTlsContext message: %+v", common) + case validationCtx.GetCustomValidatorConfig() != nil: + return nil, fmt.Errorf("unsupported custom_validator_config field in CommonTlsContext message: %+v", common) + } + + if rootProvider := validationCtx.GetCaCertificateProviderInstance(); rootProvider != nil { + sc.RootInstanceName = rootProvider.GetInstanceName() + sc.RootCertName = rootProvider.GetCertificateName() + } + var matchers []matcher.StringMatcher + for _, m := range validationCtx.GetMatchSubjectAltNames() { + matcher, err := matcher.StringMatcherFromProto(m) + if err != nil { + return nil, err + } + matchers = append(matchers, matcher) + } + if server && len(matchers) != 0 { + return nil, fmt.Errorf("match_subject_alt_names field in validation context is not supported on the server: %v", common) + } + sc.SubjectAltNameMatchers = matchers + return sc, nil +} + +// circuitBreakersFromCluster extracts the circuit breakers configuration from +// the received cluster resource. Returns nil if no CircuitBreakers or no +// Thresholds in CircuitBreakers. +func circuitBreakersFromCluster(cluster *v3clusterpb.Cluster) *uint32 { + for _, threshold := range cluster.GetCircuitBreakers().GetThresholds() { + if threshold.GetPriority() != v3corepb.RoutingPriority_DEFAULT { + continue + } + maxRequestsPb := threshold.GetMaxRequests() + if maxRequestsPb == nil { + return nil + } + maxRequests := maxRequestsPb.GetValue() + return &maxRequests + } + return nil +} + +// UnmarshalEndpoints processes resources received in an EDS response, +// validates them, and transforms them into a native struct which contains only +// fields we are interested in. +func UnmarshalEndpoints(opts *UnmarshalOptions) (map[string]EndpointsUpdateErrTuple, UpdateMetadata, error) { + update := make(map[string]EndpointsUpdateErrTuple) + md, err := processAllResources(opts, update) + return update, md, err +} + +func unmarshalEndpointsResource(r *anypb.Any, logger *grpclog.PrefixLogger) (string, EndpointsUpdate, error) { + if !IsEndpointsResource(r.GetTypeUrl()) { + return "", EndpointsUpdate{}, fmt.Errorf("unexpected resource type: %q ", r.GetTypeUrl()) + } + + cla := &v3endpointpb.ClusterLoadAssignment{} + if err := proto.Unmarshal(r.GetValue(), cla); err != nil { + return "", EndpointsUpdate{}, fmt.Errorf("failed to unmarshal resource: %v", err) + } + logger.Infof("Resource with name: %v, type: %T, contains: %v", cla.GetClusterName(), cla, pretty.ToJSON(cla)) + + u, err := parseEDSRespProto(cla) + if err != nil { + return cla.GetClusterName(), EndpointsUpdate{}, err + } + u.Raw = r + return cla.GetClusterName(), u, nil +} + +func parseAddress(socketAddress *v3corepb.SocketAddress) string { + return net.JoinHostPort(socketAddress.GetAddress(), strconv.Itoa(int(socketAddress.GetPortValue()))) +} + +func parseDropPolicy(dropPolicy *v3endpointpb.ClusterLoadAssignment_Policy_DropOverload) OverloadDropConfig { + percentage := dropPolicy.GetDropPercentage() + var ( + numerator = percentage.GetNumerator() + denominator uint32 + ) + switch percentage.GetDenominator() { + case v3typepb.FractionalPercent_HUNDRED: + denominator = 100 + case v3typepb.FractionalPercent_TEN_THOUSAND: + denominator = 10000 + case v3typepb.FractionalPercent_MILLION: + denominator = 1000000 + } + return OverloadDropConfig{ + Category: dropPolicy.GetCategory(), + Numerator: numerator, + Denominator: denominator, + } +} + +func parseEndpoints(lbEndpoints []*v3endpointpb.LbEndpoint) []Endpoint { + endpoints := make([]Endpoint, 0, len(lbEndpoints)) + for _, lbEndpoint := range lbEndpoints { + endpoints = append(endpoints, Endpoint{ + HealthStatus: EndpointHealthStatus(lbEndpoint.GetHealthStatus()), + Address: parseAddress(lbEndpoint.GetEndpoint().GetAddress().GetSocketAddress()), + Weight: lbEndpoint.GetLoadBalancingWeight().GetValue(), + }) + } + return endpoints +} + +func parseEDSRespProto(m *v3endpointpb.ClusterLoadAssignment) (EndpointsUpdate, error) { + ret := EndpointsUpdate{} + for _, dropPolicy := range m.GetPolicy().GetDropOverloads() { + ret.Drops = append(ret.Drops, parseDropPolicy(dropPolicy)) + } + priorities := make(map[uint32]struct{}) + for _, locality := range m.Endpoints { + l := locality.GetLocality() + if l == nil { + return EndpointsUpdate{}, fmt.Errorf("EDS response contains a locality without ID, locality: %+v", locality) + } + lid := internal.LocalityID{ + Region: l.Region, + Zone: l.Zone, + SubZone: l.SubZone, + } + priority := locality.GetPriority() + priorities[priority] = struct{}{} + ret.Localities = append(ret.Localities, Locality{ + ID: lid, + Endpoints: parseEndpoints(locality.GetLbEndpoints()), + Weight: locality.GetLoadBalancingWeight().GetValue(), + Priority: priority, + }) + } + for i := 0; i < len(priorities); i++ { + if _, ok := priorities[uint32(i)]; !ok { + return EndpointsUpdate{}, fmt.Errorf("priority %v missing (with different priorities %v received)", i, priorities) + } + } + return ret, nil +} + +// ListenerUpdateErrTuple is a tuple with the update and error. It contains the +// results from unmarshal functions. It's used to pass unmarshal results of +// multiple resources together, e.g. in maps like `map[string]{Update,error}`. +type ListenerUpdateErrTuple struct { + Update ListenerUpdate + Err error +} + +// RouteConfigUpdateErrTuple is a tuple with the update and error. It contains +// the results from unmarshal functions. It's used to pass unmarshal results of +// multiple resources together, e.g. in maps like `map[string]{Update,error}`. +type RouteConfigUpdateErrTuple struct { + Update RouteConfigUpdate + Err error +} + +// ClusterUpdateErrTuple is a tuple with the update and error. It contains the +// results from unmarshal functions. It's used to pass unmarshal results of +// multiple resources together, e.g. in maps like `map[string]{Update,error}`. +type ClusterUpdateErrTuple struct { + Update ClusterUpdate + Err error +} + +// EndpointsUpdateErrTuple is a tuple with the update and error. It contains the +// results from unmarshal functions. It's used to pass unmarshal results of +// multiple resources together, e.g. in maps like `map[string]{Update,error}`. +type EndpointsUpdateErrTuple struct { + Update EndpointsUpdate + Err error +} + +// processAllResources unmarshals and validates the resources, populates the +// provided ret (a map), and returns metadata and error. +// +// After this function, the ret map will be populated with both valid and +// invalid updates. Invalid resources will have an entry with the key as the +// resource name, value as an empty update. +// +// The type of the resource is determined by the type of ret. E.g. +// map[string]ListenerUpdate means this is for LDS. +func processAllResources(opts *UnmarshalOptions, ret interface{}) (UpdateMetadata, error) { + timestamp := time.Now() + md := UpdateMetadata{ + Version: opts.Version, + Timestamp: timestamp, + } + var topLevelErrors []error + perResourceErrors := make(map[string]error) + + for _, r := range opts.Resources { + switch ret2 := ret.(type) { + case map[string]ListenerUpdateErrTuple: + name, update, err := unmarshalListenerResource(r, opts.UpdateValidator, opts.Logger) + if err == nil { + ret2[name] = ListenerUpdateErrTuple{Update: update} + continue + } + if name == "" { + topLevelErrors = append(topLevelErrors, err) + continue + } + perResourceErrors[name] = err + // Add place holder in the map so we know this resource name was in + // the response. + ret2[name] = ListenerUpdateErrTuple{Err: err} + case map[string]RouteConfigUpdateErrTuple: + name, update, err := unmarshalRouteConfigResource(r, opts.Logger) + if err == nil { + ret2[name] = RouteConfigUpdateErrTuple{Update: update} + continue + } + if name == "" { + topLevelErrors = append(topLevelErrors, err) + continue + } + perResourceErrors[name] = err + // Add place holder in the map so we know this resource name was in + // the response. + ret2[name] = RouteConfigUpdateErrTuple{Err: err} + case map[string]ClusterUpdateErrTuple: + name, update, err := unmarshalClusterResource(r, opts.UpdateValidator, opts.Logger) + if err == nil { + ret2[name] = ClusterUpdateErrTuple{Update: update} + continue + } + if name == "" { + topLevelErrors = append(topLevelErrors, err) + continue + } + perResourceErrors[name] = err + // Add place holder in the map so we know this resource name was in + // the response. + ret2[name] = ClusterUpdateErrTuple{Err: err} + case map[string]EndpointsUpdateErrTuple: + name, update, err := unmarshalEndpointsResource(r, opts.Logger) + if err == nil { + ret2[name] = EndpointsUpdateErrTuple{Update: update} + continue + } + if name == "" { + topLevelErrors = append(topLevelErrors, err) + continue + } + perResourceErrors[name] = err + // Add place holder in the map so we know this resource name was in + // the response. + ret2[name] = EndpointsUpdateErrTuple{Err: err} + } + } + + if len(topLevelErrors) == 0 && len(perResourceErrors) == 0 { + md.Status = ServiceStatusACKed + return md, nil + } + + var typeStr string + switch ret.(type) { + case map[string]ListenerUpdate: + typeStr = "LDS" + case map[string]RouteConfigUpdate: + typeStr = "RDS" + case map[string]ClusterUpdate: + typeStr = "CDS" + case map[string]EndpointsUpdate: + typeStr = "EDS" + } + + md.Status = ServiceStatusNACKed + errRet := combineErrors(typeStr, topLevelErrors, perResourceErrors) + md.ErrState = &UpdateErrorMetadata{ + Version: opts.Version, + Err: errRet, + Timestamp: timestamp, + } + return md, errRet +} + +func combineErrors(rType string, topLevelErrors []error, perResourceErrors map[string]error) error { + var errStrB strings.Builder + errStrB.WriteString(fmt.Sprintf("error parsing %q response: ", rType)) + if len(topLevelErrors) > 0 { + errStrB.WriteString("top level errors: ") + for i, err := range topLevelErrors { + if i != 0 { + errStrB.WriteString(";\n") + } + errStrB.WriteString(err.Error()) + } + } + if len(perResourceErrors) > 0 { + var i int + for name, err := range perResourceErrors { + if i != 0 { + errStrB.WriteString(";\n") + } + i++ + errStrB.WriteString(fmt.Sprintf("resource %q: %v", name, err.Error())) + } + } + return errors.New(errStrB.String()) +} diff --git a/xds/internal/client/tests/client_test.go b/xds/internal/xdsclient/xdsclient_test.go similarity index 92% rename from xds/internal/client/tests/client_test.go rename to xds/internal/xdsclient/xdsclient_test.go index f5a57fbcd21..f348df48161 100644 --- a/xds/internal/client/tests/client_test.go +++ b/xds/internal/xdsclient/xdsclient_test.go @@ -16,7 +16,7 @@ * */ -package tests_test +package xdsclient_test import ( "testing" @@ -25,11 +25,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpctest" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" - _ "google.golang.org/grpc/xds/internal/client/v2" // Register the v2 API client. "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register the v2 API client. ) type s struct { diff --git a/xds/server.go b/xds/server.go index f1c1e4181b8..b36fa64b500 100644 --- a/xds/server.go +++ b/xds/server.go @@ -23,89 +23,78 @@ import ( "errors" "fmt" "net" + "strings" "sync" - "time" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal" - xdsinternal "google.golang.org/grpc/internal/credentials/xds" + "google.golang.org/grpc/internal/buffer" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/transport" + "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/grpc/xds/internal/server" + "google.golang.org/grpc/xds/internal/xdsclient" ) const serverPrefix = "[xds-server %p] " var ( // These new functions will be overridden in unit tests. - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { return xdsclient.New() } - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { return grpc.NewServer(opts...) } - // Unexported function to retrieve transport credentials from a gRPC server. - grpcGetServerCreds = internal.GetServerCredentials.(func(*grpc.Server) credentials.TransportCredentials) - buildProvider = buildProviderFunc - logger = grpclog.Component("xds") + grpcGetServerCreds = internal.GetServerCredentials.(func(*grpc.Server) credentials.TransportCredentials) + drainServerTransports = internal.DrainServerTransports.(func(*grpc.Server, string)) + logger = grpclog.Component("xds") ) func prefixLogger(p *GRPCServer) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(serverPrefix, p)) } -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the server. This is useful for overriding in unit tests. -type xdsClientInterface interface { - WatchListener(string, func(xdsclient.ListenerUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - -// grpcServerInterface contains methods from grpc.Server which are used by the +// grpcServer contains methods from grpc.Server which are used by the // GRPCServer type here. This is useful for overriding in unit tests. -type grpcServerInterface interface { +type grpcServer interface { RegisterService(*grpc.ServiceDesc, interface{}) Serve(net.Listener) error Stop() GracefulStop() + GetServiceInfo() map[string]grpc.ServiceInfo } // GRPCServer wraps a gRPC server and provides server-side xDS functionality, by // communication with a management server using xDS APIs. It implements the // grpc.ServiceRegistrar interface and can be passed to service registration // functions in IDL generated code. -// -// Experimental -// -// Notice: This type is EXPERIMENTAL and may be changed or removed in a -// later release. type GRPCServer struct { - gs grpcServerInterface + gs grpcServer quit *grpcsync.Event logger *internalgrpclog.PrefixLogger xdsCredsInUse bool + opts *serverOptions // clientMu is used only in initXDSClient(), which is called at the // beginning of Serve(), where we have to decide if we have to create a // client or use an existing one. clientMu sync.Mutex - xdsC xdsClientInterface + xdsC xdsclient.XDSClient } // NewGRPCServer creates an xDS-enabled gRPC server using the passed in opts. // The underlying gRPC server has no service registered and has not started to // accept requests yet. -// -// Experimental -// -// Notice: This API is EXPERIMENTAL and may be changed or removed in a later -// release. func NewGRPCServer(opts ...grpc.ServerOption) *GRPCServer { newOpts := []grpc.ServerOption{ grpc.ChainUnaryInterceptor(xdsUnaryInterceptor), @@ -115,6 +104,7 @@ func NewGRPCServer(opts ...grpc.ServerOption) *GRPCServer { s := &GRPCServer{ gs: newGRPCServer(newOpts...), quit: grpcsync.NewEvent(), + opts: handleServerOptions(opts), } s.logger = prefixLogger(s) s.logger.Infof("Created xds.GRPCServer") @@ -135,6 +125,18 @@ func NewGRPCServer(opts ...grpc.ServerOption) *GRPCServer { return s } +// handleServerOptions iterates through the list of server options passed in by +// the user, and handles the xDS server specific options. +func handleServerOptions(opts []grpc.ServerOption) *serverOptions { + so := &serverOptions{} + for _, opt := range opts { + if o, ok := opt.(*serverOption); ok { + o.apply(so) + } + } + return so +} + // RegisterService registers a service and its implementation to the underlying // gRPC server. It is called from the IDL generated code. This must be called // before invoking Serve. @@ -142,6 +144,12 @@ func (s *GRPCServer) RegisterService(sd *grpc.ServiceDesc, ss interface{}) { s.gs.RegisterService(sd, ss) } +// GetServiceInfo returns a map from service names to ServiceInfo. +// Service names include the package names, in the form of .. +func (s *GRPCServer) GetServiceInfo() map[string]grpc.ServiceInfo { + return s.gs.GetServiceInfo() +} + // initXDSClient creates a new xdsClient if there is no existing one available. func (s *GRPCServer) initXDSClient() error { s.clientMu.Lock() @@ -151,6 +159,12 @@ func (s *GRPCServer) initXDSClient() error { return nil } + newXDSClient := newXDSClient + if s.opts.bootstrapContents != nil { + newXDSClient = func() (xdsclient.XDSClient, error) { + return xdsclient.NewClientWithBootstrapContents(s.opts.bootstrapContents) + } + } client, err := newXDSClient() if err != nil { return fmt.Errorf("xds: failed to create xds-client: %v", err) @@ -178,76 +192,60 @@ func (s *GRPCServer) Serve(lis net.Listener) error { if err := s.initXDSClient(); err != nil { return err } + cfg := s.xdsC.BootstrapConfig() + if cfg == nil { + return errors.New("bootstrap configuration is empty") + } // If xds credentials were specified by the user, but bootstrap configs do // not contain any certificate provider configuration, it is better to fail // right now rather than failing when attempting to create certificate // providers after receiving an LDS response with security configuration. if s.xdsCredsInUse { - bc := s.xdsC.BootstrapConfig() - if bc == nil || len(bc.CertProviderConfigs) == 0 { + if len(cfg.CertProviderConfigs) == 0 { return errors.New("xds: certificate_providers config missing in bootstrap file") } } - lw, err := s.newListenerWrapper(lis) - if lw == nil { - // Error returned can be nil (when Stop/GracefulStop() is called). So, - // we need to check the returned listenerWrapper instead. - return err - } - return s.gs.Serve(lw) -} - -// newListenerWrapper creates and returns a listenerWrapper, which is a thin -// wrapper around the passed in listener lis, that can be passed to -// grpcServer.Serve(). -// -// It then registers a watch for a Listener resource and blocks until a good -// response is received or the server is stopped by a call to -// Stop/GracefulStop(). -func (s *GRPCServer) newListenerWrapper(lis net.Listener) (*listenerWrapper, error) { - lw := &listenerWrapper{ - Listener: lis, - closed: grpcsync.NewEvent(), - xdsHI: xdsinternal.NewHandshakeInfo(nil, nil), - } - - // This is used to notify that a good update has been received and that - // Serve() can be invoked on the underlying gRPC server. Using a - // grpcsync.Event instead of a vanilla channel simplifies the update handler - // as it need not keep track of whether the received update is the first one - // or not. - goodUpdate := grpcsync.NewEvent() - - // The resource_name in the LDS request sent by the xDS-enabled gRPC server - // is of the following format: - // "/path/to/resource?udpa.resource.listening_address=IP:Port". The - // `/path/to/resource` part of the name is sourced from the bootstrap config - // field `grpc_server_resource_name_id`. If this field is not specified in - // the bootstrap file, we will use a default of `grpc/server`. - path := "grpc/server" - if cfg := s.xdsC.BootstrapConfig(); cfg != nil && cfg.ServerResourceNameID != "" { - path = cfg.ServerResourceNameID - } - name := fmt.Sprintf("%s?udpa.resource.listening_address=%s", path, lis.Addr().String()) - - // Register an LDS watch using our xdsClient, and specify the listening - // address as the resource name. - cancelWatch := s.xdsC.WatchListener(name, func(update xdsclient.ListenerUpdate, err error) { - s.handleListenerUpdate(listenerUpdate{ - lw: lw, - name: name, - lds: update, - err: err, - goodUpdate: goodUpdate, - }) + // The server listener resource name template from the bootstrap + // configuration contains a template for the name of the Listener resource + // to subscribe to for a gRPC server. If the token `%s` is present in the + // string, it will be replaced with the server's listening "IP:port" (e.g., + // "0.0.0.0:8080", "[::]:8080"). The absence of a template will be treated + // as an error since we do not have any default value for this. + if cfg.ServerListenerResourceNameTemplate == "" { + return errors.New("missing server_listener_resource_name_template in the bootstrap configuration") + } + name := cfg.ServerListenerResourceNameTemplate + if strings.Contains(cfg.ServerListenerResourceNameTemplate, "%s") { + name = strings.Replace(cfg.ServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) + } + + modeUpdateCh := buffer.NewUnbounded() + go func() { + s.handleServingModeChanges(modeUpdateCh) + }() + + // Create a listenerWrapper which handles all functionality required by + // this particular instance of Serve(). + lw, goodUpdateCh := server.NewListenerWrapper(server.ListenerWrapperParams{ + Listener: lis, + ListenerResourceName: name, + XDSCredsInUse: s.xdsCredsInUse, + XDSClient: s.xdsC, + ModeCallback: func(addr net.Addr, mode connectivity.ServingMode, err error) { + modeUpdateCh.Put(&modeChangeArgs{ + addr: addr, + mode: mode, + err: err, + }) + }, + DrainCallback: func(addr net.Addr) { + if gs, ok := s.gs.(*grpc.Server); ok { + drainServerTransports(gs, addr.String()) + } + }, }) - s.logger.Infof("Watch started on resource name %v", name) - lw.cancelWatch = func() { - cancelWatch() - s.logger.Infof("Watch cancelled on resource name %v", name) - } // Block until a good LDS response is received or the server is stopped. select { @@ -256,123 +254,51 @@ func (s *GRPCServer) newListenerWrapper(lis net.Listener) (*listenerWrapper, err // need to explicitly close the listener. Cancellation of the xDS watch // is handled by the listenerWrapper. lw.Close() - return nil, nil - case <-goodUpdate.Done(): + return nil + case <-goodUpdateCh: } - return lw, nil + return s.gs.Serve(lw) } -// listenerUpdate wraps the information received from a registered LDS watcher. -type listenerUpdate struct { - lw *listenerWrapper // listener associated with this watch - name string // resource name being watched - lds xdsclient.ListenerUpdate // received update - err error // received error - goodUpdate *grpcsync.Event // event to fire upon a good update +// modeChangeArgs wraps argument required for invoking mode change callback. +type modeChangeArgs struct { + addr net.Addr + mode connectivity.ServingMode + err error } -func (s *GRPCServer) handleListenerUpdate(update listenerUpdate) { - if update.lw.closed.HasFired() { - s.logger.Warningf("Resource %q received update: %v with error: %v, after for listener was closed", update.name, update.lds, update.err) - return - } - - if update.err != nil { - // We simply log an error here and hope we get a successful update - // in the future. The error could be because of a timeout or an - // actual error, like the requested resource not found. In any case, - // it is fine for the server to hang indefinitely until Stop() is - // called. - s.logger.Warningf("Received error for resource %q: %+v", update.name, update.err) - return - } - s.logger.Infof("Received update for resource %q: %+v", update.name, update.lds.String()) - - if err := s.handleSecurityConfig(update.lds.SecurityCfg, update.lw); err != nil { - s.logger.Warningf("Invalid security config update: %v", err) - return - } - - // If we got all the way here, it means the received update was a good one. - update.goodUpdate.Fire() -} - -func (s *GRPCServer) handleSecurityConfig(config *xdsclient.SecurityConfig, lw *listenerWrapper) error { - // If xdsCredentials are not in use, i.e, the user did not want to get - // security configuration from the control plane, we should not be acting on - // the received security config here. Doing so poses a security threat. - if !s.xdsCredsInUse { - return nil - } - - // Security config being nil is a valid case where the control plane has - // not sent any security configuration. The xdsCredentials implementation - // handles this by delegating to its fallback credentials. - if config == nil { - // We need to explicitly set the fields to nil here since this might be - // a case of switching from a good security configuration to an empty - // one where fallback credentials are to be used. - lw.xdsHI.SetRootCertProvider(nil) - lw.xdsHI.SetIdentityCertProvider(nil) - lw.xdsHI.SetRequireClientCert(false) - return nil - } - - cpc := s.xdsC.BootstrapConfig().CertProviderConfigs - // Identity provider is mandatory on the server side. - identityProvider, err := buildProvider(cpc, config.IdentityInstanceName, config.IdentityCertName, true, false) - if err != nil { - return err - } - - // A root provider is required only when doing mTLS. - var rootProvider certprovider.Provider - if config.RootInstanceName != "" { - rootProvider, err = buildProvider(cpc, config.RootInstanceName, config.RootCertName, false, true) - if err != nil { - return err +// handleServingModeChanges runs as a separate goroutine, spawned from Serve(). +// It reads a channel on to which mode change arguments are pushed, and in turn +// invokes the user registered callback. It also calls an internal method on the +// underlying grpc.Server to gracefully close existing connections, if the +// listener moved to a "not-serving" mode. +func (s *GRPCServer) handleServingModeChanges(updateCh *buffer.Unbounded) { + for { + select { + case <-s.quit.Done(): + return + case u := <-updateCh.Get(): + updateCh.Load() + args := u.(*modeChangeArgs) + if args.mode == connectivity.ServingModeNotServing { + // We type assert our underlying gRPC server to the real + // grpc.Server here before trying to initiate the drain + // operation. This approach avoids performing the same type + // assertion in the grpc package which provides the + // implementation for internal.GetServerCredentials, and allows + // us to use a fake gRPC server in tests. + if gs, ok := s.gs.(*grpc.Server); ok { + drainServerTransports(gs, args.addr.String()) + } + } + if s.opts.modeCallback != nil { + s.opts.modeCallback(args.addr, ServingModeChangeArgs{ + Mode: args.mode, + Err: args.err, + }) + } } } - - // Close the old providers and cache the new ones. - lw.providerMu.Lock() - if lw.cachedIdentity != nil { - lw.cachedIdentity.Close() - } - if lw.cachedRoot != nil { - lw.cachedRoot.Close() - } - lw.cachedRoot = rootProvider - lw.cachedIdentity = identityProvider - - // We set all fields here, even if some of them are nil, since they - // could have been non-nil earlier. - lw.xdsHI.SetRootCertProvider(rootProvider) - lw.xdsHI.SetIdentityCertProvider(identityProvider) - lw.xdsHI.SetRequireClientCert(config.RequireClientCert) - lw.providerMu.Unlock() - - return nil -} - -func buildProviderFunc(configs map[string]*certprovider.BuildableConfig, instanceName, certName string, wantIdentity, wantRoot bool) (certprovider.Provider, error) { - cfg, ok := configs[instanceName] - if !ok { - return nil, fmt.Errorf("certificate provider instance %q not found in bootstrap file", instanceName) - } - provider, err := cfg.Build(certprovider.BuildOptions{ - CertName: certName, - WantIdentity: wantIdentity, - WantRoot: wantRoot, - }) - if err != nil { - // This error is not expected since the bootstrap process parses the - // config and makes sure that it is acceptable to the plugin. Still, it - // is possible that the plugin parses the config successfully, but its - // Build() method errors out. - return nil, fmt.Errorf("failed to get security plugin instance (%+v): %v", cfg, err) - } - return provider, nil } // Stop stops the underlying gRPC server. It immediately closes all open @@ -398,120 +324,79 @@ func (s *GRPCServer) GracefulStop() { } } +// routeAndProcess routes the incoming RPC to a configured route in the route +// table and also processes the RPC by running the incoming RPC through any HTTP +// Filters configured. +func routeAndProcess(ctx context.Context) error { + conn := transport.GetConnection(ctx) + cw, ok := conn.(interface { + VirtualHosts() []xdsclient.VirtualHostWithInterceptors + }) + if !ok { + return errors.New("missing virtual hosts in incoming context") + } + mn, ok := grpc.Method(ctx) + if !ok { + return errors.New("missing method name in incoming context") + } + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return errors.New("missing metadata in incoming context") + } + // A41 added logic to the core grpc implementation to guarantee that once + // the RPC gets to this point, there will be a single, unambiguous authority + // present in the header map. + authority := md.Get(":authority") + vh := xdsclient.FindBestMatchingVirtualHostServer(authority[0], cw.VirtualHosts()) + if vh == nil { + return status.Error(codes.Unavailable, "the incoming RPC did not match a configured Virtual Host") + } + + var rwi *xdsclient.RouteWithInterceptors + rpcInfo := iresolver.RPCInfo{ + Context: ctx, + Method: mn, + } + for _, r := range vh.Routes { + if r.M.Match(rpcInfo) { + // "NonForwardingAction is expected for all Routes used on server-side; a route with an inappropriate action causes + // RPCs matching that route to fail with UNAVAILABLE." - A36 + if r.RouteAction != xdsclient.RouteActionNonForwardingAction { + return status.Error(codes.Unavailable, "the incoming RPC matched to a route that was not of action type non forwarding") + } + rwi = &r + break + } + } + if rwi == nil { + return status.Error(codes.Unavailable, "the incoming RPC did not match a configured Route") + } + for _, interceptor := range rwi.Interceptors { + if err := interceptor.AllowRPC(ctx); err != nil { + return status.Errorf(codes.PermissionDenied, "Incoming RPC is not allowed: %v", err) + } + } + return nil +} + // xdsUnaryInterceptor is the unary interceptor added to the gRPC server to // perform any xDS specific functionality on unary RPCs. -// -// This is a no-op at this point. func xdsUnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + if env.RBACSupport { + if err := routeAndProcess(ctx); err != nil { + return nil, err + } + } return handler(ctx, req) } // xdsStreamInterceptor is the stream interceptor added to the gRPC server to // perform any xDS specific functionality on streaming RPCs. -// -// This is a no-op at this point. func xdsStreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - return handler(srv, ss) -} - -// listenerWrapper wraps the net.Listener associated with the listening address -// passed to Serve(). It also contains all other state associated with this -// particular invocation of Serve(). -type listenerWrapper struct { - net.Listener - cancelWatch func() - - // A small race exists in the xdsClient code where an xDS response is - // received while the user is calling cancel(). In this small window the - // registered callback can be called after the watcher is canceled. We avoid - // processing updates received in callbacks once the listener is closed, to - // make sure that we do not process updates received during this race - // window. - closed *grpcsync.Event - - // The certificate providers are cached here to that they can be closed when - // a new provider is to be created. - providerMu sync.Mutex - cachedRoot certprovider.Provider - cachedIdentity certprovider.Provider - - // Wraps all information required by the xds handshaker. - xdsHI *xdsinternal.HandshakeInfo -} - -// Accept blocks on an Accept() on the underlying listener, and wraps the -// returned net.Conn with the configured certificate providers. -func (l *listenerWrapper) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - return &conn{Conn: c, xdsHI: l.xdsHI}, nil -} - -// Close closes the underlying listener. It also cancels the xDS watch -// registered in Serve() and closes any certificate provider instances created -// based on security configuration received in the LDS response. -func (l *listenerWrapper) Close() error { - l.closed.Fire() - l.Listener.Close() - if l.cancelWatch != nil { - l.cancelWatch() - } - - l.providerMu.Lock() - if l.cachedIdentity != nil { - l.cachedIdentity.Close() - l.cachedIdentity = nil - } - if l.cachedRoot != nil { - l.cachedRoot.Close() - l.cachedRoot = nil + if env.RBACSupport { + if err := routeAndProcess(ss.Context()); err != nil { + return err + } } - l.providerMu.Unlock() - - return nil -} - -// conn is a thin wrapper around a net.Conn returned by Accept(). -type conn struct { - net.Conn - - // This is the same HandshakeInfo as stored in the listenerWrapper that - // created this conn. The former updates the HandshakeInfo whenever it - // receives new security configuration. - xdsHI *xdsinternal.HandshakeInfo - - // The connection deadline as configured by the grpc.Server on the rawConn - // that is returned by a call to Accept(). This is set to the connection - // timeout value configured by the user (or to a default value) before - // initiating the transport credential handshake, and set to zero after - // completing the HTTP2 handshake. - deadlineMu sync.Mutex - deadline time.Time -} - -// SetDeadline makes a copy of the passed in deadline and forwards the call to -// the underlying rawConn. -func (c *conn) SetDeadline(t time.Time) error { - c.deadlineMu.Lock() - c.deadline = t - c.deadlineMu.Unlock() - return c.Conn.SetDeadline(t) -} - -// GetDeadline returns the configured deadline. This will be invoked by the -// ServerHandshake() method of the XdsCredentials, which needs a deadline to -// pass to the certificate provider. -func (c *conn) GetDeadline() time.Time { - c.deadlineMu.Lock() - t := c.deadline - c.deadlineMu.Unlock() - return t -} - -// XDSHandshakeInfo returns a pointer to the HandshakeInfo stored in conn. This -// will be invoked by the ServerHandshake() method of the XdsCredentials. -func (c *conn) XDSHandshakeInfo() *xdsinternal.HandshakeInfo { - return c.xdsHI + return handler(srv, ss) } diff --git a/xds/server_options.go b/xds/server_options.go new file mode 100644 index 00000000000..1d46c3adb7b --- /dev/null +++ b/xds/server_options.go @@ -0,0 +1,76 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "net" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" +) + +type serverOptions struct { + modeCallback ServingModeCallbackFunc + bootstrapContents []byte +} + +type serverOption struct { + grpc.EmptyServerOption + apply func(*serverOptions) +} + +// ServingModeCallback returns a grpc.ServerOption which allows users to +// register a callback to get notified about serving mode changes. +func ServingModeCallback(cb ServingModeCallbackFunc) grpc.ServerOption { + return &serverOption{apply: func(o *serverOptions) { o.modeCallback = cb }} +} + +// ServingModeCallbackFunc is the callback that users can register to get +// notified about the server's serving mode changes. The callback is invoked +// with the address of the listener and its new mode. +// +// Users must not perform any blocking operations in this callback. +type ServingModeCallbackFunc func(addr net.Addr, args ServingModeChangeArgs) + +// ServingModeChangeArgs wraps the arguments passed to the serving mode callback +// function. +type ServingModeChangeArgs struct { + // Mode is the new serving mode of the server listener. + Mode connectivity.ServingMode + // Err is set to a non-nil error if the server has transitioned into + // not-serving mode. + Err error +} + +// BootstrapContentsForTesting returns a grpc.ServerOption which allows users +// to inject a bootstrap configuration used by only this server, instead of the +// global configuration from the environment variables. +// +// Testing Only +// +// This function should ONLY be used for testing and may not work with some +// other features, including the CSDS service. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func BootstrapContentsForTesting(contents []byte) grpc.ServerOption { + return &serverOption{apply: func(o *serverOptions) { o.bootstrapContents = contents }} +} diff --git a/xds/server_test.go b/xds/server_test.go index cde96307926..0866e0414ae 100644 --- a/xds/server_test.go +++ b/xds/server_test.go @@ -24,27 +24,99 @@ import ( "fmt" "net" "reflect" + "strings" "testing" "time" + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + _ "google.golang.org/grpc/xds/internal/httpfilter/router" xdstestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/e2e" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( - defaultTestTimeout = 5 * time.Second - defaultTestShortTimeout = 10 * time.Millisecond - testServerResourceNameID = "/path/to/resource" + defaultTestTimeout = 5 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + testServerListenerResourceNameTemplate = "/path/to/resource/%s/%s" ) +var listenerWithFilterChains = &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + FilterChainMatch: &v3listenerpb.FilterChainMatch{ + PrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, + SourcePrefixRanges: []*v3corepb.CidrRange{ + { + AddressPrefix: "192.168.0.0", + PrefixLen: &wrapperspb.UInt32Value{ + Value: uint32(16), + }, + }, + }, + SourcePorts: []uint32{80}, + }, + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, +} + type s struct { grpctest.Tester } @@ -65,9 +137,10 @@ func (f *fakeGRPCServer) RegisterService(*grpc.ServiceDesc, interface{}) { f.registerServiceCh.Send(nil) } -func (f *fakeGRPCServer) Serve(net.Listener) error { +func (f *fakeGRPCServer) Serve(lis net.Listener) error { f.serveCh.Send(nil) <-f.done + lis.Close() return nil } @@ -80,6 +153,10 @@ func (f *fakeGRPCServer) GracefulStop() { f.gracefulStopCh.Send(nil) } +func (f *fakeGRPCServer) GetServiceInfo() map[string]grpc.ServiceInfo { + panic("implement me") +} + func newFakeGRPCServer() *fakeGRPCServer { return &fakeGRPCServer{ done: make(chan struct{}), @@ -90,6 +167,14 @@ func newFakeGRPCServer() *fakeGRPCServer { } } +func splitHostPort(hostport string) (string, string) { + addr, port, err := net.SplitHostPort(hostport) + if err != nil { + panic(fmt.Sprintf("listener address %q does not parse: %v", hostport, err)) + } + return addr, port +} + func (s) TestNewServer(t *testing.T) { xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) if err != nil { @@ -119,7 +204,7 @@ func (s) TestNewServer(t *testing.T) { wantServerOpts := len(test.serverOpts) + 2 origNewGRPCServer := newGRPCServer - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { if got := len(opts); got != wantServerOpts { t.Fatalf("%d ServerOptions passed to grpc.Server, want %d", got, wantServerOpts) } @@ -147,7 +232,7 @@ func (s) TestRegisterService(t *testing.T) { fs := newFakeGRPCServer() origNewGRPCServer := newGRPCServer - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { return fs } + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { return fs } defer func() { newGRPCServer = origNewGRPCServer }() s := NewGRPCServer() @@ -173,8 +258,14 @@ var ( ) func init() { - fpb1 = &fakeProviderBuilder{name: fakeProvider1Name} - fpb2 = &fakeProviderBuilder{name: fakeProvider2Name} + fpb1 = &fakeProviderBuilder{ + name: fakeProvider1Name, + buildCh: testutils.NewChannel(), + } + fpb2 = &fakeProviderBuilder{ + name: fakeProvider2Name, + buildCh: testutils.NewChannel(), + } cfg1, _ := fpb1.ParseConfig(fakeConfig + "1111") cfg2, _ := fpb2.ParseConfig(fakeConfig + "2222") certProviderConfigs = map[string]*certprovider.BuildableConfig{ @@ -188,7 +279,8 @@ func init() { // fakeProviderBuilder builds new instances of fakeProvider and interprets the // config provided to it as a string. type fakeProviderBuilder struct { - name string + name string + buildCh *testutils.Channel } func (b *fakeProviderBuilder) ParseConfig(config interface{}) (*certprovider.BuildableConfig, error) { @@ -197,6 +289,7 @@ func (b *fakeProviderBuilder) ParseConfig(config interface{}) (*certprovider.Bui return nil, fmt.Errorf("providerBuilder %s received config of type %T, want string", b.name, config) } return certprovider.NewBuildableConfig(b.name, []byte(s), func(certprovider.BuildOptions) certprovider.Provider { + b.buildCh.Send(nil) return &fakeProvider{ Distributor: certprovider.NewDistributor(), config: s, @@ -220,19 +313,19 @@ func (p *fakeProvider) Close() { p.Distributor.Stop() } -// setupOverrides sets up overrides for bootstrap config, new xdsClient creation, -// new gRPC.Server creation, and certificate provider creation. -func setupOverrides() (*fakeGRPCServer, *testutils.Channel, *testutils.Channel, func()) { +// setupOverrides sets up overrides for bootstrap config, new xdsClient creation +// and new gRPC.Server creation. +func setupOverrides() (*fakeGRPCServer, *testutils.Channel, func()) { clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() c.SetBootstrapConfig(&bootstrap.Config{ - BalancerName: "dummyBalancer", - Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), - NodeProto: xdstestutils.EmptyNodeProtoV3, - ServerResourceNameID: testServerResourceNameID, - CertProviderConfigs: certProviderConfigs, + BalancerName: "dummyBalancer", + Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), + NodeProto: xdstestutils.EmptyNodeProtoV3, + ServerListenerResourceNameTemplate: testServerListenerResourceNameTemplate, + CertProviderConfigs: certProviderConfigs, }) clientCh.Send(c) return c, nil @@ -240,20 +333,11 @@ func setupOverrides() (*fakeGRPCServer, *testutils.Channel, *testutils.Channel, fs := newFakeGRPCServer() origNewGRPCServer := newGRPCServer - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { return fs } + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { return fs } - providerCh := testutils.NewChannel() - origBuildProvider := buildProvider - buildProvider = func(c map[string]*certprovider.BuildableConfig, id, cert string, wi, wr bool) (certprovider.Provider, error) { - p, err := origBuildProvider(c, id, cert, wi, wr) - providerCh.Send(nil) - return p, err - } - - return fs, clientCh, providerCh, func() { + return fs, clientCh, func() { newXDSClient = origNewXDSClient newGRPCServer = origNewGRPCServer - buildProvider = origBuildProvider } } @@ -261,16 +345,16 @@ func setupOverrides() (*fakeGRPCServer, *testutils.Channel, *testutils.Channel, // one. Tests that use xdsCredentials need a real grpc.Server instead of a fake // one, because the xDS-enabled server needs to read configured creds from the // underlying grpc.Server to confirm whether xdsCreds were configured. -func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, *testutils.Channel, func()) { +func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, func()) { clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() bc := &bootstrap.Config{ - BalancerName: "dummyBalancer", - Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), - NodeProto: xdstestutils.EmptyNodeProtoV3, - ServerResourceNameID: testServerResourceNameID, + BalancerName: "dummyBalancer", + Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), + NodeProto: xdstestutils.EmptyNodeProtoV3, + ServerListenerResourceNameTemplate: testServerListenerResourceNameTemplate, } if includeCertProviderCfg { bc.CertProviderConfigs = certProviderConfigs @@ -280,18 +364,7 @@ func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, return c, nil } - providerCh := testutils.NewChannel() - origBuildProvider := buildProvider - buildProvider = func(c map[string]*certprovider.BuildableConfig, id, cert string, wi, wr bool) (certprovider.Provider, error) { - p, err := origBuildProvider(c, id, cert, wi, wr) - providerCh.Send(nil) - return p, err - } - - return clientCh, providerCh, func() { - newXDSClient = origNewXDSClient - buildProvider = origBuildProvider - } + return clientCh, func() { newXDSClient = origNewXDSClient } } // TestServeSuccess tests the successful case of calling Serve(). @@ -303,10 +376,17 @@ func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, // 4. Push a good response from the xdsClient, and make sure that Serve() on the // underlying grpc.Server is called. func (s) TestServeSuccess(t *testing.T) { - fs, clientCh, _, cleanup := setupOverrides() + fs, clientCh, cleanup := setupOverrides() defer cleanup() - server := NewGRPCServer() + // Create a new xDS-enabled gRPC server and pass it a server option to get + // notified about serving mode changes. + modeChangeCh := testutils.NewChannel() + modeChangeOption := ServingModeCallback(func(addr net.Addr, args ServingModeChangeArgs) { + t.Logf("server mode change callback invoked for listener %q with mode %q and error %v", addr.String(), args.Mode, args.Err) + modeChangeCh.Send(args.Mode) + }) + server := NewGRPCServer(modeChangeOption) defer server.Stop() lis, err := xdstestutils.LocalTCPListener() @@ -337,33 +417,89 @@ func (s) TestServeSuccess(t *testing.T) { if err != nil { t.Fatalf("error when waiting for a ListenerWatch: %v", err) } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) + wantName := strings.Replace(testServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) if name != wantName { t.Fatalf("LDS watch registered for name %q, want %q", name, wantName) } // Push an error to the registered listener watch callback and make sure // that Serve does not return. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{}, errors.New("LDS error")) + client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{}, xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "LDS resource not found")) sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() if _, err := serveDone.Receive(sCtx); err != context.DeadlineExceeded { t.Fatal("Serve() returned after a bad LDS response") } + // Make sure the serving mode changes appropriately. + v, err := modeChangeCh.Receive(ctx) + if err != nil { + t.Fatalf("error when waiting for serving mode to change: %v", err) + } + if mode := v.(connectivity.ServingMode); mode != connectivity.ServingModeNotServing { + t.Fatalf("server mode is %q, want %q", mode, connectivity.ServingModeNotServing) + } + // Push a good LDS response, and wait for Serve() to be invoked on the // underlying grpc.Server. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: "routeconfig"}, nil) + fcm, err := xdsclient.NewFilterChainManager(listenerWithFilterChains) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + addr, port := splitHostPort(lis.Addr().String()) + client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + RouteConfigName: "routeconfig", + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: addr, + Port: port, + FilterChains: fcm, + }, + }, nil) if _, err := fs.serveCh.Receive(ctx); err != nil { t.Fatalf("error when waiting for Serve() to be invoked on the grpc.Server") } + + // Make sure the serving mode changes appropriately. + v, err = modeChangeCh.Receive(ctx) + if err != nil { + t.Fatalf("error when waiting for serving mode to change: %v", err) + } + if mode := v.(connectivity.ServingMode); mode != connectivity.ServingModeServing { + t.Fatalf("server mode is %q, want %q", mode, connectivity.ServingModeServing) + } + + // Push an update to the registered listener watch callback with a Listener + // resource whose host:port does not match the actual listening address and + // port. This will push the listener to "not-serving" mode. + client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ + RouteConfigName: "routeconfig", + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: "10.20.30.40", + Port: "666", + FilterChains: fcm, + }, + }, nil) + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := serveDone.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatal("Serve() returned after a bad LDS response") + } + + // Make sure the serving mode changes appropriately. + v, err = modeChangeCh.Receive(ctx) + if err != nil { + t.Fatalf("error when waiting for serving mode to change: %v", err) + } + if mode := v.(connectivity.ServingMode); mode != connectivity.ServingModeNotServing { + t.Fatalf("server mode is %q, want %q", mode, connectivity.ServingModeNotServing) + } } // TestServeWithStop tests the case where Stop() is called before an LDS update // is received. This should cause Serve() to exit before calling Serve() on the // underlying grpc.Server. func (s) TestServeWithStop(t *testing.T) { - fs, clientCh, _, cleanup := setupOverrides() + fs, clientCh, cleanup := setupOverrides() defer cleanup() // Note that we are not deferring the Stop() here since we explicitly call @@ -399,7 +535,7 @@ func (s) TestServeWithStop(t *testing.T) { server.Stop() t.Fatalf("error when waiting for a ListenerWatch: %v", err) } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) + wantName := strings.Replace(testServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) if name != wantName { server.Stop() t.Fatalf("LDS watch registered for name %q, wantPrefix %q", name, wantName) @@ -448,39 +584,79 @@ func (s) TestServeBootstrapFailure(t *testing.T) { } } -// TestServeBootstrapWithMissingCertProviders tests the case where the bootstrap -// config does not contain certificate provider configuration, but xdsCreds are -// passed to the server. Verifies that the call to Serve() fails. -func (s) TestServeBootstrapWithMissingCertProviders(t *testing.T) { - _, _, cleanup := setupOverridesForXDSCreds(false) - defer cleanup() - - xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) - if err != nil { - t.Fatalf("failed to create xds server credentials: %v", err) +// TestServeBootstrapConfigInvalid tests the cases where the bootstrap config +// does not contain expected fields. Verifies that the call to Serve() fails. +func (s) TestServeBootstrapConfigInvalid(t *testing.T) { + tests := []struct { + desc string + bootstrapConfig *bootstrap.Config + }{ + { + desc: "bootstrap config is missing", + bootstrapConfig: nil, + }, + { + desc: "certificate provider config is missing", + bootstrapConfig: &bootstrap.Config{ + BalancerName: "dummyBalancer", + Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), + NodeProto: xdstestutils.EmptyNodeProtoV3, + ServerListenerResourceNameTemplate: testServerListenerResourceNameTemplate, + }, + }, + { + desc: "server_listener_resource_name_template is missing", + bootstrapConfig: &bootstrap.Config{ + BalancerName: "dummyBalancer", + Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), + NodeProto: xdstestutils.EmptyNodeProtoV3, + CertProviderConfigs: certProviderConfigs, + }, + }, } - server := NewGRPCServer(grpc.Creds(xdsCreds)) - defer server.Stop() - lis, err := xdstestutils.LocalTCPListener() - if err != nil { - t.Fatalf("xdstestutils.LocalTCPListener() failed: %v", err) - } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + // Override the xdsClient creation with one that returns a fake + // xdsClient with the specified bootstrap configuration. + clientCh := testutils.NewChannel() + origNewXDSClient := newXDSClient + newXDSClient = func() (xdsclient.XDSClient, error) { + c := fakeclient.NewClient() + c.SetBootstrapConfig(test.bootstrapConfig) + clientCh.Send(c) + return c, nil + } + defer func() { newXDSClient = origNewXDSClient }() - serveDone := testutils.NewChannel() - go func() { - err := server.Serve(lis) - serveDone.Send(err) - }() + xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + t.Fatalf("failed to create xds server credentials: %v", err) + } + server := NewGRPCServer(grpc.Creds(xdsCreds)) + defer server.Stop() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - v, err := serveDone.Receive(ctx) - if err != nil { - t.Fatalf("error when waiting for Serve() to exit: %v", err) - } - if err, ok := v.(error); !ok || err == nil { - t.Fatal("Serve() did not exit with error") + lis, err := xdstestutils.LocalTCPListener() + if err != nil { + t.Fatalf("xdstestutils.LocalTCPListener() failed: %v", err) + } + + serveDone := testutils.NewChannel() + go func() { + err := server.Serve(lis) + serveDone.Send(err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + v, err := serveDone.Receive(ctx) + if err != nil { + t.Fatalf("error when waiting for Serve() to exit: %v", err) + } + if err, ok := v.(error); !ok || err == nil { + t.Fatal("Serve() did not exit with error") + } + }) } } @@ -488,7 +664,7 @@ func (s) TestServeBootstrapWithMissingCertProviders(t *testing.T) { // verifies that Server() exits with a non-nil error. func (s) TestServeNewClientFailure(t *testing.T) { origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { return nil, errors.New("xdsClient creation failed") } defer func() { newXDSClient = origNewXDSClient }() @@ -522,7 +698,7 @@ func (s) TestServeNewClientFailure(t *testing.T) { // server is not configured with xDS credentials. Verifies that the security // config received as part of a Listener update is not acted upon. func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { - fs, clientCh, providerCh, cleanup := setupOverrides() + fs, clientCh, cleanup := setupOverrides() defer cleanup() server := NewGRPCServer() @@ -556,7 +732,7 @@ func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { if err != nil { t.Fatalf("error when waiting for a ListenerWatch: %v", err) } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) + wantName := strings.Replace(testServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) if name != wantName { t.Fatalf("LDS watch registered for name %q, want %q", name, wantName) } @@ -564,12 +740,57 @@ func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { // Push a good LDS response with security config, and wait for Serve() to be // invoked on the underlying grpc.Server. Also make sure that certificate // providers are not created. + fcm, err := xdsclient.NewFilterChainManager(&v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + TransportSocket: &v3corepb.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + TlsCertificateCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: "identityPluginInstance", + CertificateName: "identityCertName", + }, + }, + }), + }, + }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + RouteSpecifier: &v3httppb.HttpConnectionManager_RouteConfig{ + RouteConfig: &v3routepb.RouteConfiguration{ + Name: "routeName", + VirtualHosts: []*v3routepb.VirtualHost{{ + Domains: []string{"lds.target.good:3333"}, + Routes: []*v3routepb.Route{{ + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/"}, + }, + Action: &v3routepb.Route_NonForwardingAction{}, + }}}}}, + }, + HttpFilters: []*v3httppb.HttpFilter{e2e.RouterHTTPFilter}, + }), + }, + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("xdsclient.NewFilterChainManager() failed with error: %v", err) + } + addr, port := splitHostPort(lis.Addr().String()) client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ RouteConfigName: "routeconfig", - SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", - IdentityInstanceName: "default2", - RequireClientCert: true, + InboundListenerCfg: &xdsclient.InboundListenerConfig{ + Address: addr, + Port: port, + FilterChains: fcm, }, }, nil) if _, err := fs.serveCh.Receive(ctx); err != nil { @@ -577,10 +798,8 @@ func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { } // Make sure the security configuration is not acted upon. - sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if _, err := providerCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Fatalf("certificate provider created when no xDS creds were specified") + if err := verifyCertProviderNotCreated(); err != nil { + t.Fatal(err) } } @@ -588,7 +807,7 @@ func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { // server is configured with xDS credentials, but receives a Listener update // with an error. Verifies that no certificate providers are created. func (s) TestHandleListenerUpdate_ErrorUpdate(t *testing.T) { - clientCh, providerCh, cleanup := setupOverridesForXDSCreds(true) + clientCh, cleanup := setupOverridesForXDSCreds(true) defer cleanup() xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) @@ -627,21 +846,14 @@ func (s) TestHandleListenerUpdate_ErrorUpdate(t *testing.T) { if err != nil { t.Fatalf("error when waiting for a ListenerWatch: %v", err) } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) + wantName := strings.Replace(testServerListenerResourceNameTemplate, "%s", lis.Addr().String(), -1) if name != wantName { t.Fatalf("LDS watch registered for name %q, want %q", name, wantName) } // Push an error to the registered listener watch callback and make sure // that Serve does not return. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ - RouteConfigName: "routeconfig", - SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", - IdentityInstanceName: "default2", - RequireClientCert: true, - }, - }, errors.New("LDS error")) + client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{}, errors.New("LDS error")) sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() if _, err := serveDone.Receive(sCtx); err != context.DeadlineExceeded { @@ -649,84 +861,21 @@ func (s) TestHandleListenerUpdate_ErrorUpdate(t *testing.T) { } // Also make sure that no certificate providers are created. - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if _, err := providerCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Fatalf("certificate provider created when no xDS creds were specified") + if err := verifyCertProviderNotCreated(); err != nil { + t.Fatal(err) } } -func (s) TestHandleListenerUpdate_ClosedListener(t *testing.T) { - clientCh, providerCh, cleanup := setupOverridesForXDSCreds(true) - defer cleanup() - - xdsCreds, err := xds.NewServerCredentials(xds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) - if err != nil { - t.Fatalf("failed to create xds server credentials: %v", err) - } - - server := NewGRPCServer(grpc.Creds(xdsCreds)) - defer server.Stop() - - lis, err := xdstestutils.LocalTCPListener() - if err != nil { - t.Fatalf("xdstestutils.LocalTCPListener() failed: %v", err) - } - - // Call Serve() in a goroutine, and push on a channel when Serve returns. - serveDone := testutils.NewChannel() - go func() { serveDone.Send(server.Serve(lis)) }() - - // Wait for an xdsClient to be created. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - c, err := clientCh.Receive(ctx) - if err != nil { - t.Fatalf("error when waiting for new xdsClient to be created: %v", err) - } - client := c.(*fakeclient.Client) - - // Wait for a listener watch to be registered on the xdsClient. - name, err := client.WaitForWatchListener(ctx) - if err != nil { - t.Fatalf("error when waiting for a ListenerWatch: %v", err) - } - wantName := fmt.Sprintf("%s?udpa.resource.listening_address=%s", client.BootstrapConfig().ServerResourceNameID, lis.Addr().String()) - if name != wantName { - t.Fatalf("LDS watch registered for name %q, want %q", name, wantName) - } - - // Push a good update to the registered listener watch callback. This will - // unblock the xds-enabled server which is waiting for a good listener - // update before calling Serve() on the underlying grpc.Server. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ - RouteConfigName: "routeconfig", - SecurityCfg: &xdsclient.SecurityConfig{IdentityInstanceName: "default2"}, - }, nil) - if _, err := providerCh.Receive(ctx); err != nil { - t.Fatal("error when waiting for certificate provider to be created") - } - - // Close the listener passed to Serve(), and wait for the latter to return a - // non-nil error. - lis.Close() - v, err := serveDone.Receive(ctx) - if err != nil { - t.Fatalf("error when waiting for Serve() to exit: %v", err) - } - if err, ok := v.(error); !ok || err == nil { - t.Fatal("Serve() did not exit with error") - } - - // Push another listener update and make sure that no certificate providers - // are created. - client.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{ - RouteConfigName: "routeconfig", - SecurityCfg: &xdsclient.SecurityConfig{IdentityInstanceName: "default1"}, - }, nil) +func verifyCertProviderNotCreated() error { sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - if _, err := providerCh.Receive(sCtx); err != context.DeadlineExceeded { - t.Fatalf("certificate provider created when no xDS creds were specified") + if _, err := fpb1.buildCh.Receive(sCtx); err != context.DeadlineExceeded { + return errors.New("certificate provider created when no xDS creds were specified") } + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := fpb2.buildCh.Receive(sCtx); err != context.DeadlineExceeded { + return errors.New("certificate provider created when no xDS creds were specified") + } + return nil } diff --git a/xds/xds.go b/xds/xds.go index 5deafd130a2..27547b56d22 100644 --- a/xds/xds.go +++ b/xds/xds.go @@ -28,9 +28,67 @@ package xds import ( + "fmt" + + v3statusgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" + "google.golang.org/grpc" + internaladmin "google.golang.org/grpc/internal/admin" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/csds" + _ "google.golang.org/grpc/credentials/tls/certprovider/pemfile" // Register the file watcher certificate provider plugin. _ "google.golang.org/grpc/xds/internal/balancer" // Register the balancers. - _ "google.golang.org/grpc/xds/internal/client/v2" // Register the v2 xDS API client. - _ "google.golang.org/grpc/xds/internal/client/v3" // Register the v3 xDS API client. - _ "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + _ "google.golang.org/grpc/xds/internal/httpfilter/fault" // Register the fault injection filter. + _ "google.golang.org/grpc/xds/internal/httpfilter/rbac" // Register the RBAC filter. + _ "google.golang.org/grpc/xds/internal/httpfilter/router" // Register the router filter. + xdsresolver "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register the v2 xDS API client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register the v3 xDS API client. ) + +func init() { + internaladmin.AddService(func(registrar grpc.ServiceRegistrar) (func(), error) { + var grpcServer *grpc.Server + switch ss := registrar.(type) { + case *grpc.Server: + grpcServer = ss + case *GRPCServer: + sss, ok := ss.gs.(*grpc.Server) + if !ok { + logger.Warningf("grpc server within xds.GRPCServer is not *grpc.Server, CSDS will not be registered") + return nil, nil + } + grpcServer = sss + default: + // Returning an error would cause the top level admin.Register() to + // fail. Log a warning instead. + logger.Warningf("server to register service on is neither a *grpc.Server or a *xds.GRPCServer, CSDS will not be registered") + return nil, nil + } + + csdss, err := csds.NewClientStatusDiscoveryServer() + if err != nil { + return nil, fmt.Errorf("failed to create csds server: %v", err) + } + v3statusgrpc.RegisterClientStatusDiscoveryServiceServer(grpcServer, csdss) + return csdss.Close, nil + }) +} + +// NewXDSResolverWithConfigForTesting creates a new xds resolver builder using +// the provided xds bootstrap config instead of the global configuration from +// the supported environment variables. The resolver.Builder is meant to be +// used in conjunction with the grpc.WithResolvers DialOption. +// +// Testing Only +// +// This function should ONLY be used for testing and may not work with some +// other features, including the CSDS service. +// +// Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func NewXDSResolverWithConfigForTesting(bootstrapConfig []byte) (resolver.Builder, error) { + return xdsresolver.NewBuilder(bootstrapConfig) +}